让书写的Matlab代码运行更快 Recipes for Faster Matlab Code
Matlab 在 Research 中用得非常多,確實(shí)也是非常方便實(shí)用,只是有一個問題就是寫 Matlab 代碼的時候經(jīng)常需要用一些比較奇怪獨(dú)特的方式來思考和處理問題,否則寫出來的代碼雖然同樣能工作,但是速度上可能會差上幾百幾千倍。這里有幾個關(guān)鍵詞:向量化、緩存、稀疏性等。不過由于 Matlab 在這方面確實(shí)“問題”比較大,所以關(guān)于如何寫更高效的 Matlab 代碼的文章也已經(jīng)非常多了。但是剛巧最近 bdahz 小朋友問我一段 Matlab 代碼為什么又短又奇怪但是速度又快,所以我覺得正好可以拿這個作為一個例子,羅嗦一下 Matlab 編程時候的一些注意事項(xiàng),也許會對其他人也有所幫助。
這次我們的例子是?K-medoids 算法,我在以前的 blog 中也介紹過。簡單地來說就是 K-means 算法的一個變種,只是在選取中心的時候 K-means 是直接計(jì)算所有點(diǎn)的平均值,而 K-medoids 則要求中心點(diǎn)必須是數(shù)據(jù)點(diǎn)中某一個,所以 K-means 的優(yōu)化如果是一個數(shù)值計(jì)算問題的話,K-medoids 應(yīng)該屬于離散優(yōu)化,通常離散優(yōu)化需要窮舉搜索來求解,所以計(jì)算上會更難一些。不過實(shí)現(xiàn)得好的話,也是可以比較高效的,比如這個版本的 Matlab K-medoids。
接下來我們就用這個作為例子分析一下寫高效的 Matlab 代碼需要注意的一些問題,先把代碼貼出來吧:
從這段代碼里我們可以看到寫高效的 Matlab 代碼的首要注意事項(xiàng)是:把代碼寫得晦澀難懂……呃,開各玩笑^_^bb,不過也確實(shí)是這樣,其實(shí)這樣的問題在各種語言中都是存在的:教學(xué)或者示例用的代碼通常和真正實(shí)際項(xiàng)目中的代碼差別很大,實(shí)際中往往摻雜各種錯誤處理呀邊界處理之類的,變得很復(fù)雜;不過代碼清晰度最大的敵人往往還是優(yōu)化。為了讓代碼運(yùn)行效率更高效所做的各種努力幾乎都會很嚴(yán)重或者非常嚴(yán)重地破壞代碼的可理解性,使得原本很清晰的算法變得面目全非。
通常人們解決這類問題的辦法就是把一些通用的優(yōu)化機(jī)制總結(jié)起來,實(shí)現(xiàn)到編譯器里面去,讓編譯器來做這些 dirty work。就 C/C++ 來說的話,現(xiàn)在的編譯器雖然離完美還差的很遠(yuǎn),但是在優(yōu)化方面也算是已經(jīng)非常強(qiáng)大了。可惜的是 Matlab 在這方面似乎做得不是很好——雖然 Matlab 嚴(yán)格地來說是沒有編譯器的。比如說,Matlab 里面用?for?循環(huán)是非常慢的,導(dǎo)致大家都不太敢用?for?循環(huán),于是 Matlab 后來說提供了?for?的 JIT 機(jī)制,據(jù)說加快速度,但是似乎結(jié)果仍然是非常慢。所以沒辦法,幸運(yùn)的是 Matlab 代碼通常都是比較短的。
回到我們的例子上,首先看第 5、6 兩行,這兩行做的事情實(shí)際上就是計(jì)算所以數(shù)據(jù)點(diǎn)之間的 pair-wise distance,放在變量?D?里。由于 pair-wise distance 在算法中要被用到很多次,并且是不會變化的,所以一開始把它計(jì)算并存儲下來后面直接用,這是所有語言里都通用的一個加速方法,或者也可以說成是空間換時間,因?yàn)槿绻麛?shù)據(jù)量比較大 pair-wise distance 矩陣在內(nèi)存中無法存下來的話,就沒有足夠的空間來換時間了。
然后我們來看看這個 pair-wise distance 矩陣是怎么計(jì)算出來的。首先第 5 行?dot?函數(shù)參考 Matlab 的幫助文檔就知道計(jì)算了矩陣?X?每一列和自己的點(diǎn)乘。然后第 6 行用了一個奇怪的函數(shù)?bsxfun。讓我們先忽略這個函數(shù),來看一下兩個點(diǎn)??和??之間的距離應(yīng)該是怎么計(jì)算的,定義如下:
但是由于我們這里只需要比較距離之間的相對大小,所以可以省略一個開平方根的計(jì)算,使用“平方距離”:
當(dāng)然,我們說了,在 Matlab 里用?for?循環(huán)來計(jì)算是很慢的,所以我們要用向量化的方法來計(jì)算,可以這樣寫?sum((x-y).^2)。但是這里的問題是,我們要計(jì)算很多點(diǎn)之間的 pair-wise distance,雖然每一對點(diǎn)之間的距離可以這樣算的話,要計(jì)算所有點(diǎn)之間的距離,好像仍然無法避免兩重?for?循環(huán)來遍歷所有的點(diǎn)。但是那樣又會很慢了,所以我們需要更加深層次的向量化,首先展開距離公式
這樣把距離的計(jì)算分成了三個部分,前面兩個部分都是計(jì)算向量的 norm (的平方),而第三個部分是計(jì)算向量內(nèi)積。這樣的形式的好處是可以方便地對一堆點(diǎn)同時進(jìn)行計(jì)算:例如,對于矩陣??的每一列的 norm 平方,就可以用我們剛才提到的?dot?函數(shù)一次算出來,也是代碼中第 5 行干的事情。接下來是內(nèi)積,這個也簡單,通過矩陣乘法的公式就可以知道,如果??的話,那么
其中??是矩陣??的第??列。所以一次矩陣乘法?X'*X?就可以把所有 pair-wise 內(nèi)積全部算出來,不用任何循環(huán)。所以接下來只要把三個部分加起來就可以了,不過這里還有一個問題:雖然?X'*X?是得到的一個形狀合適的矩陣,但是?dot(X,X)?得到的卻是一個向量。為了看得更清楚一點(diǎn),我們分別用?、和??表示 pair-wise distance 計(jì)算的三個部分,按理說應(yīng)該計(jì)算得到三個形狀相同的矩陣,然后相加起來:
顯然?,而?,所以對于??來說,列坐標(biāo)是無關(guān)緊要的,如果記之前?dot?得到的結(jié)果向量為?v?的話,?應(yīng)該是向量?v?按列不斷重復(fù)而得到的矩陣;類似的,?應(yīng)該是?v?轉(zhuǎn)置之后按行重復(fù)得到的矩陣。在 Matlab 中經(jīng)常需要這樣的操作,用?repmat?即可完成,所以,下面的代碼實(shí)際上就可以計(jì)算 pair-wise distance 矩陣:
v = dot(X,X); D = repmat(v, length(v),1) + repmat(v', 1, length(v)) - 2*(X'*X);這里又碰到了一個空間換時間的問題:由于我們希望用向量化的方式“同時”計(jì)算所有點(diǎn)對的距離,所以我們需要把?v?擴(kuò)張成??和??這兩個矩陣,需要的存儲空間從??變到了?,并且存儲的都是重復(fù)的元素,如果用?for?循環(huán)一個一個地計(jì)算的話,這些多余的空間當(dāng)然是可以避免的,但是 Matlab 的?for又很慢。不過由于這個問題出現(xiàn)得非常多,于是 Matlab 提供了一個解決方案:bsxfun。詳細(xì)的文檔可以看 Matlab 的幫助,講得很清楚,簡單地來說,bsxfun?就是對矩陣的每個元素做同一個操作,基本等價于于寫一些?for?來對矩陣元素做計(jì)算,不同的是速度快了許多許多倍。另外還有一個特點(diǎn)就是傳給bsxfun?的矩陣如果某一個維度上 size 是 1 的話,在那個維度上它會根據(jù)傳進(jìn)來的其他矩陣做“重復(fù)擴(kuò)展”,所做的事情和我們?nèi)巳庥?repmat?是一樣的,只是實(shí)現(xiàn)方式并不是這樣,它并不會生成臨時矩陣,所以在內(nèi)存方面絕對占有。
原來代碼里其實(shí)就是用?bsxfun?做了我們剛才用?repmat?做的事情。下面的代碼對比了三種方法:
用 Matlab 的 Profiler 運(yùn)行一下(順便說一下,Matlab 的 Profiler 是非常好用的工具,也是提升代碼性能的重要工具,用善用),在我這里,bsxfun、repmat和用循環(huán)的方式的運(yùn)行時間(m=1000,n=1000)分別是 0.26、0.18 和 8.22。循環(huán)比?repmat?慢了近 50 倍,bsxfun?和?repmat?速度差不多,但是內(nèi)存更省一些,一般推薦使用?bsxfun。
然后是第 8 行,先是用?randsample?隨機(jī)選出?k?個點(diǎn)作為初始 center,然后為每個數(shù)據(jù)點(diǎn)計(jì)算 label:也就是找出它們與?k?個 center 距離最近的那個所對應(yīng)的 index。這也是用向量化的方法一次性計(jì)算的,因?yàn)?Matlab 的?min?函數(shù)能夠支持向量化操作,事實(shí)上 Matlab 的大多數(shù)基本函數(shù)都支持向量化操作,多看一下文檔會有好處。
然后是第 11 行,這一行的目的是根據(jù)每一類的數(shù)據(jù)點(diǎn)重新選點(diǎn)每類的中心點(diǎn),這一步中就是 K-means 和 K-medoids 不同的地方:K-medoids 由于要求類中心必須是數(shù)據(jù)點(diǎn)中的某一個,所以這里需要用遍歷搜索的方法:遍歷該類中的所有數(shù)據(jù)點(diǎn),選中最優(yōu)的中心。這里最優(yōu)的定義是:該中心到該類的其他點(diǎn)的距離之和最小,這個是和 K-means 的定義一致的。不過從代碼里來看,這里顯然又用了向量化的方法而不是循環(huán)來處理了搜索。
讓我們來看一下代碼里是怎么做的:代碼里的?sparse?函數(shù)(具體用法請參考 Matlab 幫助)構(gòu)造了一個??的稀疏矩陣,不妨?xí)簳r記為?,如果第??個數(shù)據(jù)點(diǎn)屬于第??類的話,那么?,否則等于?。然后用 pair-wise distance 矩陣??去乘上?,得到一個??的矩陣暫時記為?。來看一下?,它是??的第??行和??的第??列內(nèi)積的結(jié)果。?的第??列標(biāo)記了所有屬于第??類的點(diǎn),其他位置全部是零,因此這樣內(nèi)積的結(jié)果就是所有第??類中的數(shù)據(jù)點(diǎn)到數(shù)據(jù)點(diǎn)??的距離之和。因此,對于第??類來說,只要求得??的第??列中數(shù)值最小的那個下標(biāo)對應(yīng)的數(shù)據(jù)點(diǎn),即是最優(yōu)的中心點(diǎn),而?min?函數(shù)是可以對于一個矩陣所有列同時求最小的,也就是代碼中該行達(dá)到的目的。
這里除了向量化之外還有一個注意事項(xiàng)就是稀疏矩陣。稀疏矩陣并不一定是很高效的,比如對里面的元素進(jìn)行下標(biāo)隨機(jī)訪問就會很慢,但是有許多其他操作則可以很快(如果用了合適的函數(shù)的話),比如矩陣相乘、矩陣遍歷(尋找非零元素或者尋找最大、最小值等)、解方程、求特征向量和特征值。比如說求特征向量,如果矩陣是稀疏的,那么可以用?eigs?來進(jìn)行求解,它的一個優(yōu)點(diǎn)是可以只求想要的幾個解,而不像?eig?那樣必須把所有解全部求出來,并且由于它是用迭代法,其中主要涉及到一些矩陣向量乘積之類的,用稀疏矩陣進(jìn)行運(yùn)算也會很快。當(dāng)然迭代法的缺點(diǎn)就是可能誤差比直接求解更大一些,數(shù)值穩(wěn)定性也更差一些。另外就是當(dāng)數(shù)據(jù)矩陣本身維度非常大但是又非常稀疏的時候,用稀疏矩陣非常節(jié)約內(nèi)存。下面是一個簡單的測試?yán)?#xff08;運(yùn)行時間在注釋里):
運(yùn)行時間差異還是比較清楚的,也就不用我多解釋了。不過有一點(diǎn)需要注意的是,eigs?在沒有指定個數(shù)的情況下默認(rèn)是只求 6 個特征值和特征向量的,所以和?eig?把所有的特征值特征向量全部求出來其實(shí)也并不是可以直接比較的。但是許多時候我們需要的都不是所有的特征向量和特征值,如果真的需要全部的話,用?eigs?來計(jì)算可能并不是一個合適的選擇,屆時可以自己嘗試和比較一下。由于求特征值和特征向量在各種基于 Graph 的方法(像 Laplacian Eigenmaps、Laplacian Regularized Least Square 等)中用得非常多,所以這些還是很有用的。這里可以順便簡單說一下稀疏矩陣的存儲。當(dāng)然是有各種存儲方式的,比較基本的比如按行存儲:矩陣是一個鏈表,把矩陣的每一行鏈起來,而每一行也是一個鏈表,把該行的非零元素鏈起來;類似的有按列存儲;此外可能還有完全按元素存儲,可以看成一個表格,如果(i,j)?位置有非零值的話,就索引到該值。根據(jù)不同的存儲方式,計(jì)算效率也會不一樣,比如一個按行存儲的矩陣乘以一個按列存儲的矩陣,就可以很快,因?yàn)榫仃囅喑说挠?jì)算方式就是左邊的行和右邊的列做內(nèi)積;但是如果反過來,一個按列存儲的矩陣乘以一個按行存儲的矩陣的話,就會比較麻煩了。Matlab 里做得比較好的是把稀疏矩陣搞得很透明,你不用關(guān)心它底層到底是怎么存儲的,大多數(shù)時候就像使用普通矩陣一樣用就 OK 了,并且性能也挺不錯。
回到我們原來的代碼,第 13 行和第 8 行是一樣的,只是現(xiàn)在使用計(jì)算出來的中心而不是隨機(jī)選出來的。剩下的就沒有什么好解釋的了。這里的 stop condition 是中心點(diǎn)不再變動,intuitively 想一想對于 K-medoids 來說這樣的 stop condition 似乎在一些比較特殊的情況下可能會出現(xiàn)來回振蕩不停下來的結(jié)果,不過那個不是我們今天要關(guān)注的問題了。
最后總結(jié)一下,要寫出高效的 Matlab 代碼的一些注意事項(xiàng):
- Profiler: Matlab 的 Profiler 是非常好用的,要善于利用這個工具,同所有其他編程語言一樣,找準(zhǔn) bottleneck 是進(jìn)行優(yōu)化的最重要的一步,如果只是想當(dāng)然地去搞的話可能浪費(fèi)了大把的精力又沒有把性能改善多少而且還把代碼搞得一團(tuán)糟。
- Sparsity: 如果問題是有稀疏性質(zhì)的,那么可以嘗試一下用稀疏矩陣和配套的那些操作。
- Vectorization: 向量化可以說是 Matlab 編程的一個特點(diǎn),就好像函數(shù)式編程總是一堆?map?呀filter?呀?reduce?呀之類的一樣。用好向量化是改善 Matlab 性能的關(guān)鍵。要多嘗試和練習(xí),逐漸習(xí)慣向量化的思維方式。特別是矩陣相乘呀、分塊之類的要熟練,例如我們在介紹代碼第 11 行的時候構(gòu)造的那個矩陣?,通常稱作 indicator matrix,元素只有 0 和 1,一般用于表示哪些元素被選出來了。這個矩陣不論是在計(jì)算上還是在公式推導(dǎo)上都經(jīng)常被用到。
總結(jié)
以上是生活随笔為你收集整理的让书写的Matlab代码运行更快 Recipes for Faster Matlab Code的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 稀疏性和L1正则化基础 Sparsity
- 下一篇: 会议投稿相关推荐