PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析
生活随笔
收集整理的這篇文章主要介紹了
PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
PyTorch:采用sklearn 工具生成這樣的合成數(shù)據(jù)集+利用PyTorch實(shí)現(xiàn)簡(jiǎn)單合成數(shù)據(jù)集上的線性回歸進(jìn)行數(shù)據(jù)分析
?
目錄
輸出結(jié)果
核心代碼
?
?
輸出結(jié)果
?
?
核心代碼
#PyTorch:采用sklearn 工具生成這樣的合成數(shù)據(jù)集+利用PyTorch實(shí)現(xiàn)簡(jiǎn)單合成數(shù)據(jù)集上的線性回歸進(jìn)行數(shù)據(jù)分析 from sklearn.datasets import make_regression import seaborn as sns import pandas as pd import matplotlib.pyplot as pltsns.set()x_train, y_train, W_target = make_regression(n_samples=100, n_features=1, noise=10, coef = True)df = pd.DataFrame(data = {'X':x_train.ravel(), 'Y':y_train.ravel()})sns.lmplot(x='X', y='Y', data=df, fit_reg=True) plt.show()x_torch = torch.FloatTensor(x_train) y_torch = torch.FloatTensor(y_train) y_torch = y_torch.view(y_torch.size()[0], 1) class LinearRegression(torch.nn.Module): #定義LR的類。torch.nn庫(kù)構(gòu)建模型#PyTorch 的 nn 庫(kù)中有大量有用的模塊,其中一個(gè)就是線性模塊。如名字所示,它對(duì)輸入執(zhí)行線性變換,即線性回歸。def __init__(self, input_size, output_size):super(LinearRegression, self).__init__()self.linear = torch.nn.Linear(input_size, output_size) def forward(self, x):return self.linear(x)model = LinearRegression(1, 1)criterion = torch.nn.MSELoss() #訓(xùn)練線性回歸,我們需要從 nn 庫(kù)中添加合適的損失函數(shù)。對(duì)于線性回歸,我們將使用 MSELoss()——均方差損失函數(shù) optimizer = torch.optim.SGD(model.parameters(), lr=0.1)#還需要使用優(yōu)化函數(shù)(SGD),并運(yùn)行與之前示例類似的反向傳播。本質(zhì)上,我們重復(fù)上文定義的 train() 函數(shù)中的步驟。 #不能直接使用該函數(shù)的原因是我們實(shí)現(xiàn)它的目的是分類而不是回歸,以及我們使用交叉熵?fù)p失和最大元素的索引作為模型預(yù)測(cè)。而對(duì)于線性回歸,我們使用線性層的輸出作為預(yù)測(cè)。for epoch in range(50):data, target = Variable(x_torch), Variable(y_torch)output = model(data)optimizer.zero_grad()loss = criterion(output, target)loss.backward()optimizer.step()predicted = model(Variable(x_torch)).data.numpy()#打印出原始數(shù)據(jù)和適合 PyTorch 的線性回歸 plt.plot(x_train, y_train, 'o', label='Original data') plt.plot(x_train, predicted, label='Fitted line')plt.legend() plt.title(u'Py:PyTorch實(shí)現(xiàn)簡(jiǎn)單合成數(shù)據(jù)集上的線性回歸進(jìn)行數(shù)據(jù)分析') plt.show()?
?
總結(jié)
以上是生活随笔為你收集整理的PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Py之seaborn:seaborn库的
- 下一篇: DL之LeNet-5:LeNet-5算法