004-2-拟合,drop-out
?
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #載入數(shù)據(jù) mnist = input_data.read_data_sets("MNIST_data",one_hot = True)#定義每個(gè)批次的大小 batch_size = 100 #計(jì)算一共有多少個(gè)批次 n_batch = mnist.train.num_examples//batch_size#定義2個(gè)placeholder x = tf.placeholder(tf.float32,[None,784]) y = tf.placeholder(tf.float32,[None,10]) keep_prob = tf.placeholder(tf.float32) #表示有百分之多少的神經(jīng)元工作#神經(jīng)網(wǎng)絡(luò): #正態(tài)分布,方差0.1 W1 = tf.Variable(tf.truncated_normal([784,2000],stddev=0.1)) b1 = tf.Variable(tf.zeros([2000])+0.1) L1 = tf.nn.tanh(tf.matmul(x,W1)+b1) L1_drop_out = tf.nn.dropout(L1,keep_prob)W2 = tf.Variable(tf.truncated_normal([2000,2000],stddev=0.1)) b2 = tf.Variable(tf.zeros([2000])+0.1) L2 = tf.nn.tanh(tf.matmul(L1_drop_out,W2)+b2) L2_drop_out = tf.nn.dropout(L2,keep_prob)W3 = tf.Variable(tf.truncated_normal([2000,1000],stddev=0.1)) b3 = tf.Variable(tf.zeros([1000])+0.1) L3 = tf.nn.tanh(tf.matmul(L2_drop_out,W3)+b3) L3_drop_out = tf.nn.dropout(L3,keep_prob)W4 = tf.Variable(tf.truncated_normal([1000,10],stddev=0.1)) b4 = tf.Variable(tf.zeros([10])+0.1) prediction = tf.nn.softmax(tf.matmul(L3_drop_out,W4)+b4)#二次代價(jià)函數(shù): # loss = tf.reduce_mean(tf.square(y-prediction)) #對數(shù)似然函數(shù) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels= y,logits= prediction)) #梯度下降 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)#初始化變量 init = tf.global_variables_initializer()#求準(zhǔn)確率 #比較預(yù)測值最大標(biāo)簽位置與真實(shí)值最大標(biāo)簽位置是否相等 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) #求準(zhǔn)去率 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))with tf.Session() as sess:sess.run(init)for epoch in range(31):for batch in range(n_batch):batch_xs,batch_ys = mnist.train.next_batch(batch_size)sess.run(train_step,feed_dict = {x:batch_xs,y:batch_ys,keep_prob:0.7})test_acc = sess.run(accuracy,feed_dict ={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})train_acc = sess.run(accuracy,feed_dict ={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})print("Iter"+str(epoch+1)+",Testing accuracy-"+str(test_acc)+",Train accuracy-"+str(train_acc))
keep_prob?= 1.0時(shí)
Iter1,Testing accuracy-0.944,Train accuracy-0.958545
Iter2,Testing accuracy-0.958,Train accuracy-0.974691
Iter3,Testing accuracy-0.9621,Train accuracy-0.982855
Iter4,Testing accuracy-0.9652,Train accuracy-0.986455
Iter5,Testing accuracy-0.968,Train accuracy-0.988364
Iter6,Testing accuracy-0.9683,Train accuracy-0.989855
Iter7,Testing accuracy-0.9694,Train accuracy-0.990982
Iter8,Testing accuracy-0.9687,Train accuracy-0.991636
Iter9,Testing accuracy-0.9691,Train accuracy-0.992255
Iter10,Testing accuracy-0.9697,Train accuracy-0.9926
Iter11,Testing accuracy-0.9697,Train accuracy-0.992909
Iter12,Testing accuracy-0.9705,Train accuracy-0.993236
Iter13,Testing accuracy-0.9701,Train accuracy-0.993309
Iter14,Testing accuracy-0.9707,Train accuracy-0.993527
Iter15,Testing accuracy-0.9705,Train accuracy-0.993691
Iter16,Testing accuracy-0.9709,Train accuracy-0.993891
Iter17,Testing accuracy-0.9707,Train accuracy-0.993982
Iter18,Testing accuracy-0.9716,Train accuracy-0.994036
Iter19,Testing accuracy-0.9717,Train accuracy-0.994236
Iter20,Testing accuracy-0.9722,Train accuracy-0.994364
Iter21,Testing accuracy-0.9716,Train accuracy-0.994436
Iter22,Testing accuracy-0.972,Train accuracy-0.994509
Iter23,Testing accuracy-0.9722,Train accuracy-0.9946
Iter24,Testing accuracy-0.972,Train accuracy-0.994636
Iter25,Testing accuracy-0.9723,Train accuracy-0.994709
Iter26,Testing accuracy-0.9723,Train accuracy-0.994836
Iter27,Testing accuracy-0.9722,Train accuracy-0.994891
Iter28,Testing accuracy-0.9727,Train accuracy-0.994964
Iter29,Testing accuracy-0.9724,Train accuracy-0.995091
Iter30,Testing accuracy-0.9725,Train accuracy-0.995164
Iter31,Testing accuracy-0.9725,Train accuracy-0.995182
?
keep_prob = 0.7時(shí)
Iter1,Testing accuracy-0.9187,Train accuracy-0.912709
Iter2,Testing accuracy-0.9281,Train accuracy-0.923782
Iter3,Testing accuracy-0.9357,Train accuracy-0.935236
Iter4,Testing accuracy-0.9379,Train accuracy-0.940855
Iter5,Testing accuracy-0.9441,Train accuracy-0.944564
Iter6,Testing accuracy-0.9463,Train accuracy-0.948164
Iter7,Testing accuracy-0.9472,Train accuracy-0.950182
Iter8,Testing accuracy-0.9515,Train accuracy-0.9544
Iter9,Testing accuracy-0.9548,Train accuracy-0.956455
Iter10,Testing accuracy-0.9551,Train accuracy-0.959091
Iter11,Testing accuracy-0.9566,Train accuracy-0.959891
Iter12,Testing accuracy-0.9594,Train accuracy-0.962036
Iter13,Testing accuracy-0.9592,Train accuracy-0.964236
Iter14,Testing accuracy-0.9585,Train accuracy-0.964818
Iter15,Testing accuracy-0.9607,Train accuracy-0.966
Iter16,Testing accuracy-0.961,Train accuracy-0.9668
Iter17,Testing accuracy-0.9612,Train accuracy-0.967891
Iter18,Testing accuracy-0.9643,Train accuracy-0.969236
Iter19,Testing accuracy-0.9646,Train accuracy-0.969945
Iter20,Testing accuracy-0.9655,Train accuracy-0.970909
Iter21,Testing accuracy-0.9656,Train accuracy-0.971509
Iter22,Testing accuracy-0.9668,Train accuracy-0.972891
Iter23,Testing accuracy-0.9665,Train accuracy-0.972982
Iter24,Testing accuracy-0.9687,Train accuracy-0.974091
Iter25,Testing accuracy-0.9673,Train accuracy-0.974782
Iter26,Testing accuracy-0.9682,Train accuracy-0.975127
Iter27,Testing accuracy-0.9682,Train accuracy-0.976055
Iter28,Testing accuracy-0.9703,Train accuracy-0.976582
Iter29,Testing accuracy-0.9692,Train accuracy-0.976982
Iter30,Testing accuracy-0.9707,Train accuracy-0.977891
Iter31,Testing accuracy-0.9703,Train accuracy-0.978091
?
可以看到,當(dāng)drop-out了30%的時(shí)候,訓(xùn)練集與測試集的準(zhǔn)確率相差比全連接要小一些,可以防止過擬合的情況出現(xiàn)。
?
?
轉(zhuǎn)載于:https://www.cnblogs.com/Mjerry/p/9828102.html
總結(jié)
以上是生活随笔為你收集整理的004-2-拟合,drop-out的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: git 命令使用技巧
- 下一篇: mysql 修改字符集