TensorFlow 笔记6--迁移学习
生活随笔
收集整理的這篇文章主要介紹了
TensorFlow 笔记6--迁移学习
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
TensorFlow 筆記6–遷移學(xué)習(xí)
參考文檔:https://github.com/ageron/handson-ml/blob/master/11_deep_learning.ipynb
一、凍結(jié)部分層權(quán)重
法一:
with tf.name_scope("train"): optimizer = tf.train.GradientDescentOptimizer(learning_rate)# 指定要訓(xùn)練的那部分層train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="hidden[34]|outputs")training_op = optimizer.minimize(loss, var_list=train_vars)# 恢復(fù)凍結(jié)層的數(shù)據(jù),其實(shí)也可以全部恢復(fù) reuse_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="hidden[12]") restore_saver = tf.train.Saver(reuse_vars) with tf.Session() as sess:restore_saver.restore(sess, "./my_model_final.ckpt")法二:
with tf.name_scope("dnn"):hidden1 = tf.layers.dense(X, n_hidden1, activation=tf.nn.relu, name="hidden1") # reused frozenhidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu, name="hidden2") # reused frozen# 在此之前的層不會(huì)進(jìn)行梯度更新hidden2_stop = tf.stop_gradient(hidden2)# 注意以下的層要相應(yīng)的修改為hidden2_stophidden3 = tf.layers.dense(hidden2_stop, n_hidden3, activation=tf.nn.relu, name="hidden3") # reused, not frozenhidden4 = tf.layers.dense(hidden3, n_hidden4, activation=tf.nn.relu, name="hidden4") # new!logits = tf.layers.dense(hidden4, n_outputs, name="outputs") # new!# 剩下的和正常的一樣二、緩存凍結(jié)層結(jié)果
# 先設(shè)置凍結(jié)層,再進(jìn)行以下操作with tf.Session() as sess:init.run()restore_saver.restore(sess, "./my_model_final.ckpt")# 緩存凍結(jié)層的結(jié)果,即訓(xùn)練期間只計(jì)算一次h2_cache = sess.run(hidden2, feed_dict={X: X_train})h2_cache_valid = sess.run(hidden2, feed_dict={X: X_valid}) for epoch in range(n_epochs):shuffled_idx = np.random.permutation(len(X_train))# feed的數(shù)據(jù)應(yīng)該相應(yīng)的改為凍結(jié)層的結(jié)果hidden2_batches = np.array_split(h2_cache[shuffled_idx], n_batches)y_batches = np.array_split(y_train[shuffled_idx], n_batches)for hidden2_batch, y_batch in zip(hidden2_batches, y_batches):sess.run(training_op, feed_dict={hidden2:hidden2_batch, y:y_batch})accuracy_val = accuracy.eval(feed_dict={hidden2: h2_cache_valid, y: y_valid}) print(epoch, "Validation accuracy:", accuracy_val) save_path = saver.save(sess, "./my_new_model_final.ckpt")總結(jié)
以上是生活随笔為你收集整理的TensorFlow 笔记6--迁移学习的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: php正则检查QQ,PHP 正则匹配手机
- 下一篇: Scikit-Learn 机器学习笔记