matlab卷积神经网络的创建与图片识别
1、Deep Network Designer工具箱使用介紹
2、神經網絡的GPU訓練
3、預測與分類
一、Deep Network Designer工具箱使用介紹
相比BP、GRNN、RBF、NARX神經網絡的簡單結構,深度神經網絡結構更加復雜,比如卷積神經網絡CNN,長短時序神經網絡LSTM等,matlab集成了深度學習工具箱,可輸入如下指令調用:
Deep Network Designer
可以使用別人的網絡架構也可以自己創建,點擊“空白網絡”創建。如下圖最左側是常用的各種網絡層,可根據文獻上的網絡結構或者自己設計的結構任意組合,具體模塊參數雙擊進行設計,前提是網絡數據維度沒有錯誤。如圖所示,為作者創建的用于RGB圖像分類的卷積神經網絡CNN結構,具體設計過程后續出。構建完成,點擊“分析”可查看是否有錯誤,無錯誤之后可通過“導出”得到網絡架構的代碼即layers。
layers = [imageInputLayer([120 160 3],"Name","imageinput") %輸入相機幀convolution2dLayer([3 3],15,"Name","conv_1","Padding","same") %卷積層reluLayer("Name","relu_1")averagePooling2dLayer([2 2],"Name","avgpool2d_1","Stride",[2 2]) %平均池化層convolution2dLayer([3 3],15,"Name","conv_2","Padding","same") %卷積層reluLayer("Name","relu_2")averagePooling2dLayer([2 2],"Name","avgpool2d_2","Stride",[2 2]) %平均池化層convolution2dLayer([3 3],12,"Name","conv_3","Padding","same") %卷積層reluLayer("Name","relu_3")averagePooling2dLayer([2 2],"Name","avgpool2d_3","Stride",[2 2]) %平均池化層dropoutLayer(0.3,"Name","dropout_2") %隨機失活,失活率為30%fullyConnectedLayer(256,"Name","fc_1","WeightL2Factor",6) %全連接層reluLayer("Name","relu_4")dropoutLayer(0.3,"Name","dropout_1") %隨機失活,失活率為30%fullyConnectedLayer(20,"Name","fc_2","WeightL2Factor",6) %全連接層softmaxLayer("Name","softmax") classificationLayer("Name","classoutput")]; %輸出層創建一個m程序,將此代碼復制進去。
二、神經網絡的GPU訓練
網絡構建好以后,就是編寫訓練的代碼,主要過程分為:讀取數據集、歸一化(可有可無)、劃分訓練集與測試集、反歸一化(可有可無)、訓練配置與訓練。作者此處給出圖像分類的代碼,詳細過程可見代碼注釋。
%% 工具箱導出的網絡結構 layers = [imageInputLayer([120 160 3],"Name","imageinput")convolution2dLayer([3 3],15,"Name","conv_1","Padding","same")reluLayer("Name","relu_1")averagePooling2dLayer([2 2],"Name","avgpool2d_1","Stride",[2 2])convolution2dLayer([3 3],15,"Name","conv_2","Padding","same")reluLayer("Name","relu_2")averagePooling2dLayer([2 2],"Name","avgpool2d_2","Stride",[2 2])convolution2dLayer([3 3],12,"Name","conv_3","Padding","same")reluLayer("Name","relu_3")averagePooling2dLayer([2 2],"Name","avgpool2d_3","Stride",[2 2])dropoutLayer(0.3,"Name","dropout_2")fullyConnectedLayer(256,"Name","fc_1","WeightL2Factor",6)reluLayer("Name","relu_4")dropoutLayer(0.3,"Name","dropout_1")fullyConnectedLayer(20,"Name","fc_2","WeightL2Factor",6)softmaxLayer("Name","softmax")classificationLayer("Name","classoutput")]; %% 讀取數據集 digitDatasetPath=fullfile('.\'); %打開數據集文件夾路徑 % 注釋:此路徑下放有30個文件夾,每個文件夾為一個類別,每個文件夾里面有等數量的圖片,這些圖片都已經預處理。 imds=imageDatastore(digitDatasetPath,...'IncludeSubfolders',true,'LabelSource','foldernames'); %讀取圖片數據集,標簽Label設置為文件名。 % 注釋:每個文件夾的名字即為分類的類別標簽 %% 劃分數據集(訓練集和驗證集) numTrainFiles=round(2/3*30); % 20為類別文件夾數量,測試集作者放在另外的地方,訓練時候只需要訓練集和驗證集。 [imdsTrain,imdsValidation]=splitEachLabel(imds,numTrainFiles,'randomize'); % 隨機劃分每類文件夾下的訓練集和驗證集%若數據圖片大小與網絡輸入不一樣,可通過下面三行代碼處理。若相同可去掉此三行代碼 inputSize=layers(1).InputSize; %讀取網絡輸入層的輸入圖像的大小尺寸 imdsTrain=augmentedImageDatastore(inputSize(1:2),imdsTrain); %整合訓練集的尺寸1與inputSize的第一二個維度相同。 augimdsValidation=augmentedImageDatastore(inputSize(1:2),imdsValidation); %% 訓練配置 ExecutionEnvironment='gpu'; %此處設置用GPU或者CPU訓練,建議GPU快 %具體一些需要改動的配置說明,可以上matlab官網查看trainingOptions函數文檔 options_train=trainingOptions('sgdm',...'MaxEpochs',100,... % 訓練輪數為65次'InitialLearnRate',0.0001,... %初始學習率'Verbose',true,'MiniBatchSize',10,... 'LearnRateSchedule','piecewise',...'LearnRateDropFactor',0.6,...'LearnRateDropPeriod',5,...'Plots','training-progress',...'ValidationData',augimdsValidation,...'ValidationFrequency',10,...'ExecutionEnvironment',ExecutionEnvironment); net=trainNetwork(imdsTrain,layers,options_train); %開始訓練 save('train.mat'); %保存訓練完的網絡模型為train.mat。三、預測與分類
此處我們是屬于分類任務,所以在第一步創建網絡最后一層模塊是分類塊,如果是數據回歸即數據預測則不同,本文不詳細說明。下面給出利用已訓練好的網絡模型進行分類的代碼。再創建一個m程序用來放分類的代碼:
load('train.mat'); %先下載同一文件夾下之前訓練好的模型 x=imread('1.jpg'); %讀取一張事先準備好的圖片1,命名為x YPred=classify(net,x); %用訓練好的網絡net對x進行分類識別 ,分類結果為YPred sprintf('測試結果為%s',YPred) 將結果YPred顯示。注意這個YPred是一個奇怪的數據類型categorical %為了后續GUI界面的方便使用,作者的數據集名字即類別lable都是數字哦 %下面就是將categorical數據類型轉化為矩陣mat類型,命名為nn。 M=string(YPred); nn=double(M);結語
讀者可能需要一些圖片的預處理和數據增強,視頻幀讀取,GUI的網絡嵌入與端到端識別等程序,可以參考其他博主的文章,作者后續閑暇之余有可能會出相關博客。本文神經網絡和識別的一些原理算法,后續博客直接給出本科畢設論文以供參考。
總結
以上是生活随笔為你收集整理的matlab卷积神经网络的创建与图片识别的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 最常用的两种C++序列化方案的使用心得(
- 下一篇: go 打印当前时间