tensor torch 构造_详解Pytorch中的网络构造
背景
在PyTroch框架中,如果要自定義一個(gè)Net(網(wǎng)絡(luò),或者model,在本文中,model和Net擁有同樣的意思),通常需要繼承自nn.Module然后實(shí)現(xiàn)自己的layer。比如,在下面的示例中,gemfield(tiande亦有貢獻(xiàn))使用Pytorch實(shí)現(xiàn)了一個(gè)Net(可以看到其父類為nn.Module):
import torch import torch.nn as nn import torch.nn.functional as Fclass CivilNet(nn.Module):def __init__(self):super(CivilNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)self.gemfield = "gemfield.org"self.syszux = torch.zeros([1,1])def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x這就帶來了一系列的問題:
1,為什么要繼承自nn.Module?
2,網(wǎng)絡(luò)的各個(gè)layer或者module為什么要直接定義在構(gòu)造函數(shù)中,而不能(比方說)放在構(gòu)造函數(shù)中的一個(gè)list里?
3,forward函數(shù)什么時(shí)候會(huì)被調(diào)用?為什么要使用net(input)而不是net.forward(input)來做前向呢?
4,保存模型時(shí),保存的究竟是什么?
5,重新載入一個(gè)pth模型時(shí),究竟發(fā)生了什么?
你肯定要問了,為什么沒說到反向?因?yàn)榉聪蚴莖ptimizer和tensor的grad共同完成的,本文只討論Net部分,這一系列文章的后續(xù)部分會(huì)討論反向。
CivilNet的實(shí)例化
一個(gè)Net,也就是繼承自nn.Module的類,當(dāng)實(shí)例化后,本質(zhì)上就是維護(hù)了以下8個(gè)字典(OrderedDict):
_parameters _buffers _backward_hooks _forward_hooks _forward_pre_hooks _state_dict_hooks _load_state_dict_pre_hooks _modules這8個(gè)字典用于網(wǎng)絡(luò)的前向、反向、序列化、反序列化中。
因此,當(dāng)實(shí)例化你定義的Net(nn.Module的子類)時(shí),要確保父類的構(gòu)造函數(shù)首先被調(diào)用,這樣才能確保上述8個(gè)OrderedDict被create出來,否則,后續(xù)任何的初始化操作將拋出類似這樣的異常:cannot assign module before Module.__init__() call。
對(duì)于前述的CivilNet而言,當(dāng)CivilNet被實(shí)例化后,CivilNet本身維護(hù)了這8個(gè)OrderedDict,更重要的是,CivilNet中的conv1和conv2(類型為nn.modules.conv.Conv2d)、pool(類型為nn.modules.pooling.MaxPool2d)、fc1、fc2、fc3(類型為torch.nn.modules.linear.Linear)均維護(hù)了8個(gè)OrderedDict,因?yàn)樗鼈兊母割惗际莕n.Module,而gemfield(類型為str)、syszux(類型為torch.Tensor)則沒有這8個(gè)OrderedDict。
也因此,在你定義的網(wǎng)絡(luò)投入運(yùn)行前,必然要確保和上面一樣——構(gòu)造出那8個(gè)OrderedDict,這個(gè)構(gòu)造,就在nn.Module的構(gòu)造函數(shù)中。如此以來,你定義的Net就必須繼承自nn.Module;如果你的Net定義了__init__()方法,則必須在你的__init__方法中調(diào)用nn.Module的構(gòu)造函數(shù),比如super(your_class).__init__() ,注意,如果你的子類沒有定義__init__()方法,則在實(shí)例化的時(shí)候會(huì)默認(rèn)用nn.Module的,這種情況也對(duì)。
nn.Module通過使用__setattr__機(jī)制,使得定義在類中(不一定要定義在構(gòu)造函數(shù)里)的成員(比如各種layer),被有序歸屬到_parameters、_modules、_buffers或者普通的attribute里;那具體怎么歸屬呢?很簡(jiǎn)單,當(dāng)類成員的type 派生于Parameter類時(shí)(比如conv的weight,在CivilNet類中,就是self.conv1中的weight屬性),該屬性就會(huì)被劃歸為_parameters;當(dāng)類成員的type派生于Module時(shí)(比如CivilNet中的self.conv1,其實(shí)除了gemfield和syszux外都是),該成員就會(huì)劃歸為_modules。
如果知道了這個(gè)機(jī)制,就會(huì)自然而然的知道,如果上面的CivilNet里的成員封裝到一個(gè)list里,像下面這樣:
class CivilNet(nn.Module):def __init__(self):super(CivilNet, self).__init__()conv1 = nn.Conv2d(3, 6, 5)pool = nn.MaxPool2d(2, 2)conv2 = nn.Conv2d(6, 16, 5)self.layer1 = [conv1, pool, conv2]...那么在運(yùn)行的時(shí)候,可能optimizer就會(huì)提示parameters為empty。這就是因?yàn)槌蓡Tlayer1的type派生自list,而非Module;而像CivilNet這樣的Net,在取所有的parameters的時(shí)候,都是通過_modules橋梁去取得的......
1,_parameters
前述說到了parameters就是Net的權(quán)重參數(shù)(比如conv的weight、conv的bias、fc的weight、fc的bias),類型為tensor,用于前向和反向;比如,你針對(duì)Net使用cpu()、cuda()等調(diào)用的時(shí)候,實(shí)際上調(diào)用的就是parameter這個(gè)tensor的cpu()、cuda()等方法;再比如,你保存模型或者重新加載pth文件的時(shí)候,針對(duì)的都是parameter的操作或者賦值。
如果你針對(duì)的是CivilNet直接取_parameters屬性的值的話,很遺憾是空的,因?yàn)镃ivilNet的成員并沒有直接派生自Parameter類;但是當(dāng)針對(duì)CivilNet取parameters()函數(shù)的返回值(是個(gè)iter)時(shí),則會(huì)遞歸拿到所有的,比如conv的weight、bias等;
2,_buffers
該成員值的填充是通過register_buffer API來完成的,通常用來將一些需要持久化的狀態(tài)(但又不是網(wǎng)絡(luò)的參數(shù))放到_buffer里;一些極其個(gè)別的操作,比如BN,會(huì)將running_mean的值放入進(jìn)來;
3,_modules
_modules成員起很重要的橋梁作用,在獲取一個(gè)net的所有的parameters的時(shí)候,是通過遞歸遍歷該net的所有_modules來實(shí)現(xiàn)的。
像前述提到的那個(gè)問題,如果將這些成員都放倒一個(gè)python list里:self.layer1 = [conv1, pool, conv2] ——會(huì)導(dǎo)致CivilNet不能將conv1, pool, conv2等劃歸到_modules里,從而通過CivilNet的parameters()獲取所有權(quán)重參數(shù)時(shí),拿到的東西為空,就會(huì)報(bào)optimizer got an empty parameter list這樣的錯(cuò)誤。針對(duì)這種情況,那怎么辦呢?
ModuleList就是為了解決這個(gè)問題的,首先,ModuleList類的基類正是Module:
class ModuleList(Module)其次,ModuleList實(shí)現(xiàn)了python的list的功能;
最后,在使用ModuleList的時(shí)候,該類會(huì)使用基類(也就是Module)的add_module()方法,或者直接操作_modules成員來將list中的module成功注冊(cè)。
Sequential模塊也具備ModuleList這樣的注冊(cè)功能,另外其還實(shí)現(xiàn)了forward,這是和ModuleList不同的地方:
def forward(self, input):for module in self._modules.values():input = module(input)return inputCivilNet的前向
網(wǎng)絡(luò)的前向需要通過諸如CivilNet(input)這樣的形式來調(diào)用,而非CivilNet.forward(input),是因?yàn)榍罢邔?shí)現(xiàn)了額外的功能:
1,先執(zhí)行完所有的_forward_pre_hooks里的hooks 2, 再調(diào)用CivilNet的forward函數(shù) 3, 再執(zhí)行完所有的_forward_hooks中的hooks 4, 再執(zhí)行完所有的_backward_hooks中的hooks可以看到:
1,_forward_pre_hooks是在網(wǎng)絡(luò)的forward之前執(zhí)行的。這些hooks通過網(wǎng)絡(luò)的register_forward_pre_hook() API來完成注冊(cè),通常只有一些Norm操作會(huì)定義_forward_pre_hooks。這種hook不能改變input的內(nèi)容。
2,_forward_hooks是通過register_forward_hook來完成注冊(cè)的。這些hooks是在forward完之后被調(diào)用的,并且不應(yīng)該改變input和output。目前就是方便自己測(cè)試的時(shí)候可以用下。
3,_backward_hooks和_forward_hooks類似。
所以總結(jié)起來就是,如果你的網(wǎng)絡(luò)中沒有Norm操作,那么使用CivilNet(input)和CivilNet.forward(input)是等價(jià)的。
另外,你必須使用CivilNet.eval()操作來將dropout和BN這些op設(shè)置為eval模式,否則你將得到不一致的前向返回值。eval()調(diào)用會(huì)將Net的實(shí)例中的training成員設(shè)置為False。
CivilNet模型的保存和重新加載
如果我們要保存一個(gè)訓(xùn)練好哦PyTorch模型的話,會(huì)使用下面的API:
cn = CivilNet() ...... torch.save(cn.state_dict(), "your_model_path.pth")可以看到使用了網(wǎng)絡(luò)的state_dict() API調(diào)用以及torch模塊的save調(diào)用。一言以蔽之,模型的保存就是先通過state_dict() API的調(diào)用獲得一個(gè)關(guān)于網(wǎng)絡(luò)參數(shù)的字典,再通過pickle模塊序列化成文件的形式。
而如果我們要load一個(gè)pth模型來進(jìn)行前向的時(shí)候,會(huì)使用下面的API:
cn = CivilNet()#參數(shù)反序列化為python dict state_dict = torch.load("your_model_path.pth") #加載訓(xùn)練好的參數(shù) cn.load_state_dict(state_dict)#變成測(cè)試模式,dropout和BN在訓(xùn)練和測(cè)試時(shí)不一樣 #eval()會(huì)把模型中的每個(gè)module的self.training設(shè)置為False cn = cn.cuda().eval()可以看到使用了torch模塊的load調(diào)用和網(wǎng)絡(luò)的load_state_dict() API調(diào)用。一言以蔽之,模型的重新加載就是先通過torch.load反序列化pickle文件得到一個(gè)Dict,然后再使用該Dict去初始化當(dāng)前網(wǎng)絡(luò)的state_dict。torch的save和load API在python2中使用的是cPickle,在python3中使用的是pickle。另外需要注意的是,序列化的pth文件會(huì)被寫入header信息,包括magic number、version信息等。
關(guān)于模型的保存,我們需要弄清楚以下概念:1, state_dict;2, 序列化一個(gè)pth模型用于以后的前向;3, 為之后的再訓(xùn)練保存一個(gè)中間的checkpoint;4,將多個(gè)模型保存為一個(gè)文件;5,用其它模型的參數(shù)來初始化當(dāng)前的網(wǎng)絡(luò);6,跨設(shè)備的模型的保存和加載。
1, state_dict
在Pytorch中,可學(xué)習(xí)的參數(shù)(如Module中的weights和biases)是包含在網(wǎng)絡(luò)的parameters()調(diào)用返回的字典中的,這就是一個(gè)普通的OrderedDict,這里面的key-value是通過網(wǎng)絡(luò)及遞歸網(wǎng)絡(luò)里的Module成員獲取到的:它的key是每一個(gè)layer的成員的名字(加上prefix),而對(duì)應(yīng)的value是一個(gè)tensor。比如本文前述的CivilNet類,它的state_dict中的key如下所示:
conv1.weight conv1.bias conv2.weight conv2.bias fc1.weight fc1.bias fc2.weight fc2.bias fc3.weight fc3.bias那如果你使用了DataParallel來訓(xùn)練的話:
cn = nn.DataParallel(cn)那么state_dict中的key將如下所示:
module.conv1.weight module.conv1.bias module.conv2.weight module.conv2.bias module.fc1.weight module.fc1.bias module.fc2.weight module.fc2.bias module.fc3.weight module.fc3.bias如果你使用了ModuleList的話,比如前述CivilNet的定義你寫作了:
class CivilNet(nn.Module):def __init__(self):super(CivilNet, self).__init__()conv1 = nn.Conv2d(3, 6, 5)pool = nn.MaxPool2d(2, 2)conv2 = nn.Conv2d(6, 16, 5)fc1 = nn.Linear(16 * 5 * 5, 120)fc2 = nn.Linear(120, 84)fc3 = nn.Linear(84, 10)self.gemfield = nn.ModuleList([conv1, pool, conv2, fc1, fc2, fc3])那state_dict中的key將如下所示:
gemfield.1.weight gemfield.1.bias gemfield.2.weight gemfield.2.bias gemfield.3.weight gemfield.3.bias gemfield.4.weight gemfield.4.bias gemfield.5.weight gemfield.5.bias還有很多的變種,不過大抵上你也知道規(guī)律了。
2,load_state_dict
load_state_dict()調(diào)用是nn.Module的一個(gè)API,用模型文件反序列化后得到的Dict來初始化當(dāng)前的模型。需要提及的是這個(gè)函數(shù)上的 strict參數(shù),默認(rèn)值是True。因此在初始化時(shí)候,該函數(shù)會(huì)嚴(yán)格比較源Dict和目標(biāo)Dict的key是否一樣,不能多也不能少,必須嚴(yán)格一樣。
如果將strict參數(shù)設(shè)置為False,則將不會(huì)進(jìn)行這樣嚴(yán)格的check。只有key一樣的才會(huì)進(jìn)行賦值。
3,序列化模型以保存state_dict
這種情況是PyTorch中最常用的保存模型的方法。
#save torch.save(model.state_dict(), PATH)#load model = CivilNet(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()不再贅述。
4,序列化整個(gè)模型
#save torch.save(model, PATH) #load model = torch.load(PATH) model.eval()這種方式不推薦,其是通過Pickle模塊將整個(gè)class序列化了,序列化過程中依賴很多具體的東西,比如定義model class的路徑。這樣反序列化的時(shí)候就喪失了靈活性。
5,序列化中間過程中的checkpoint
這種序列化的目的是為了之后以這個(gè)狀態(tài)為基點(diǎn)重新開始訓(xùn)練。和前述序列化模型的本質(zhì)不同就在于還需要序列化optimizer的Dict(比如學(xué)習(xí)率等參數(shù))。傳統(tǒng)上,checkpoint文件用.tar作為后綴:
#save torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)#load model = CivilNet(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs)checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss']model.train() #model.eval()6,將多個(gè)模型序列化到一個(gè)文件里
比如,decoder-encoder這種結(jié)構(gòu)會(huì)有多個(gè)Net。傳統(tǒng)上,checkpoint文件用.tar作為后綴。
#save torch.save({'modelA_state_dict': modelA.state_dict(),'modelB_state_dict': modelB.state_dict(),'optimizerA_state_dict': optimizerA.state_dict(),'optimizerB_state_dict': optimizerB.state_dict(),...}, PATH)#load modelA = TheModelAClass(*args, **kwargs) modelB = TheModelBClass(*args, **kwargs) optimizerA = TheOptimizerAClass(*args, **kwargs) optimizerB = TheOptimizerBClass(*args, **kwargs)checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict']) modelB.load_state_dict(checkpoint['modelB_state_dict']) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])7,用一個(gè)模型的部分參數(shù)初始化另一個(gè)模型(遷移學(xué)習(xí))
這種情況的目的是為了復(fù)用一個(gè)模型的部分layer,以實(shí)現(xiàn)遷移學(xué)習(xí)。
#save torch.save(modelA.state_dict(), PATH)#load modelB = TheModelBClass(*args, **kwargs) modelB.load_state_dict(torch.load(PATH), strict=False)和前述保存模型相比,序列化部分一樣,反序列化只需要將strict參數(shù)設(shè)置為False。在前述load_state_dict章節(jié)中已經(jīng)解釋過,此處不再贅述。
8,跨device(cpu/gpu)來save/load模型
比如模型是在GPU上訓(xùn)練的,現(xiàn)在要load到cpu上。或者反之,或者在CPU上訓(xùn)練,在GPU上load。這三種情況下,save的方法是一樣的:
torch.save(model.state_dict(), PATH)而load的方法就不一樣了:
###############Save on GPU, Load on CPU ######### device = torch.device('cpu') model = CivilNet(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device))###############Save on GPU, Load on GPU ######### device = torch.device("cuda") model = CivilNet(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.to(device) #確保在輸入給網(wǎng)絡(luò)的tensor上調(diào)用input = input.to(device)###############Save on CPU, Load on GPU ######### device = torch.device("cuda") model = CivilNet(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want model.to(device) #確保在輸入給網(wǎng)絡(luò)的tensor上調(diào)用input = input.to(device)9,使用torch.nn.DataParallel訓(xùn)練的模型如何序列化
torch.nn.DataParallel 是一個(gè)wrapper,用來幫助在多個(gè)GPU上并行進(jìn)行運(yùn)算。這種情況下要保存訓(xùn)練好的模型,最好使用model.module.state_dict(),請(qǐng)參考本章第1節(jié):state_dict。這種情況下你在重新加載pth模型文件的時(shí)候,就會(huì)有極大的靈活性,而不是出現(xiàn)一大堆unexpected keys和missed keys:
torch.save(model.module.state_dict(), PATH)打印CivilNet
這個(gè)是靠__repr__機(jī)制,不再贅述;
cn = CivilNet() print(cn)另外,你的類可以重寫nn.Module的extra_repr()方法來實(shí)現(xiàn)定制化的打印。
總結(jié)
以上是生活随笔為你收集整理的tensor torch 构造_详解Pytorch中的网络构造的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python自动测试u_自动化测试——S
- 下一篇: python包里面的dll是什么_Pyt