No module named MNIST_写给小白的用fashion-mnist入门机器学习和深度学习的简单项目(非常全面!!!)...
這是一個入門機器學習和深度學習的小項目,以fashion-mnist數據為基礎。分別利用機器學習(隨機森林)和深度學習(多層感知機/卷積神經網絡)方法進行訓練。完整的包含數據讀取,數據處理,訓練, 驗證,loss曲線的繪制,訓練過程的可視化,模型推理,混淆矩陣的計算,特征圖可視化等。因為最近總是會帶一些學弟學妹入門,前段時間了解到這個數據集,感覺麻雀雖小五臟俱全,真的是一個很好的入門項目,把這一套搞明白了,做復雜的項目也都是水到渠成。于是乎花了點時間,找了些代碼然后修改了一些,完善了很多的構件,這里當做一個記錄吧!關于代碼已經盡我所能的注釋得很清楚了。完整的代碼我已經上傳github上面,并且readme里面會有使用方法:
DLLXW/Fashion-MNIST?github.com數據集介紹
Fashion-mnist可以看作經典MNIST數據的加強版,號稱計算機視覺領域的Hello, World,這里暫不作過多介紹。下面是數據集的github鏈接:
zalandoresearch/fashion-mnist?github.com用于訓練的圖片有6w張,用于驗證的圖片有1w張,每一個樣本是一張28x28像素的圖像。總共10中服飾類別。所以問題是一個10分類的問題
訓練隨機森林模型
首先利用機器學習方法來進行一下該分類任務,這里選取隨機森林.
需要提前安裝的庫為:sklearn
準備數據:
事實上很多機器學習庫都已經集成了該數據集,也就是說可以在代碼里面直接導入,但這里推薦自己手動下載下來。到上面的github鏈接頁面,找到:
下載數據集并且解壓到自己的目錄下
譬如我的數據格式組織如下:
數據集都解壓到了raw文件夾。下面是實現用隨機森林訓練,驗證,并且打印多分類混淆矩陣和分類信息的代碼。
import這里稍微解釋下多分類的混淆矩陣,對于二分類很好理解,多分類的混淆矩陣其實也是在二分類的基礎上進行的,基本思想是:當研究其中的一類時,其余的各個類別都當做負類。
上面隨機森林是一個經典的機器學習算法,以此作為例子對fashion-mnist數據集進行了分類。如果想換成其它的模型,也很容易,只需要簡單的修改上面的幾行代碼。
Pytorch構建多層感知機和卷積神經網絡進行分類
下面介紹一下如何構建一個多層的神經網絡(多層感知機MLP)已經卷積神經網絡來進行分類。這里用pytorch進行構建。這里給出的網絡的構建代碼,完整代碼請參考github
#這里其實是構建一個最簡單經典的神經網絡模型訓練/驗證
下面給出訓練/驗證部分代碼,主要分為訓練的參數設置;數據集加載;訓練過程和驗證過程;利用tensorboard實現訓練過程的可視化,網絡結構可視化等。
def模型推理
首先說明說明叫推理(infer):前面我們已經訓練好了模型,同時也已經保存好了我們的模型(xxx.pt)。同時我們還在訓練的過程中就驗證(測試)了我們的模型訓練效果;現在我們有了一張新圖片,需要送給模型,讓模型判斷是屬于哪個類別,這個過程就叫做模型的infer。
所以推理的前提是需要從保存的模型里面加載模型,同時要注意的是我們之前保存的只是模型的權重,并未保存模型的結構,所以還得導入前面定義的模型結構,然后將這些權重附著在網絡的結構(骨架)之上。
#最后貼上一些模型訓練可視化的效果圖
完整項目已經上傳github
DLLXW/Fashion-MNIST?github.com在readme里面有使用說明!!!希望大家素質三連
總結
以上是生活随笔為你收集整理的No module named MNIST_写给小白的用fashion-mnist入门机器学习和深度学习的简单项目(非常全面!!!)...的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: centos8 配置 dns_Linux
- 下一篇: scihub只能用doi查吗_同步带轮齿