pytorch动态网络以及权重共享
生活随笔
收集整理的這篇文章主要介紹了
pytorch动态网络以及权重共享
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
pytorch 動態網絡+權值共享
pytorch以動態圖著稱,下面以一個栗子來實現動態網絡和權值共享技術:
# -*- coding: utf-8 -*- import random import torchclass DynamicNet(torch.nn.Module):def __init__(self, D_in, H, D_out):"""這里構造了幾個向前傳播過程中用到的線性函數"""super(DynamicNet, self).__init__()self.input_linear = torch.nn.Linear(D_in, H)self.middle_linear = torch.nn.Linear(H, H)self.output_linear = torch.nn.Linear(H, D_out)def forward(self, x):"""For the forward pass of the model, we randomly choose either 0, 1, 2, or 3and reuse the middle_linear Module that many times to compute hidden layerrepresentations.Since each forward pass builds a dynamic computation graph, we can use normalPython control-flow operators like loops or conditional statements whendefining the forward pass of the model.Here we also see that it is perfectly safe to reuse the same Module manytimes when defining a computational graph. This is a big improvement from LuaTorch, where each Module could be used only once.這里中間層每次向前過程中都是隨機添加0-3層,而且中間層都是使用的同一個線性層,這樣計算時,權值也是用的同一個。"""h_relu = self.input_linear(x).clamp(min=0)for _ in range(random.randint(0, 3)):h_relu = self.middle_linear(h_relu).clamp(min=0)y_pred = self.output_linear(h_relu)return y_pred# N is batch size; D_in is input dimension;# H is hidden dimension; D_out is output dimension.N, D_in, H, D_out = 64, 1000, 100, 10# Create random Tensors to hold inputs and outputsx = torch.randn(N, D_in)y = torch.randn(N, D_out)# Construct our model by instantiating the class defined abovemodel = DynamicNet(D_in, H, D_out)# Construct our loss function and an Optimizer. Training this strange model with# vanilla stochastic gradient descent is tough, so we use momentumcriterion = torch.nn.MSELoss(reduction='sum')optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)for t in range(500):# Forward pass: Compute predicted y by passing x to the modely_pred = model(x)# Compute and print lossloss = criterion(y_pred, y)print(t, loss.item())# Zero gradients, perform a backward pass, and update the weights.optimizer.zero_grad()loss.backward()optimizer.step()這個程序實際上是一種RNN結構,在執行過程中動態的構建計算圖
References: Pytorch Documentations.
總結
以上是生活随笔為你收集整理的pytorch动态网络以及权重共享的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 周杰伦讲给快手的“独家秘密”
- 下一篇: Unity3D 动态加载本地/网络GLB