KNN识别手写数字MNIST
生活随笔
收集整理的這篇文章主要介紹了
KNN识别手写数字MNIST
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
使用sklearn的KNN實現類,neighbors.KNeighborsClassifier,模型精度達到96.7%
數據集可以在線下載,也可以手動下載:
mnist數據集地址:https://www.lanzouw.com/iXDefxnl3fa
import torch, torchvision from sklearn import neighbors#加載mnist數據集 train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, download=True) test_dataset = torchvision.datasets.MNIST(root='./data/', train=False, download=False)#獲取mnist數據集, 并進行歸一化,然后將(28*,28)的圖片轉成(1, 784)向量 train_data = (train_dataset.data/255).view(-1, 784) train_label = train_dataset.targets#加載測試集 test_data = (test_dataset.data/255).view(-1, 784) test_label = test_dataset.targets#訓練模型 model = neighbors.KNeighborsClassifier(n_neighbors=8) model.fit(train_data, train_label)#模型預測 predict = model.predict(test_data)#使用sklearn的score函數算精度, acc = model.score(test_data, test_label) print(acc)?
總結
以上是生活随笔為你收集整理的KNN识别手写数字MNIST的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Pytorch和Numpy的默认类型
- 下一篇: seaborn常用图