【深度学习】一文看懂 (Transfer Learning)迁移学习(pytorch实现)
前言
你會發現聰明人都喜歡”偷懶”, 因為這樣的偷懶能幫我們節省大量的時間, 提高效率. 還有一種偷懶是 “站在巨人的肩膀上”. 不僅能看得更遠, 還能看到更多. 這也用來表達我們要善于學習先輩的經驗, 一個人的成功往往還取決于先輩們累積的知識. 這句話, 放在機器學習中, 這就是今天要說的遷移學習了, transfer learning.
什么是遷移學習?
遷移學習通俗來講,就是運用已有的知識來學習新的知識,核心是找到已有知識和新知識之間的相似性,用成語來說就是舉一反三。由于直接對目標域從頭開始學習成本太高,我們故而轉向運用已有的相關知識來輔助盡快地學習新知識。比如,已經會下中國象棋,就可以類比著來學習國際象棋;已經會編寫Java程序,就可以類比著來學習C#;已經學會英語,就可以類比著來學習法語;等等。世間萬事萬物皆有共性,如何合理地找尋它們之間的相似性,進而利用這個橋梁來幫助學習新知識,是遷移學習的核心問題。
為什么需要遷移學習?
現在的機器人視覺已經非常先進了, 有些甚至超過了人類. 99.99%的識別準確率都不在話下. 這樣的成功, 依賴于強大的機器學習技術, 其中, 神經網絡成為了領軍人物. 而 CNN 等, 像人一樣擁有千千萬萬個神經聯結的結構, 為這種成功貢獻了巨大力量. 但是為了更厲害的 CNN, 我們的神經網絡設計, 也從簡單的幾層網絡, 變得越來越多, 越來越多, 越來越多… 為什么會越來越多?
因為計算機硬件, 比如 GPU 變得越來越強大, 能夠更快速地處理龐大的信息. 在同樣的時間內, 機器能學到更多東西. 可是, 不是所有人都擁有這么龐大的計算能力. 而且有時候面對類似的任務時, 我們希望能夠借鑒已有的資源.
如何做遷移學習?
這就好比, Google 和百度的關系, facebook 和人人的關系, KFC 和 麥當勞的關系, 同一類型的事業, 不用自己完全從頭做, 借鑒對方的經驗, 往往能節省很多時間. 有這樣的思路, 我們也能偷偷懶, 不用花時間重新訓練一個無比龐大的神經網絡, 借鑒借鑒一個已經訓練好的神經網絡就行.
比如這樣的一個神經網絡, 我花了兩天訓練完之后, 它已經能正確區分圖片中具體描述的是男人, 女人還是眼鏡. 說明這個神經網絡已經具備對圖片信息一定的理解能力. 這些理解能力就以參數的形式存放在每一個神經節點中. 不巧, 領導下達了一個緊急任務,
要求今天之內訓練出來一個預測圖片里實物價值的模型. 我想這可完蛋了, 上一個圖片模型都要花兩天, 如果要再搭個模型重新訓練, 今天肯定出不來呀.
這時, 遷移學習來拯救我了. 因為這個訓練好的模型中已經有了一些對圖片的理解能力, 而模型最后輸出層的作用是分類之前的圖片, 對于現在計算價值的任務是用不到的, #所以我將最后一層替換掉, 變為服務于現在這個任務的輸出層. #接著只訓練新加的輸出層, 讓理解力保持始終不變. 前面的神經層龐大的參數不用再訓練, 節省了我很多時間, 我也在一天時間內, 將這個任務順利完成.
但并不是所有時候我們都需要遷移學習. 比如神經網絡很簡單, 相比起計算機視覺中龐大的 CNN 或者語音識別的 RNN, 訓練小的神經網絡并不需要特別多的時間, 我們完全可以直接重頭開始訓練. 從頭開始訓練也是有好處的.
如果固定住之前的理解力, 或者使用更小的學習率來更新借鑒來的模型, 就變得有點像認識一個人時的第一印象, 如果遷移前的數據和遷移后的數據差距很大, 或者說我對于這個人的第一印象和后續印象差距很大, 我還不如不要管我的第一印象, 同理, 這時, 遷移來的模型并不會起多大作用, 還可能干擾我后續的決策.
遷移學習的限制
比如說,我們不能隨意移除預訓練網絡中的卷積層。但由于參數共享的關系,我們可以很輕松地在不同空間尺寸的圖像上運行一個預訓練網絡。這在卷積層和池化層和情況下是顯而易見的,因為它們的前向函數(forward function)獨立于輸入內容的空間尺寸。在全連接層(FC)的情形中,這仍然成立,因為全連接層可被轉化成一個卷積層。所以當我們導入一個預訓練的模型時,網絡結構需要與預訓練的網絡結構相同,然后再針對特定的場景和任務進行訓練。
常見的遷移學習方式:
載權重后訓練所有參數
載入權重后只訓練最后幾層參數
載入權重后在原網絡基礎上再添加一層全鏈接層,僅訓練最后一個全鏈接層
衍生
了解了一般的遷移學習玩法后, 我們看看前輩們還有哪些新玩法. 多任務學習, 或者強化學習中的 learning to learn, 遷移機器人對運作形式的理解, 解決不同的任務. 炒個蔬菜, 紅燒肉, 番茄蛋花湯雖然菜色不同, 但是做菜的原則是類似的.
又或者 google 的翻譯模型, 在某些語言上訓練, 產生出對語言的理解模型, 將這個理解模型當做遷移模型在另外的語言上訓練. 其實說白了, 那個遷移的模型就能看成機器自己發明的一種只有它自己才能看懂的語言. 然后用自己的這個語言模型當成翻譯中轉站, 將某種語言轉成自己的語言, 然后再翻譯成另外的語言. 遷移學習的腦洞還有很多, 相信這種站在巨人肩膀上繼續學習的方法, 還會帶來更多有趣的應用.
使用圖像數據進行遷移學習
牛津 VGG 模型(http://www.robots.ox.ac.uk/~vgg/research/very_deep/)
谷歌 Inception模型(https://github.com/tensorflow/models/tree/master/inception)
微軟 ResNet 模型(https://github.com/KaimingHe/deep-residual-networks)
可以在 Caffe Model Zoo(https://github.com/BVLC/caffe/wiki/Model-Zoo)中找到更多的例子,那里分享了很多預訓練的模型。
實例:
注:如何獲取官方的.pth文件,以resnet為例子
import torchvision.models.resnet在腳本中輸入以上代碼,將鼠標對住resnet并按ctrl鍵,發現改變顏色,點擊進入resnet.py腳本,在最開始有url,如下圖所示選擇你要下載的模型,copy到瀏覽器即可,若是覺得慢可以用迅雷等等。
ResNet詳細講解在這篇博文里:ResNet——CNN經典網絡模型詳解(pytorch實現)
#train.pyimport?torch import?torch.nn?as?nn from?torchvision?import?transforms,?datasets import?json import?matplotlib.pyplot?as?plt import?os import?torch.optim?as?optim from?model?import?resnet34,?resnet101 import?torchvision.models.resnetdevice?=?torch.device("cuda:0"?if?torch.cuda.is_available()?else?"cpu") print(device)data_transform?=?{"train":?transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225])]),#來自官網參數"val":?transforms.Compose([transforms.Resize(256),#將最小邊長縮放到256transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225])])}data_root?=?os.getcwd() image_path?=?data_root?+?"/flower_data/"??#?flower?data?set?pathtrain_dataset?=?datasets.ImageFolder(root=image_path?+?"train",transform=data_transform["train"]) train_num?=?len(train_dataset)#?{'daisy':0,?'dandelion':1,?'roses':2,?'sunflower':3,?'tulips':4} flower_list?=?train_dataset.class_to_idx cla_dict?=?dict((val,?key)?for?key,?val?in?flower_list.items()) #?write?dict?into?json?file json_str?=?json.dumps(cla_dict,?indent=4) with?open('class_indices.json',?'w')?as?json_file:json_file.write(json_str)batch_size?=?16 train_loader?=?torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,?shuffle=True,num_workers=0)validate_dataset?=?datasets.ImageFolder(root=image_path?+?"/val",transform=data_transform["val"]) val_num?=?len(validate_dataset) validate_loader?=?torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size,?shuffle=False,num_workers=0) net?=?resnet34() #?net?=?resnet34(num_classes=5) #?load?pretrain?weightsmodel_weight_path?=?"./resnet34-pre.pth" missing_keys,?unexpected_keys?=?net.load_state_dict(torch.load(model_weight_path),?strict=False)#載入模型參數#?for?param?in?net.parameters(): #?????param.requires_grad?=?False #?change?fc?layer?structureinchannel?=?net.fc.in_features net.fc?=?nn.Linear(inchannel,?5)net.to(device)loss_function?=?nn.CrossEntropyLoss() optimizer?=?optim.Adam(net.parameters(),?lr=0.0001)best_acc?=?0.0 save_path?=?'./resNet34.pth' for?epoch?in?range(3):#?trainnet.train()running_loss?=?0.0for?step,?data?in?enumerate(train_loader,?start=0):images,?labels?=?dataoptimizer.zero_grad()logits?=?net(images.to(device))loss?=?loss_function(logits,?labels.to(device))loss.backward()optimizer.step()#?print?statisticsrunning_loss?+=?loss.item()#?print?train?processrate?=?(step+1)/len(train_loader)a?=?"*"?*?int(rate?*?50)b?=?"."?*?int((1?-?rate)?*?50)print("\rtrain?loss:?{:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100),?a,?b,?loss),?end="")print()#?validatenet.eval()acc?=?0.0??#?accumulate?accurate?number?/?epochwith?torch.no_grad():for?val_data?in?validate_loader:val_images,?val_labels?=?val_dataoutputs?=?net(val_images.to(device))??#?eval?model?only?have?last?output?layer#?loss?=?loss_function(outputs,?test_labels)predict_y?=?torch.max(outputs,?dim=1)[1]acc?+=?(predict_y?==?val_labels.to(device)).sum().item()val_accurate?=?acc?/?val_numif?val_accurate?>?best_acc:best_acc?=?val_accuratetorch.save(net.state_dict(),?save_path)print('[epoch?%d]?train_loss:?%.3f??test_accuracy:?%.3f'?%(epoch?+?1,?running_loss?/?step,?val_accurate))print('Finished?Training')未使用遷移學習VGG16
#train.pyimport?torch.nn?as?nn from?torchvision?import?transforms,?datasets import?json import?os import?torch.optim?as?optim from?model?import?vgg import?torch import?time import?torchvision.models.vgg from?torchvision?import?modelsdevice?=?torch.device("cuda:0"?if?torch.cuda.is_available()?else?"cpu") print(device)#數據預處理,從頭data_transform?=?{"train":?transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))]),"val":?transforms.Compose([transforms.Resize((224,?224)),transforms.ToTensor(),transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))])}data_root?=?os.path.abspath(os.path.join(os.getcwd(),?"../../.."))??#?get?data?root?path image_path?=?data_root?+?"/data_set/flower_data/"??#?flower?data?set?pathhtrain_dataset?=?datasets.ImageFolder(root=image_path?+?"/train",transform=data_transform["train"]) train_num?=?len(train_dataset)#?{'daisy':0,?'dandelion':1,?'roses':2,?'sunflower':3,?'tulips':4} flower_list?=?train_dataset.class_to_idx cla_dict?=?dict((val,?key)?for?key,?val?in?flower_list.items()) #?write?dict?into?json?file json_str?=?json.dumps(cla_dict,?indent=4) with?open('class_indices.json',?'w')?as?json_file:json_file.write(json_str)batch_size?=?20 train_loader?=?torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,?shuffle=True,num_workers=0)validate_dataset?=?datasets.ImageFolder(root=image_path?+?"val",transform=data_transform["val"]) val_num?=?len(validate_dataset) validate_loader?=?torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size,?shuffle=False,num_workers=0)#?test_data_iter?=?iter(validate_loader) #?test_image,?test_label?=?test_data_iter.next()#?model #?=?models.vgg16(pretrained=True)# #?model_name?=?"vgg16" #?net?=?vgg(model_name=model_name,?init_weights=True)#?load?pretrain?weights net?=?models.vgg16(pretrained=False) pre?=?torch.load("./vgg16.pth") net.load_state_dict(pre)for?parma?in?net.parameters():parma.requires_grad?=?Falsenet.classifier?=?torch.nn.Sequential(torch.nn.Linear(25088,?4096),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(4096,?4096),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(4096,?5))#?model_weight_path?=?"./vgg16.pth" #?missing_keys,?unexpected_keys?=?net.load_state_dict(torch.load(model_weight_path),?strict=False)#載入模型參數#?#?for?param?in?net.parameters(): #?#?????param.requires_grad?=?False #?#?change?fc?layer?structure # #?inchannel?=?512 #?net.classifier?=?nn.Linear(inchannel,?5)loss_function?=?torch.nn.CrossEntropyLoss() optimizer?=?optim.Adam(net.classifier.parameters(),?lr=0.001)#?loss_function?=?nn.CrossEntropyLoss() #?optimizer?=?optim.Adam(net.parameters(),?lr=0.0001)?#learn?rate net.to(device)best_acc?=?0.0 #save_path?=?'./{}Net.pth'.format(model_name) save_path?=?'./vgg16Net.pth' for?epoch?in?range(15):#?trainnet.train()running_loss?=?0.0?#統計訓練過程中的平均損失t1?=?time.perf_counter()for?step,?data?in?enumerate(train_loader,?start=0):images,?labels?=?dataoptimizer.zero_grad()#with?torch.no_grad():?#用來消除驗證階段的loss,由于梯度在驗證階段不能傳回,造成梯度的累計outputs?=?net(images.to(device))loss?=?loss_function(outputs,?labels.to(device))??#得到預測值與真實值的一個損失loss.backward()optimizer.step()#更新結點參數#?print?statisticsrunning_loss?+=?loss.item()#?print?train?processrate?=?(step?+?1)?/?len(train_loader)a?=?"*"?*?int(rate?*?50)b?=?"."?*?int((1?-?rate)?*?50)print("\rtrain?loss:?{:^3.0f}%[{}->{}]{:.3f}".format(int(rate?*?100),?a,?b,?loss),?end="")print()print(time.perf_counter()?-?t1)#?validatenet.eval()acc?=?0.0??#?accumulate?accurate?number?/?epochwith?torch.no_grad():#不去跟蹤損失梯度for?val_data?in?validate_loader:val_images,?val_labels?=?val_data#optimizer.zero_grad()outputs?=?net(val_images.to(device))predict_y?=?torch.max(outputs,?dim=1)[1]acc?+=?(predict_y?==?val_labels.to(device)).sum().item()val_accurate?=?acc?/?val_numif?val_accurate?>?best_acc:best_acc?=?val_accuratetorch.save(net.state_dict(),?save_path)print('[epoch?%d]?train_loss:?%.3f??test_accuracy:?%.3f'?%(epoch?+?1,?running_loss?/?step,?val_accurate))print('Finished?Training')densenet121
#train.pyimport?torch import?torch.nn?as?nn from?torchvision?import?transforms,?datasets import?json import?matplotlib.pyplot?as?plt from?model?import?densenet121 import?os import?torch.optim?as?optim import?torchvision.models.densenet import?torchvision.models?as?modelsdevice?=?torch.device("cuda:0"?if?torch.cuda.is_available()?else?"cpu") print(device)data_transform?=?{"train":?transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225])]),#來自官網參數"val":?transforms.Compose([transforms.Resize(256),#將最小邊長縮放到256transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225])])}data_root?=?os.path.abspath(os.path.join(os.getcwd(),?"../../.."))??#?get?data?root?path image_path?=?data_root?+?"/data_set/flower_data/"??#?flower?data?set?pathtrain_dataset?=?datasets.ImageFolder(root=image_path?+?"train",transform=data_transform["train"]) train_num?=?len(train_dataset)#?{'daisy':0,?'dandelion':1,?'roses':2,?'sunflower':3,?'tulips':4} flower_list?=?train_dataset.class_to_idx cla_dict?=?dict((val,?key)?for?key,?val?in?flower_list.items()) #?write?dict?into?json?file json_str?=?json.dumps(cla_dict,?indent=4) with?open('class_indices.json',?'w')?as?json_file:json_file.write(json_str)batch_size?=?16 train_loader?=?torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,?shuffle=True,num_workers=0)validate_dataset?=?datasets.ImageFolder(root=image_path?+?"/val",transform=data_transform["val"]) val_num?=?len(validate_dataset) validate_loader?=?torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size,?shuffle=False,num_workers=0)#遷移學習 net?=?models.densenet121(pretrained=False) model_weight_path="./densenet121-a.pth" missing_keys,?unexpected_keys?=?net.load_state_dict(torch.load(model_weight_path),?strict=?False)inchannel?=?net.classifier.in_features net.classifier?=?nn.Linear(inchannel,?5) net.to(device)loss_function?=?nn.CrossEntropyLoss() optimizer?=?optim.Adam(net.parameters(),?lr=0.0001)#普通#?model_name?=?"densenet121" #?net?=?densenet121(model_name=model_name,?num_classes=5)best_acc?=?0.0 save_path?=?'./densenet121.pth' for?epoch?in?range(12):#?trainnet.train()running_loss?=?0.0for?step,?data?in?enumerate(train_loader,?start=0):images,?labels?=?dataoptimizer.zero_grad()logits?=?net(images.to(device))loss?=?loss_function(logits,?labels.to(device))loss.backward()optimizer.step()#?print?statisticsrunning_loss?+=?loss.item()#?print?train?processrate?=?(step+1)/len(train_loader)a?=?"*"?*?int(rate?*?50)b?=?"."?*?int((1?-?rate)?*?50)print("\rtrain?loss:?{:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100),?a,?b,?loss),?end="")print()#?validatenet.eval()acc?=?0.0??#?accumulate?accurate?number?/?epochwith?torch.no_grad():for?val_data?in?validate_loader:val_images,?val_labels?=?val_dataoutputs?=?net(val_images.to(device))??#?eval?model?only?have?last?output?layer#?loss?=?loss_function(outputs,?test_labels)predict_y?=?torch.max(outputs,?dim=1)[1]acc?+=?(predict_y?==?val_labels.to(device)).sum().item()val_accurate?=?acc?/?val_numif?val_accurate?>?best_acc:best_acc?=?val_accuratetorch.save(net.state_dict(),?save_path)print('[epoch?%d]?train_loss:?%.3f??test_accuracy:?%.3f'?%(epoch?+?1,?running_loss?/?step,?val_accurate))print('Finished?Training')使用
注:部分圖片來自于莫凡python
總結
以上是生活随笔為你收集整理的【深度学习】一文看懂 (Transfer Learning)迁移学习(pytorch实现)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【职场】公司利益和个人利益,永远不可能完
- 下一篇: 学完可以解决90%以上的数据分析问题-利