java识别手写文字_神经网络入门 第6章 识别手写字体
前言
神經網絡是一種很特別的解決問題的方法。本書將用最簡單易懂的方式與讀者一起從最簡單開始,一步一步深入了解神經網絡的基礎算法。本書將盡量避開讓人望而生畏的名詞和數學概念,通過構造可以運行的Java程序來實踐相關算法。
關注微信號“邏輯編程"來獲取本書的更多信息。
這一章節我們將會解決一個真正的問題:手寫字體識別。我們將識別像下面圖中這樣的手寫數字。
在開始之前,我們先要準備好相應的測試數據。我們不能像前邊那樣簡單的產生手寫字體,畢竟我們自己還不知道如何寫出一個產生手寫字體的算法。訓練要達到一定的精度需要較多的訓練數據。還好,前人栽樹后人乘涼,先驅們已經收集了寶貴的訓練材料。MNIST就是一個廣泛使用的數據集。不但可以拿來用,我們還可以從網站上看到別人的識別準確率。這樣我們就有了很好的參照。MNIST包含一套訓練數據和一套測試數據,分別來自不同的人群的手寫。
MNIST網站:?http://yann.lecun.com/exdb/mnist/
這個數據集是寫在特定的二進制文件中的,并非普通圖片格式。每個圖片數據由28*28個像素組成。每個像素1個字節表示顏色灰度級。MNIST網站上有具體的介紹。
我們寫一個類來完成數據集的讀取工作,并提供接口返回指定的訓練或者測試數據。具體代碼不做分析,僅將代碼附在下面,供讀者使用。代碼執行前要先下載數據文件并保留GZIP格式。代碼執行后將隨機抽取20個生成PNG圖片供讀者自己查看和驗證數據內容。
下面我們寫個測試類來識別手寫字體。我們使用MNIST庫的60000訓練數據來反復訓練我們的神經網絡。每輪訓練后使用MNIST庫的10000個測試數據來測試識別率。
下面是代碼:
package com.luoxq.ann;
import java.util.Arrays;
import java.util.Random;
public class MnistTest {
public static void main(String... args) {
int[] shape = {28 * 28, 10};
NeuralNetwork nn = new NeuralNetwork(shape);
Mnist mnist = new Mnist();
mnist.load();
mnist.shuffle();
System.out.println("Shape: " + Arrays.toString(shape));
System.out.println("Initial correct rate: " + test(nn, mnist));
int epochs = 1000;
double rate = 0.5;
System.out.println("Learning rate: " + rate);
System.out.println("Epoch,Time,Correctness\n----------------------");
long time = System.currentTimeMillis();
Mnist.Data[] data = mnist.getTrainingSlice(0, 60000);
for (int epoch = 1; epoch <= epochs; epoch++) {
for (int sample = 0; sample < data.length; sample++) {
nn.train(data[sample].input, data[sample].output, rate);
}
long seconds = (System.currentTimeMillis() - time) / 1000;
System.out.println(epoch + ", " + seconds + ", " +
test(nn, mnist));
}
}
private static int test(NeuralNetwork nn, Mnist mnist) {
int correct = 0;
Mnist.Data[] data = mnist.getTestSlice(0, 10000);
for (int sample = 0; sample < data.length; sample++) {
if (max(nn.f(data[sample].input)) == data[sample].label) {
correct++;
}
}
return correct;
}
private static int max(double[] d) {
double max = d[0];
int idx = 0;
for (int i = 1; i < d.length; i++) {
if (max < d[i]) {
max = d[i];
idx = i;
}
}
return idx;
}
}
我們先用一個10個神經元的單層神經網絡試試看。結果出乎意外的好。我們很快就獲得了超過90%的正確率。單層網絡幾乎就是對每個數字的像素分布做簡單統計。能獲得如此高的識別率,還是很神奇的。 在達到90%之后再訓練已經效果不大,達到飽和了。我們必須換一種方法來做了。
Shape: [784, 10]
Initial correct rate: 1373
Learning rate: 0.5
Epoch,Time,Correctness
----------------------
1, 4, 6429
2, 8, 7663
3, 13, 8963
4, 17, 9029
5, 22, 9016
6, 27, 9062
7, 31, 9063
8, 36, 9066
9, 41, 9072
10, 45, 9057
11, 50, 9084
12, 55, 9072
13, 61, 9062
14, 66, 9050
15, 70, 9077
16, 75, 9052
17, 79, 9068
18, 84, 9055
19, 88, 9060
20, 93, 9064
那么我們來使用三層神經網絡試一試。在試了幾個不同的中間層大小和學習率參數之后,我找到了下面這個較好的參數組合:
Shape: [784, 50, 10]
Initial correct rate: 944
Learning rate: 1.0
Epoch,Time,Correctness
----------------------
1, 24, 7459
2, 59, 9232
3, 99, 9313
4, 131, 9379
5, 153, 9412
6, 176, 9443
7, 200, 9412
8, 226, 9447
9, 248, 9462
10, 269, 9461
11, 290, 9465
12, 314, 9493
13, 343, 9477
14, 368, 9499
15, 392, 9502
16, 420, 9509
17, 447, 9482
18, 472, 9508
19, 496, 9491
20, 518, 9536
21, 545, 9523
22, 569, 9549
23, 593, 9527
24, 618, 9527
25, 643, 9520
26, 667, 9513
27, 689, 9507
28, 712, 9527
29, 734, 9501
30, 758, 9521
31, 781, 9508
32, 804, 9534
33, 827, 9534
34, 850, 9550
35, 875, 9569
我們很快達到了95%以上的正確率。可見多層網絡相對單層神經網絡還是有優勢的。雖然這個正確率還達不到產品水平,但是作為初次嘗試結果還是很不錯的。
下面是MNIST文件讀取源代碼:
package com.luoxq.ann;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Random;
import java.util.zip.GZIPInputStream;
/**
* Created by luoxq on 17/4/15.
*/
public class Mnist {
static class Data {
public byte[] data;
public int label;
public double[] input;
public double[] output;
}
public static void main(String... args) throws Exception {
Mnist mnist = new Mnist();
mnist.load();
System.out.println("Data loaded.");
Random rand = new Random(System.nanoTime());
for (int i = 0; i < 20; i++) {
int idx = rand.nextInt(60000);
Data d = mnist.getTrainingData(idx);
BufferedImage img = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
for (int x = 0; x < 28; x++) {
for (int y = 0; y < 28; y++) {
img.setRGB(x, y, toRgb(d.data[y * 28 + x]));
}
}
File output = new File(i + "_" + d.label + ".png");
if (!output.exists()) {
output.createNewFile();
}
ImageIO.write(img, "png", output);
}
}
static int toRgb(byte bb) {
int b = (255 - (0xff & bb));
return (b << 16 | b << 8 | b) & 0xffffff;
}
Data[] trainingSet;
Data[] testSet;
public void shuffle() {
Random rand = new Random();
for (int i = 0; i < trainingSet.length; i++) {
int x = rand.nextInt(trainingSet.length);
Data d = trainingSet[i];
trainingSet[i] = trainingSet[x];
trainingSet[x] = trainingSet[i];
}
}
public Data getTrainingData(int idx) {
return trainingSet[idx];
}
public Data[] getTrainingSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(trainingSet, start, ret, 0, count);
return ret;
}
public Data getTestData(int idx) {
return testSet[idx];
}
public Data[] getTestSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(testSet, start, ret, 0, count);
return ret;
}
public void load() {
trainingSet = load("train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz");
testSet = load("t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz");
if (trainingSet.length != 60000 || testSet.length != 10000) {
throw new RuntimeException("Unexpected training/test data size: " + trainingSet.length + "/" + testSet.length);
}
}
private Data[] load(String imgFile, String labelFile) {
byte[][] images = loadImages(imgFile);
byte[] labels = loadLabels(labelFile);
if (images.length != labels.length) {
throw new RuntimeException("Images and label doesn't match: " + imgFile + " " + labelFile);
}
int len = images.length;
Data[] data = new Data[len];
for (int i = 0; i < len; i++) {
data[i] = new Data();
data[i].data = images[i];
data[i].label = 0xff & labels[i];
data[i].input = dataToInput(images[i]);
data[i].output = labelToOutput(labels[i]);
}
return data;
}
private double[] labelToOutput(byte label) {
double[] o = new double[10];
o[label] = 1;
return o;
}
private double[] dataToInput(byte[] b) {
double[] d = new double[b.length];
for (int i = 0; i < b.length; i++) {
d[i] = (b[i] & 0xff) / 255.0;
}
return d;
}
private byte[][] loadImages(String imgFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(imgFile)));) {
int magic = in.readInt();
if (magic != 0x00000803) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
}
int count = in.readInt();
int rows = in.readInt();
int cols = in.readInt();
if (rows != 28 || cols != 28) {
throw new RuntimeException("Unexpected row and col count: " + rows + "x" + cols);
}
byte[][] data = new byte[count][rows * cols];
for (int i = 0; i < count; i++) {
in.readFully(data[i]);
}
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + imgFile, ex);
}
}
private byte[] loadLabels(String labelFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(labelFile)));) {
int magic = in.readInt();
if (magic != 0x00000801) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
}
int count = in.readInt();
byte[] data = new byte[count];
in.readFully(data);
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + labelFile, ex);
}
}
}
歡迎關注訂閱號邏輯編程內容。
總結
以上是生活随笔為你收集整理的java识别手写文字_神经网络入门 第6章 识别手写字体的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 北京租房小贴士
- 下一篇: 七段显示器 + 74HC595 显示 /