用pytorch及numpy计算成对余弦相似性矩阵,并用numpy实现kmeans聚类
生活随笔
收集整理的這篇文章主要介紹了
用pytorch及numpy计算成对余弦相似性矩阵,并用numpy实现kmeans聚类
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
??sklearn和scipy里面都提供了kmeans聚類的庫,但是它們都是根據(jù)向量直接進(jìn)行計(jì)算歐氏距離、閔氏距離或余弦相似度,如果使用其他的度量函數(shù)或者向量維度非常高需要先計(jì)算好度量距離然后再聚類時(shí),似乎這些庫函數(shù)都不能直接實(shí)現(xiàn),于是我用numpy自己寫了一個(gè),運(yùn)行也非常快。這里記錄下來以后備用:
import numpy as np import matplotlib.pyplot as plt import time t0 = time.time()Num = 512 corr = np.load('corrs20000.npy') #相關(guān)系數(shù)矩陣 u = np.arange(Num) #設(shè)置初始中心點(diǎn) #u = np.random.choice(20000,Num,replace=False) for n in range(1000): #設(shè)置1000次循環(huán)cluster = [[v] for v in u] #每個(gè)簇放在一個(gè)列表中,總體再有一個(gè)大列表存放others = np.array([v for v in range(20000) if v not in u]) #其他未歸類的點(diǎn)temp = corr[:,u]temp = temp[others,:] #通過兩步提取出所有其他未歸類點(diǎn)和各中心點(diǎn)的子相關(guān)矩陣inds = temp.argmax(axis=1) #計(jì)算每個(gè)未歸類點(diǎn)與各中心點(diǎn)的最大關(guān)系那個(gè)點(diǎn)的序號(hào)new_u = []for i in range(Num): #對(duì)每個(gè)簇分別計(jì)算(暫未想到矢量化方法)ind = np.where(inds==i)[0] #提取各簇中所有新點(diǎn)在未歸類點(diǎn)中的序號(hào)points = others[ind] #根據(jù)序號(hào)查找對(duì)應(yīng)的未歸類點(diǎn)實(shí)際編號(hào)cluster[i] = cluster[i] + points.tolist() #把本簇未歸類點(diǎn)加入到簇中temp = corr[cluster[i],:]temp = temp[:,cluster[i]] #通過兩步計(jì)算提取本簇各點(diǎn)子相關(guān)矩陣ind_ = temp.sum(axis=0).argmax() #計(jì)算各點(diǎn)和其他各點(diǎn)的總相關(guān)系數(shù)之和,取最大的一個(gè)的序號(hào)ind_new_center = cluster[i][ind_] #根據(jù)序號(hào)轉(zhuǎn)換為實(shí)際編號(hào),得到新的本簇中心點(diǎn)new_u.append(ind_new_center) #加入到新中心點(diǎn)向量new_u = np.asarray(new_u,dtype=np.int32)if (new_u==u).sum() == Num: #如果新的中心點(diǎn)向量已不再變化,停止循環(huán)breakprint(n,(new_u==u).sum(),time.time()-t0)u = new_u.copy() #計(jì)算全部結(jié)束后得到cluster就是各簇的點(diǎn)集和,u是中心點(diǎn)向量--------------------------后續(xù)補(bǔ)充:
??然而,快速計(jì)算一組向量的自相關(guān)性矩陣或者兩組向量的相互成對(duì)相關(guān)系數(shù)矩陣也是很常用的,在pytorch中用torch.cosine_similarity只能計(jì)算兩個(gè)向量間的,不能批量整體處理,如果循環(huán)計(jì)算,或者把向量通過repeat方法擴(kuò)展顯然計(jì)算速度比較慢。這里給出一種使用torch.matmul批量計(jì)算的方法,可以在cuda中計(jì)算,速度非常快。記錄備用:
總結(jié)
以上是生活随笔為你收集整理的用pytorch及numpy计算成对余弦相似性矩阵,并用numpy实现kmeans聚类的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pytorch几种损失函数CrossEn
- 下一篇: 声学、音乐计算常用工具总结(soundf