CV算法复现(分类算法2/6):AlexNet(2012年 Hinton组)
生活随笔
收集整理的這篇文章主要介紹了
CV算法复现(分类算法2/6):AlexNet(2012年 Hinton组)
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
致謝:霹靂吧啦Wz:https://space.bilibili.com/18161609
目錄
致謝:霹靂吧啦Wz:https://space.bilibili.com/18161609
1 本次要點(diǎn)
1.1 深度學(xué)習(xí)理論
1.2 pytorch框架語(yǔ)法
2 網(wǎng)絡(luò)簡(jiǎn)介
2.1 歷史意義
2.2 網(wǎng)絡(luò)亮點(diǎn)
2.3 網(wǎng)絡(luò)架構(gòu)
3 代碼結(jié)構(gòu)
3.1?model.py
3.2?train.py
3.3?predict.py
3.4?split_data.py
1 本次要點(diǎn)
1.1 深度學(xué)習(xí)理論
- 經(jīng)過(guò)一次卷積操作后,圖像新尺寸計(jì)算公式:(如果padding [p1, p2]中p1,p2不相等,那么公式中2P就變?yōu)镻1+P2)(如果結(jié)果值不是整數(shù),pytorch中會(huì)自動(dòng)忽略最后一行以及最后一列,以保證N為整數(shù)。)
- ?
1.2 pytorch框架語(yǔ)法
- pytorch可以自定義網(wǎng)絡(luò)權(quán)重的初始化方法(見(jiàn)model.py)。
-
pata?=?list(net.parameters())?#查看模型參數(shù)
2 網(wǎng)絡(luò)簡(jiǎn)介
2.1 歷史意義
- 2012年ImageNet圖像分類冠軍網(wǎng)絡(luò),分類準(zhǔn)確率由傳統(tǒng)的 70%+直接提升到 80%+。在那年之后,深
度學(xué)習(xí)開(kāi)始迅速發(fā)展。
2.2 網(wǎng)絡(luò)亮點(diǎn)
- 首次利用 GPU 進(jìn)行網(wǎng)絡(luò)加速訓(xùn)練。
- 使用了 ReLU 激活函數(shù),而不是傳統(tǒng)的 Sigmoid 激活函數(shù)以及 Tanh 激活函數(shù)。
- 在前兩層的全連接層中使用了 Dropout 隨機(jī)失活神經(jīng)元操作,以減少過(guò)擬合。
2.3 網(wǎng)絡(luò)架構(gòu)
備注:padding:?[1, 2]即圖像最左邊緣加1列0,最右邊緣加2列0。圖像最上邊緣加1行0,圖像最下邊緣加2行0。
3 代碼結(jié)構(gòu)
- model.py
- train.py
- predict.py
- split_data.py(數(shù)據(jù)集劃分)
3.1?model.py
import torch.nn as nn
import torch"""
本AlexNet復(fù)現(xiàn)相比原論文,每層的卷積核個(gè)數(shù)減半。
"""
class AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()# nn.Sequential():將一系列層結(jié)構(gòu)進(jìn)行打包。省去每一層都用一個(gè)變量去表示。self.features = nn.Sequential(nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55]nn.ReLU(inplace=True), #inplace:通過(guò)增加計(jì)算量來(lái)降低內(nèi)存使用,從而可以載入更大模型(默認(rèn)False)。nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27]nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(128 * 6 * 6, 2048), # 輸入:128通道*6*6(特征圖大小)(到此之前會(huì)拉成1維)nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1) # torch中順序[B,C,H,W],start_dim=1就是將C維度拉平。x = self.classifier(x)return x# 初始化權(quán)重方式(框架有默認(rèn),如果要自定義可如下方式寫(xiě))def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)
3.2?train.py
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time"""
數(shù)據(jù)集:花分類(5類)
"""def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),#水平隨機(jī)翻轉(zhuǎn)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) #os.getcwd():獲取當(dāng)前絕對(duì)路徑。"../.."返回到上上層路徑。image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())#將鍵和值順序反過(guò)來(lái)。目的是讓模型預(yù)測(cè)的結(jié)果索引,可直接找到對(duì)應(yīng)的類型。# write dict into json filejson_str = json.dumps(cla_dict, indent=4)#編碼成json格式with open('class_indices.json', 'w') as json_file:#新建json文件并寫(xiě)入內(nèi)容json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=4, shuffle=False,num_workers=nw)print("using {} images for training, {} images fot validation.".format(train_num,# 查看數(shù)據(jù)集代碼 val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()## def imshow(img):# img = img / 2 + 0.5 # unnormalize# npimg = img.numpy()# plt.imshow(np.transpose(npimg, (1, 2, 0)))# plt.show()## print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))# imshow(utils.make_grid(test_image))net = AlexNet(num_classes=5, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()# pata = list(net.parameters()) #查看模型參數(shù)(調(diào)試用)optimizer = optim.Adam(net.parameters(), lr=0.0002)save_path = './AlexNet.pth'best_acc = 0.0for epoch in range(10):# 訓(xùn)練階段net.train() #自動(dòng)判定dropout或BN層是否應(yīng)該啟用。running_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()#反向傳播optimizer.step()#更新每個(gè)節(jié)點(diǎn)參數(shù)# print statisticsrunning_loss += loss.item()# print train process 打印訓(xùn)練信息rate = (step + 1) / len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)# 驗(yàn)證階段net.eval() #自動(dòng)判定dropout或BN層是否應(yīng)該啟用。acc = 0.0 # accumulate accurate number / epochwith torch.no_grad():#不去計(jì)算損失梯度f(wàn)or val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc: best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')if __name__ == '__main__':main()
訓(xùn)練結(jié)果:
3.3?predict.py
import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import jsondata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("../tulip.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():#不去計(jì)算損失梯度# predict classoutput = torch.squeeze(model(img))#torch.squeeze():對(duì)數(shù)據(jù)的維度進(jìn)行壓縮,去掉維數(shù)為1的的維度predict = torch.softmax(output, dim=0)#將預(yù)測(cè)結(jié)果值轉(zhuǎn)換為概率分布形式。predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()
輸出:
3.4?split_data.py
import os
from shutil import copy, rmtree
import random"""
使用步驟如下:
(1)在data_set文件夾下創(chuàng)建新文件夾"flower_data"
(2)點(diǎn)擊鏈接下載花分類數(shù)據(jù)集 http://download.tensorflow.org/example_images/flower_photos.tgz
(3)解壓數(shù)據(jù)集到flower_data文件夾下
(4)執(zhí)行"split_data.py"腳本自動(dòng)將數(shù)據(jù)集劃分成訓(xùn)練集train和驗(yàn)證集val├── flower_data ├── flower_photos(解壓的數(shù)據(jù)集文件夾,3670個(gè)樣本) ├── train(生成的訓(xùn)練集,3306個(gè)樣本) └── val(生成的驗(yàn)證集,364個(gè)樣本)
"""def mk_file(file_path: str):if os.path.exists(file_path):# 如果文件夾存在,則先刪除原文件夾在重新創(chuàng)建rmtree(file_path)os.makedirs(file_path)def main():# 保證隨機(jī)可復(fù)現(xiàn)random.seed(0)# 將數(shù)據(jù)集中10%的數(shù)據(jù)劃分到驗(yàn)證集中split_rate = 0.1# 指向你解壓后的flower_photos文件夾cwd = os.getcwd()data_root = os.path.join(cwd, "flower_data")origin_flower_path = os.path.join(data_root, "flower_photos")assert os.path.exists(origin_flower_path)flower_class = [cla for cla in os.listdir(origin_flower_path)if os.path.isdir(os.path.join(origin_flower_path, cla))]# 建立保存訓(xùn)練集的文件夾train_root = os.path.join(data_root, "train")mk_file(train_root)for cla in flower_class:# 建立每個(gè)類別對(duì)應(yīng)的文件夾mk_file(os.path.join(train_root, cla))# 建立保存驗(yàn)證集的文件夾val_root = os.path.join(data_root, "val")mk_file(val_root)for cla in flower_class:# 建立每個(gè)類別對(duì)應(yīng)的文件夾mk_file(os.path.join(val_root, cla))for cla in flower_class:cla_path = os.path.join(origin_flower_path, cla)images = os.listdir(cla_path)num = len(images)# 隨機(jī)采樣驗(yàn)證集的索引eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):if image in eval_index:# 將分配至驗(yàn)證集中的文件復(fù)制到相應(yīng)目錄image_path = os.path.join(cla_path, image)new_path = os.path.join(val_root, cla)copy(image_path, new_path)else:# 將分配至訓(xùn)練集中的文件復(fù)制到相應(yīng)目錄image_path = os.path.join(cla_path, image)new_path = os.path.join(train_root, cla)copy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing barprint()print("processing done!")if __name__ == '__main__':main()
輸出:
總結(jié)
以上是生活随笔為你收集整理的CV算法复现(分类算法2/6):AlexNet(2012年 Hinton组)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: CV算法复现(分类算法1/6):LeNe
- 下一篇: CV算法复现(分类算法3/6):VGG(