生活随笔
收集整理的這篇文章主要介紹了
Pytorch(七) --加载数据集
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
主要用到了Pytorch中的Dataset和DataLoadder這兩個(gè)方法,其中Dataset是抽象類,不能實(shí)例化對象,只能繼承用于構(gòu)造數(shù)據(jù)集,DataLoader是幫助加載數(shù)據(jù)的,可以做shuffle、batch_size,能拿Mini-batch進(jìn)行訓(xùn)練。
代碼如下:
import torch
import numpy
as np
from torch
.utils
.data
import Dataset
from torch
.utils
.data
import DataLoader
class DiabetesDataset(Dataset
):def __init__(self
,filepath
):xy
= np
.loadtxt
(filepath
,delimiter
=',',dtype
=np
.float32
)self
.len = xy
.shape
[0]self
.x_data
=torch
.from_numpy
(xy
[:,:-1])self
.y_data
= torch
.from_numpy
(xy
[:,[-1]])def __getitem__(self
,index
):return self
.x_data
[index
],self
.y_data
[index
]def __len__(self
):return self
.lendataset
= DiabetesDataset
('E:\\tmp\\.keras\\datasets\\diabetes.csv\\diabetes.csv')
train_loader
= DataLoader
(dataset
=dataset
,batch_size
= 32,shuffle
= True,num_workers
=0
)
class Model(torch
.nn
.Module
):def __init__(self
):super(Model
,self
).__init__
()self
.linear1
= torch
.nn
.Linear
(8,6)self
.linear2
= torch
.nn
.Linear
(6,4)self
.linear3
= torch
.nn
.Linear
(4,1)self
.relu
= torch
.nn
.ReLU
()self
.sigmoid
= torch
.nn
.Sigmoid
()def forward(self
,x
):x
= self
.relu
(self
.linear1
(x
))x
= self
.relu
(self
.linear2
(x
))x
= self
.sigmoid
(self
.linear3
(x
))return x
model
= Model
()
criterion
= torch
.nn
.BCELoss
(reduction
='mean')
optimizer
= torch
.optim
.SGD
(model
.parameters
(),lr
=0.01)
if __name__
== '__main__':for epoch
in range(1000):for i
,data
in enumerate(train_loader
,0):inputs
,labels
= datay_pred
= model
(inputs
)loss
=criterion
(y_pred
,labels
)print(epoch
,i
,loss
.item
())optimizer
.zero_grad
()loss
.backward
()optimizer
.step
()
努力加油a啊
總結(jié)
以上是生活随笔為你收集整理的Pytorch(七) --加载数据集的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔推薦給好友。