深入浅出线性判别分析(LDA),从理论到代码实现
?作者|善財童子
學校|西北工業大學
研究方向|機器學習/射頻微波
在知乎看到一篇講解線性判別分析(LDA,Linear Discriminant Analysis)的文章,感覺數學概念講得不是很清楚,而且沒有代碼實現。所以童子在參考相關文章的基礎上在這里做一個學習總結,與大家共勉,歡迎各位批評指正~~
注意:在不加說明的情況下,所有公式的向量均是列向量,這個也會反映到代碼中。
本文的基本思路來自以下文章:
https://www.adeveloperdiary.com/data-science/machine-learning/linear-discriminant-analysis-from-theory-to-code/
基本概念和目標
線性判別分析是一種很重要的分類算法,同時也是一種降維方法(這個我還沒想懂)。和 PCA 一樣,LDA 也是通過投影的方式達到去除數據之間冗余的一種算法。
如下圖所示的 2 類數據,為了正確的分類,我們希望這 2 類數據投影之后,同類的數據盡可能的集中(距離近,有重疊),不同類的數據盡可能的分開(距離遠,無重疊),左圖的投影不好,因為 2 類數據投影后有重疊,而右圖投影之后可以很好地進行分類,因為投影之后的 2 類數據之間幾乎沒有重疊,只是類內重疊得很厲害,而這正是我們想要的結果。
正交投影
因為 LDA 用到了投影,所以這里有必要科普一下投影的知識。以二維平面為例,如圖所示
我們要計算向量 在 上的投影 ,很顯然 與 成比例關系:,其中 是一個常數。我們使用向量正交的概念來求出這個常數 。在上圖中,向量 , 與 垂直,它們的內積為 0,即 ,即
注意:對于兩個向量 x 和 y,?,所以有 。
假設 w 是一個單位向量,則 ,這樣,對于任意向量 x,其在 w 上的投影 可表示為:
其中,??是一個常數。
對于一個數據集 ,其中 ,i=1,2,3,...m 是 d 維列向量。同樣假設 w 是一個單位向量,那么每一個 在 w 的投影是:
上述公式的 是叫做 在 w 上的偏移或者坐標。這一系列的值 表示我們做了一個映射 ,即通過投影,我們將 d 維向量降維到了 1 維。
投影數據的均值
為簡化起見,我們先假設有 2 類數據,定義樣本 :,其中 。
我們再定義 :
其中 是類別, 是所有類別為 的樣本的集合。所有數據 投影到 w 后,求其均值:
其中, 是 數據集的均值,同理 的均值是 ,投影后的均值 。為了使投影之后數據可正確地分類,我們希望這 2 類數據的中心離得越遠越好,也就是要使 最大,但是單獨這個條件并不能保證能夠正確地對每一個數據進行分類,我們還需要考慮每一類數據的方差,大的方差表示 2 類數據之間有重疊,小的方差表示 2 類數據之間沒有重疊。
LDA 并沒有直接使用方差的計算公式,而是采用如下的定義:
這個有個名稱叫 scatter matrix,本文暫時將其翻譯成散步矩陣吧。
總結一下,LDA 主要就兩點:
(1)最大化各類數據中心的距離,也就是各類數據的均值之間的距離要最大;
(2)各類數據的散步矩陣之和要小,也就是每個類別中的數據盡可能地集中。
將上述兩點整合在一起,得到一個優化公式:
這個公式也叫做 Fisher LDA,這樣,LDA 的問題就是關于??最優化上述的公式。我們重寫上述公式如下:
同理有??:
這樣:
這樣,LDA 目標優化函數就可以重寫為:
對公式(9)關于??求導,并令其導數為 0,可得:
整理得:
公式(11)中 做了替代:, 是一個常量。如果 S 是非奇異矩陣,那么公式(11)左乘 得到:
最終,LDA 問題其實就是求 對應最大特征值,而我們前面要求的投影方向就是最大特征值對應的特征向量,我們將 LDA 問題化成了矩陣的特征值和特征向量的問題了。
上述推導針對二分類問題進行的,對于多分類問題, 矩陣的計算方式不變,而 矩陣需要采用如下的公式計算:
其中:
C 表示類別的個數; 表示第 i 類中樣本的個數; 表示第 i 類樣本的均值; 表示整個樣本的均值。
關于矩陣微分可參考如下文章:
https://zhuanlan.zhihu.com/p/24709748
https://zhuanlan.zhihu.com/p/24863977
這里提醒一下,對 關于 x 求導的結果是 ,如果 A 是對稱矩陣,即 ,則 。公式(10)中因為 B 和 S 都是對稱矩陣(由它們的定義可以看出是對稱矩陣),所以對 關于 w 求導的結果是? 2Bw ,即 ,同理 。
代碼實現
import?numpy?as?np from?sklearn?import?datasetsfrom?sklearn.datasets?import?make_blobs import?matplotlib.pyplot?as?pltclass?MyLDA:def?__init__(self):passdef?fit(self,?X,?y):#?獲取所有的類別labels?=?np.unique(y)#print(labels)means?=?[]for?label?in?labels:#?計算每一個類別的樣本均值means.append(np.mean(X[y?==?label],?axis=0))#?如果是二分類的話if?len(labels)?==?2:mu?=?(means[0]?-?means[1])mu?=?mu[:,None]?#?轉成列向量B?=?mu?@?mu.Telse:total_mu?=?np.mean(X,?axis=0)B?=?np.zeros((X.shape[1],?X.shape[1]))for?i,?m?in?enumerate(means):n?=?X[y==i].shape[0]mu_i?=?m?-?total_mumu_i?=?mu_i[:,None]?#?轉成列向量B?+=?n?*?np.dot(mu_i,?mu_i.T)#?計算S矩陣S_t?=?[]for?label,?m?in?enumerate(means):S_i?=?np.zeros((X.shape[1],?X.shape[1]))for?row?in?X[y?==?label]:t?=?(row?-?m)t?=?t[:,None]?#?轉成列向量S_i?+=?t?@?t.TS_t.append(S_i)S?=?np.zeros((X.shape[1],?X.shape[1]))for?s?in?S_t:S?+=?s#?S^-1B進行特征分解S_inv?=?np.linalg.inv(S)S_inv_B?=?S_inv?@?Beig_vals,?eig_vecs?=?np.linalg.eig(S_inv_B)#從大到小排序ind?=?eig_vals.argsort()[::-1]eig_vals?=?eig_vals[ind]eig_vecs?=?eig_vecs[:,?ind]return?eig_vecs#構造數據集 def?make_data(centers=3,?cluster_std=[1.0,?3.0,?2.5],?n_samples=150,?n_features=2):????X,?y?=?make_blobs(n_samples,?n_features,?centers,?cluster_std)return?X,?yif?__name__?==?"__main__":X,?y?=?make_data(2,?[1.0,?3.0])print(X.shape)lda?=?MyLDA()eig_vecs?=?lda.fit(X,?y)W?=?eig_vecs[:,?:1]colors?=?['red',?'green',?'blue']fig,?ax?=?plt.subplots(figsize=(10,?8))for?point,?pred?in?zip(X,?y):#?畫出原始數據的散點圖ax.scatter(point[0],?point[1],?color=colors[pred],?alpha=0.5)#?每個數據點在W上的投影proj?=?(np.dot(point,?W)?*?W)?/?np.dot(W.T,?W)#畫出所有數據的投影ax.scatter(proj[0],?proj[1],?color=colors[pred],?alpha=0.5)plt.show()4.1 2類2個特征
if?__name__?==?"__main__":X,?y?=?make_data(2,?[1.0,?3.0])?#rint(X.shape)lda?=?MyLDA()eig_vecs?=?lda.fit(X,?y)W?=?eig_vecs[:,?:1]colors?=?['red',?'green',?'blue']fig,?ax?=?plt.subplots(figsize=(10,?8))for?point,?pred?in?zip(X,?y):#?畫出原始數據的散點圖ax.scatter(point[0],?point[1],?color=colors[pred],?alpha=0.5)#?每個數據點在W上的投影proj?=?(np.dot(point,?W)?*?W)?/?np.dot(W.T,?W)#畫出所有數據的投影ax.scatter(proj[0],?proj[1],?color=colors[pred],?alpha=0.5)plt.show()運行結果是:
可見,數據投影后在 1 維上可以很好的分類。
4.2 3類2個特征
if?__name__?==?"__main__":#?3類X,?y?=?make_data([[2.0,?1.0],?[15.0,?5.0],?[31.0,?12.0]],?[1.0,?3.0,?2.5])print(X.shape)lda?=?MyLDA()eig_vecs?=?lda.fit(X,?y)W?=?eig_vecs[:,?:1]colors?=?['red',?'green',?'blue']fig,?ax?=?plt.subplots(figsize=(10,?8))for?point,?pred?in?zip(X,?y):#?畫出原始數據的散點圖ax.scatter(point[0],?point[1],?color=colors[pred],?alpha=0.5)#?每個數據點在W上的投影proj?=?(np.dot(point,?W)?*?W)?/?np.dot(W.T,?W)#畫出所有數據的投影ax.scatter(proj[0],?proj[1],?color=colors[pred],?alpha=0.5)plt.show()運行結果是:
4.3 3類4個特征
if?__name__?==?"__main__":#X,?y?=?load_data(cols,?load_all=True,?head=True)X,?y?=?make_data([[2.0,?1.0],?[15.0,?5.0],?[31.0,?12.0]],?[1.0,?3.0,?2.5],?n_features=4)print(X.shape)lda?=?MyLDA()eig_vecs?=?lda.fit(X,?y)#?取前2個最大特征值對應的特征向量W?=?eig_vecs[:,?:2]#?將數據投影到這兩個特征向量上,從而達到降維的目的transformed?=?X?@?Wplt.subplots(figsize=(10,?8))plt.scatter(transformed[:,?0],?transformed[:,?1],?c=y,?cmap=plt.cm.Set1)plt.show()運行結果如下:
對上述結果使用 sklearn 官方實現的 LDA 進行對比驗證:
if?__name__?==?"__main__":X,?y?=?make_data([[2.0,?1.0],?[15.0,?5.0],?[31.0,?12.0]],?[1.0,?3.0,?2.5],?n_features=4)print(X.shape)lda?=?MyLDA()eig_vecs?=?lda.fit(X,?y)#?取前2個最大特征值對應的特征向量W?=?eig_vecs[:,?:2]#?將數據投影到這兩個特征向量上,從而達到降維的目的transformed?=?X?@?Wplt.subplots(figsize=(10,?8))plt.scatter(transformed[:,?0],?transformed[:,?1],?c=y,?cmap=plt.cm.Set1)plt.title('self-implementation')from?sklearn.discriminant_analysis?import?LinearDiscriminantAnalysissk_lda?=?LinearDiscriminantAnalysis()sk_lda.fit(X,?y)transformed?=?sk_lda.transform(X)plt.subplots(figsize=(10,?8))plt.scatter(transformed[:,?0],?transformed[:,?1],?c=y,?cmap=plt.cm.Set1)plt.title("sklearn's?offical?implementation")plt.show()左圖是本文實現的 LDA 分類結果,右圖是官方實現的 LDA 分類結果,可見,兩者的結果是一致的。
總結
LDA 是一個很強大的工具,但它是一個有監督的分類算法,PCA 是一個無監督的算法,這是和 PCA 的一個很重要的區別。
更多閱讀
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學習心得或技術干貨。我們的目的只有一個,讓知識真正流動起來。
?????來稿標準:
? 稿件確系個人原創作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發,請在投稿時提醒并附上所有已發布鏈接?
? PaperWeekly 默認每篇文章都是首發,均會添加“原創”標志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發送?
? 請留下即時聯系方式(微信或手機),以便我們在編輯發布時和作者溝通
????
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結
以上是生活随笔為你收集整理的深入浅出线性判别分析(LDA),从理论到代码实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 因为航空原因时间推迟 中转时间就半个小时
- 下一篇: 直播 | 清华大学李一鸣:后门攻击简介