快速掌握TensorFlow中张量运算的广播机制
相信大家在使用numpy和tensorflow的時候都會遇到如下的錯誤
ValueError: operands could not be broadcast together with shapes (4,3) (4,)這是由于numpy和tensorflow中的張量在進行運算的時候形狀不滿足廣播機制的要求,不理解廣播機制的同學可能會通過各種魔改代碼來讓代碼正常運行起來,但是卻不知道為什么那樣改就可以。
本文將從原理上介紹張量運算中經常用到的廣播機制。
廣播(broadcasting)指的是不同形狀的張量之間的算數運算的執行方式。
通過兩個例子直觀了解廣播
數組與標量值的乘法
import numpy as np arr = np.arange(5) arr #-> array([0, 1, 2, 3, 4]) arr * 4 #-> array([ 0, 4, 8, 12, 16])在上面的乘法運算中,標量值4被廣播到了其他所有元素上
通過減去列平均值的方式對數組每一列進行距平化處理
arr = np.random.randn(4,3) arr #-> array([[ 1.83518156, 0.86096695, 0.18681254],# [ 1.32276051, 0.97987486, 0.27828887],# [ 0.65269467, 0.91924574, -0.71780692],# [-0.05431312, 0.58711748, -1.21710134]]) arr.mean(axis=0) #-> array([ 0.93908091, 0.83680126, -0.36745171])關于mean中的axis參數,可以這樣理解:
在numpy中,axis = 0為行軸(豎直方向),axis = 1為列軸(水平方向),指定axis表示該操作沿axis進行,得到結果將是一個shape為除去該axis的array,對于多維張量,axis=i是指運算操作沿著第i個張量下標變化的方向進行。
在上例中,arr.mean(axis=0)表示對arr沿著軸0(豎直方向)求均值。顯然,第0個下標變化的方向即為豎直方向,以第一列為例,4個元素的下標分別為[(0,0),(1,0),(2,0),(3,0)]。
而arr的shape為(4,3),除去axis=0的shape,結果為(1,3)或者(3,),這與上面的代碼運行結果相符。
廣播機制的原理
★如果兩個數組的后緣維度(從末尾開始算起的維度)的軸長度相符或其中一方的長度為1,則認為它們是廣播兼容的。廣播會在缺失維度和(或)軸長度為1的維度上進行。
”demeaned = arr - arr.mean(axis=0) demeaned > array([[ 0.89610065, 0.02416569, 0.55426426],[ 0.3836796 , 0.1430736 , 0.64574058],[-0.28638623, 0.08244448, -0.35035521],[-0.99339402, -0.24968378, -0.84964963]]) demeaned.mean(axis=0) > array([ -5.55111512e-17, -5.55111512e-17, 0.00000000e+00])在上面的對arr每一列減去列平均值的例子中,arr的后緣維度為3,arr.mean(0)后緣維度也是3,滿足軸長度相符的條件,廣播會在缺失維度進行。
這里有點奇怪的是缺失維度不是axis=1,而是axis=0,個人理解是缺失維度指的是兩個arr除了軸長度匹配的維度,在上面的例子中,正好是axis=0。
arr.mean(0)沿著axis=0廣播,可以看作是把arr.mean(0)沿著豎直方向復制4份,即廣播的時候arr.mean(0)相當于一個shape=(4,3)的數組,數組的每一行均相同,均為arr.mean(0)
各行減去行均值
row_means = arr.mean(axis=1) row_means.shape > (4,) arr - row_means > ---------------------------------------------------------------------------ValueError Traceback (most recent call last)<ipython-input-10-3d1314c7e700> in <module>()----> 1 arr - row_meansValueError: operands could not be broadcast together with shapes (4,3) (4,)直接相減,報錯,無法進行廣播。
回顧上面的原則,要么滿足后緣維度軸長度相等,要么滿足其中一方長度為1。在這個例子中,兩者均不滿足,所以報錯。根據廣播原則,較小數組的廣播維必須為1。解決方案是為較小的數組添加一個長度為1的新軸。
numpy提供了一種通過索引機制插入軸的特殊語法。通過特殊的np.newaxis屬性以及“全”切片來插入新軸。
下面的例子中,我們通過插入新軸的方式實現二維數組各行減去行均值。這里將行均值沿著水平方向進行廣播,廣播軸為axis=1,對row_means添加一個新軸axis=1
row_means[:,np.newaxis].shape > (4, 1) arr - row_means[:,np.newaxis] > array([[ 0.87419454, -0.10002007, -0.77417447],[ 0.46245243, 0.11956678, -0.58201921],[ 0.36798351, 0.63453458, -1.00251808],[ 0.17378588, 0.81521647, -0.98900235]])另一個例子
a = np.array([1,2,3]) a.shape # -> (3,) b = np.array([[1,],[2,],[3]]) # -> (3,1) b - a # -> array([[ 0, -1, -2],# [ 1, 0, -1],# [ 2, 1, 0]])上面的例子輸出為什么是一個3*3的數組??
我們來分析一下,根據廣播原則,b滿足其中一方軸長度為1,那么廣播會沿著長度為1的軸,及axis=1進行,對數組b沿著axis=1即水平方向進行復制,相當于b變成一個shape為(3,3)且各列均為[1,2,3]的數組。
一個維度為(3,3)的數組減去一個維度為(3,)的數組,滿足后緣維度軸長度相等,數組a沿著axis=0即豎直方向進行廣播,相當遠a變成一個shape為(3,3)且個行均為[1,2,3]的數組。
b-a的時候,
?b被廣播成為
[[1,1,1],[2,2,2],[3,3,3]]a被廣播成為
[[1,2,3],[1,2,3],[1,2,3]]所以b-a的結果是
[[0,-1,-2],[1, 0,-1],[2, 1, 0]]三維情況
下面的例子中,構造一個3*4*5的隨機數組arr_3d,我們希望實現對arr_3d的每個元素減去其深度(axis=2)方向的均值
#構造三維數組 arr_3d = np.random.randn(3,4,5) #求深度方向的均值,想想結果的shape是什么?原始shape是(3,4,5) #除去axis=2后還剩(3,4) depth_means = arr_3d.mean(axis=2) depth_means.shape > (3, 4) #arr(3,4,5)和depth_means(3,4)不能直接廣播,后緣維度不相符且不存在軸長度為1的軸 arr_3d_new = arr_3d - depth_means[:,:,np.newaxis] #所以我們添加廣播軸 arr_3d_new.mean(axis=2) #結果應該為0,這里是接近0的浮點數,符合預期> array([[ -5.55111512e-17, 4.44089210e-17, 4.44089210e-17, 4.44089210e-17],[ -8.88178420e-17, -1.11022302e-16, -6.66133815e-17,0.00000000e+00],[ 0.00000000e+00, -7.77156117e-17, -2.22044605e-17,-2.22044605e-17]])以上就是關于張量運算中廣播機制的一點介紹,歡迎關注公眾號淺夢的學習筆記,一起討論交流!
關于本站
“機器學習初學者”公眾號由是黃海廣博士創建,黃博個人知乎粉絲23000+,github排名全球前110名(32000+)。本公眾號致力于人工智能方向的科普性文章,為初學者提供學習路線和基礎資料。原創作品有:吳恩達機器學習個人筆記、吳恩達深度學習筆記等。
往期精彩回顧
那些年做的學術公益-你不是一個人在戰斗
適合初學者入門人工智能的路線及資料下載
吳恩達機器學習課程筆記及資源(github標星12000+,提供百度云鏡像)
吳恩達深度學習筆記及視頻等資源(github標星8500+,提供百度云鏡像)
《統計學習方法》的python代碼實現(github標星7200+)
精心整理和翻譯的機器學習的相關數學資料
首發:深度學習入門寶典-《python深度學習》原文代碼中文注釋版及電子書
備注:加入本站微信群或者qq群,請回復“加群”
加入知識星球(4300+用戶,ID:92416895),請回復“知識星球”
與50位技術專家面對面20年技術見證,附贈技術全景圖總結
以上是生活随笔為你收集整理的快速掌握TensorFlow中张量运算的广播机制的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 原创:机器学习代码练习(一、回归)
- 下一篇: 员外陪你读论文:DeepWalk: On