U-net结构及代码注释
之前看了U-net的代碼,不過沒有實際運行相應的代碼,讀相應的博客也了解了一些初學者關于U-net的問題:
1.U-net的套路結構,以及論文中的結構
2.U-net的數據增強方式
3.U-net的代碼實現方式
4.U-net的損失函數
如果大家有讀論文的習慣,那大家首先關注的應該是這篇論文的應用場合以及相對于以前工作的優點。
這里有一篇博客說明了U-net的作用以及特點https://blog.csdn.net/u012931582/article/details/70215756
U-net屬于FCN框架,FCN是輸入和輸出都是圖像,沒有全連接層。較淺的高分辨率層用來解決像素定位的問題,較深的層用來解決像素分類的問題。屬于端到端的學習,圖像風格轉換以及圖像超分辨率都是這類框架。
?
大致說明了U-net結構的意義,我們來說明文初U-net的三個問題:
1.U-net的套路結構:
相信大家在網上的論文都會看到這樣的圖:
第一層:
可以看到輸入是572*572*1的圖像,其實原始圖片輸入應該是512*512*1,在3*3卷積的過程中可以發現圖像尺寸在不斷的變小,論文中使用的卷積方式是'VALID',而不是'SAME'。如果我們希望邊緣的像素點也可以被準確分割的情況下,U-Net使用了鏡像操作(Overlay-tile Strategy)來解決該問題。鏡像操作即是給輸入圖像加入一個對稱的邊,那么邊的寬度是多少呢?一個比較好的策略是通過感受野確定。因為有效卷積是會降低Feature Map分辨率的,但是我們希望??的圖像的邊界點能夠保留到最后一層Feature Map。所以我們需要通過加邊的操作增加圖像的分辨率,增加的尺寸即是感受野的大小,也就是說每條邊界增加感受野的一半作為鏡像邊。
該圖片從知乎引用。
根據圖中所示的壓縮路徑的網絡架構,我們可以計算其感受野:
這也就是為什么U-Net的輸入數據是??的。572的卷積的另外一個好處是每次降采樣操作的Feature Map的尺寸都是偶數,這個值也是和網絡結構密切相關的。相關博客:https://zhuanlan.zhihu.com/p/43927696
左半部分為例:這是很正常的CNN的結構,不過這里的一個單元是,第一次conv+RELU將channel值倍增,第二次conv+RELU將channel不變。
最底層:,這里依然是和之前一樣的,兩次卷積,將channel倍增和維持不變。
右半部分:
藍色框的轉換是通過進行,進行的特點是邊長倍增,但是通道數減少一倍,這是反卷積操作(實際上是轉置卷積,論文見神經網絡特征可視化),但是這里多了白色框的部分,白色的框和灰色箭頭代表?將之前的特征和現有特征拼接起來(操作是concat,densenet有類似操作),原因是在不斷的降采樣的過程中,信息雖然抽象程度越來越高,但是信息也是在不斷的損失的,通過將之前層的信息,結合可以更好的判斷分割,例如原圖和最后一層的concat,只進行上采樣過程可能只能分割大致區域(畢竟降采樣到了很小的尺寸上),結合原始圖像可以很好的定位分割位置。
這里輸入輸出的尺寸差異實際上是由于在卷積過程中使用'VALID'的方式。如果改成'SAME'輸入和輸出就可以一致,最后的代碼說明就是這種方式。不過同尺寸的方式由于沒有代碼實驗,效果怎樣是不確定的,不過現在大都是按照輸入輸出一致進行網絡結構設定的。
?
2.圖像數據的增多:由于生物圖像的特殊性,形變后的組織也是符合相應的組織特點的,如下圖
?
圖像扭曲的論文:http://faculty.cs.tamu.edu/schaefer/research/mls.pdf
當然噪聲也是可以加進去的。
其實這里還有一個細節需要注意(參看https://zhuanlan.zhihu.com/p/43927696):U-net的損失函數,有時分割圖像是這樣的,細胞間是緊密相連的,所以邊緣是非常難以探測的,這時需要對損失函數進行設定。
那么該怎樣設計損失函數來讓模型有分離邊界的能力呢?U-Net使用的是帶邊界權值的損失函數:
其中??是$$softmax$$損失函數,??是像素點的標簽值,??是像素點的權值,目的是為了給圖像中貼近邊界點的像素更高的權值。
其中??是平衡類別比例的權值,??是像素點到距離其最近的細胞的距離,??則是像素點到距離其第二近的細胞的距離。??和??是常數值,在實驗中?,??。
當然對于某些圖像就沒有必要了,例如
?
代碼說明(有一些注釋):
class Unet:def __init__(self):print('New U-net Network')self.input_image = Noneself.input_label = Noneself.cast_image = Noneself.cast_label = Noneself.keep_prob = Noneself.lamb = Noneself.result_expand = Noneself.loss, self.loss_mean, self.loss_all, self.train_step = [None] * 4self.prediction, self.correct_prediction, self.accuracy = [None] * 3self.result_conv = {}self.result_relu = {}self.result_maxpool = {}self.result_from_contract_layer = {}self.w = {}self.b = {}def init_w(self, shape, name):with tf.name_scope('init_w'):stddev = tf.sqrt(x=2 / (shape[0] * shape[1] * shape[2]))# stddev = 0.01w = tf.Variable(initial_value=tf.truncated_normal(shape=shape, stddev=stddev, dtype=tf.float32), name=name)tf.add_to_collection(name='loss', value=tf.contrib.layers.l2_regularizer(self.lamb)(w))return w@staticmethoddef init_b(shape, name):with tf.name_scope('init_b'):return tf.Variable(initial_value=tf.random_normal(shape=shape, dtype=tf.float32), name=name)@staticmethoddef copy_and_crop_and_merge(result_from_contract_layer, result_from_upsampling):# result_from_contract_layer_shape = tf.shape(result_from_contract_layer)# result_from_upsampling_shape = tf.shape(result_from_upsampling)# result_from_contract_layer_crop = \# tf.slice(# input_=result_from_contract_layer,# begin=[# 0,# (result_from_contract_layer_shape[1] - result_from_upsampling_shape[1]) // 2,# (result_from_contract_layer_shape[2] - result_from_upsampling_shape[2]) // 2,# 0# ],# size=[# result_from_upsampling_shape[0],# result_from_upsampling_shape[1],# result_from_upsampling_shape[2],# result_from_upsampling_shape[3]# ]# )result_from_contract_layer_crop = result_from_contract_layerreturn tf.concat(values=[result_from_contract_layer_crop, result_from_upsampling], axis=-1)def set_up_unet(self, batch_size):# inputwith tf.name_scope('input'):# learning_rate = tf.train.exponential_decay()self.input_image = tf.placeholder(dtype=tf.float32, shape=[batch_size, INPUT_IMG_WIDE, INPUT_IMG_WIDE, INPUT_IMG_CHANNEL], name='input_images')# self.cast_image = tf.reshape(# tensor=self.input_image,# shape=[batch_size, INPUT_IMG_WIDE, INPUT_IMG_WIDE, INPUT_IMG_CHANNEL]# )# for softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')# using one-hot# self.input_label = tf.placeholder(# dtype=tf.uint8, shape=[OUTPUT_IMG_WIDE, OUTPUT_IMG_WIDE], name='input_labels'# )# self.cast_label = tf.reshape(# tensor=self.input_label,# shape=[batch_size, OUTPUT_IMG_WIDE, OUTPUT_IMG_HEIGHT]# )# for sparse_softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')# not using one-hot codingself.input_label = tf.placeholder(dtype=tf.int32, shape=[batch_size, OUTPUT_IMG_WIDE, OUTPUT_IMG_WIDE], name='input_labels')self.keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob')self.lamb = tf.placeholder(dtype=tf.float32, name='lambda')# layer 1with tf.name_scope('layer_1'):# conv_1self.w[1] = self.init_w(shape=[3, 3, INPUT_IMG_CHANNEL, 64], name='w_1')self.b[1] = self.init_b(shape=[64], name='b_1')result_conv_1 = tf.nn.conv2d(input=self.input_image, filter=self.w[1],strides=[1, 1, 1, 1], padding='SAME', name='conv_1')result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[1], name='add_bias'), name='relu_1')# conv_2self.w[2] = self.init_w(shape=[3, 3, 64, 64], name='w_2')self.b[2] = self.init_b(shape=[64], name='b_2')result_conv_2 = tf.nn.conv2d(input=result_relu_1, filter=self.w[2],strides=[1, 1, 1, 1], padding='SAME', name='conv_2')result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[2], name='add_bias'), name='relu_2')self.result_from_contract_layer[1] = result_relu_2 # 該層結果臨時保存, 供上采樣使用# maxpoolresult_maxpool = tf.nn.max_pool(value=result_relu_2, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='VALID', name='maxpool')# dropoutresult_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)# layer 2with tf.name_scope('layer_2'):# conv_1self.w[3] = self.init_w(shape=[3, 3, 64, 128], name='w_3')self.b[3] = self.init_b(shape=[128], name='b_3')result_conv_1 = tf.nn.conv2d(input=result_dropout, filter=self.w[3],strides=[1, 1, 1, 1], padding='SAME', name='conv_1')result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[3], name='add_bias'), name='relu_1')# conv_2self.w[4] = self.init_w(shape=[3, 3, 128, 128], name='w_4')self.b[4] = self.init_b(shape=[128], name='b_4')result_conv_2 = tf.nn.conv2d(input=result_relu_1, filter=self.w[4],strides=[1, 1, 1, 1], padding='SAME', name='conv_2')result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[4], name='add_bias'), name='relu_2')self.result_from_contract_layer[2] = result_relu_2 # 該層結果臨時保存, 供上采樣使用# maxpoolresult_maxpool = tf.nn.max_pool(value=result_relu_2, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='VALID', name='maxpool')# dropoutresult_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)# layer 3with tf.name_scope('layer_3'):# conv_1self.w[5] = self.init_w(shape=[3, 3, 128, 256], name='w_5')self.b[5] = self.init_b(shape=[256], name='b_5')result_conv_1 = tf.nn.conv2d(input=result_dropout, filter=self.w[5],strides=[1, 1, 1, 1], padding='SAME', name='conv_1')result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[5], name='add_bias'), name='relu_1')# conv_2self.w[6] = self.init_w(shape=[3, 3, 256, 256], name='w_6')self.b[6] = self.init_b(shape=[256], name='b_6')result_conv_2 = tf.nn.conv2d(input=result_relu_1, filter=self.w[6],strides=[1, 1, 1, 1], padding='SAME', name='conv_2')result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[6], name='add_bias'), name='relu_2')self.result_from_contract_layer[3] = result_relu_2 # 該層結果臨時保存, 供上采樣使用# maxpoolresult_maxpool = tf.nn.max_pool(value=result_relu_2, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='VALID', name='maxpool')# dropoutresult_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)# layer 4with tf.name_scope('layer_4'):# conv_1self.w[7] = self.init_w(shape=[3, 3, 256, 512], name='w_7')self.b[7] = self.init_b(shape=[512], name='b_7')result_conv_1 = tf.nn.conv2d(input=result_dropout, filter=self.w[7],strides=[1, 1, 1, 1], padding='SAME', name='conv_1')result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[7], name='add_bias'), name='relu_1')# conv_2self.w[8] = self.init_w(shape=[3, 3, 512, 512], name='w_8')self.b[8] = self.init_b(shape=[512], name='b_8')result_conv_2 = tf.nn.conv2d(input=result_relu_1, filter=self.w[8],strides=[1, 1, 1, 1], padding='SAME', name='conv_2')result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[8], name='add_bias'), name='relu_2')self.result_from_contract_layer[4] = result_relu_2 # 該層結果臨時保存, 供上采樣使用# maxpoolresult_maxpool = tf.nn.max_pool(value=result_relu_2, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='VALID', name='maxpool')# dropoutresult_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)# layer 5 (bottom)with tf.name_scope('layer_5'):# conv_1self.w[9] = self.init_w(shape=[3, 3, 512, 1024], name='w_9')self.b[9] = self.init_b(shape=[1024], name='b_9')result_conv_1 = tf.nn.conv2d(input=result_dropout, filter=self.w[9],strides=[1, 1, 1, 1], padding='SAME', name='conv_1')result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[9], name='add_bias'), name='relu_1')# conv_2self.w[10] = self.init_w(shape=[3, 3, 1024, 1024], name='w_10')self.b[10] = self.init_b(shape=[1024], name='b_10')result_conv_2 = tf.nn.conv2d(input=result_relu_1, filter=self.w[10],strides=[1, 1, 1, 1], padding='SAME', name='conv_2')result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[10], name='add_bias'), name='relu_2')# up sampleself.w[11] = self.init_w(shape=[2, 2, 512, 1024], name='w_11')self.b[11] = self.init_b(shape=[512], name='b_11')result_up = tf.nn.conv2d_transpose(value=result_relu_2, filter=self.w[11],output_shape=[batch_size, 64, 64, 512],strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[11], name='add_bias'), name='relu_3')# dropoutresult_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)# layer 6with tf.name_scope('layer_6'):# copy, crop and mergeresult_merge = self.copy_and_crop_and_merge(result_from_contract_layer=self.result_from_contract_layer[4], result_from_upsampling=result_dropout)# print(result_merge)# conv_1self.w[12] = self.init_w(shape=[3, 3, 1024, 512], name='w_12')self.b[12] = self.init_b(shape=[512], name='b_12')result_conv_1 = tf.nn.conv2d(input=result_merge, filter=self.w[12],strides=[1, 1, 1, 1], padding='SAME', name='conv_1')result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[12], name='add_bias'), name='relu_1')# conv_2self.w[13] = self.init_w(shape=[3, 3, 512, 512], name='w_10')self.b[13] = self.init_b(shape=[512], name='b_10')result_conv_2 = tf.nn.conv2d(input=result_relu_1, filter=self.w[13],strides=[1, 1, 1, 1], padding='SAME', name='conv_2')result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[13], name='add_bias'), name='relu_2')# print(result_relu_2.shape[1])# up sampleself.w[14] = self.init_w(shape=[2, 2, 256, 512], name='w_11')self.b[14] = self.init_b(shape=[256], name='b_11')result_up = tf.nn.conv2d_transpose(value=result_relu_2, filter=self.w[14],output_shape=[batch_size, 128, 128, 256],strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[14], name='add_bias'), name='relu_3')# dropoutresult_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)# layer 7with tf.name_scope('layer_7'):# copy, crop and mergeresult_merge = self.copy_and_crop_and_merge(result_from_contract_layer=self.result_from_contract_layer[3], result_from_upsampling=result_dropout)# conv_1self.w[15] = self.init_w(shape=[3, 3, 512, 256], name='w_12')self.b[15] = self.init_b(shape=[256], name='b_12')result_conv_1 = tf.nn.conv2d(input=result_merge, filter=self.w[15],strides=[1, 1, 1, 1], padding='SAME', name='conv_1')result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[15], name='add_bias'), name='relu_1')# conv_2self.w[16] = self.init_w(shape=[3, 3, 256, 256], name='w_10')self.b[16] = self.init_b(shape=[256], name='b_10')result_conv_2 = tf.nn.conv2d(input=result_relu_1, filter=self.w[16],strides=[1, 1, 1, 1], padding='SAME', name='conv_2')result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[16], name='add_bias'), name='relu_2')# up sampleself.w[17] = self.init_w(shape=[2, 2, 128, 256], name='w_11')self.b[17] = self.init_b(shape=[128], name='b_11')result_up = tf.nn.conv2d_transpose(value=result_relu_2, filter=self.w[17],output_shape=[batch_size, 256, 256, 128],strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[17], name='add_bias'), name='relu_3')# dropoutresult_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)# layer 8with tf.name_scope('layer_8'):# copy, crop and mergeresult_merge = self.copy_and_crop_and_merge(result_from_contract_layer=self.result_from_contract_layer[2], result_from_upsampling=result_dropout)# conv_1self.w[18] = self.init_w(shape=[3, 3, 256, 128], name='w_12')self.b[18] = self.init_b(shape=[128], name='b_12')result_conv_1 = tf.nn.conv2d(input=result_merge, filter=self.w[18],strides=[1, 1, 1, 1], padding='SAME', name='conv_1')result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[18], name='add_bias'), name='relu_1')# conv_2self.w[19] = self.init_w(shape=[3, 3, 128, 128], name='w_10')self.b[19] = self.init_b(shape=[128], name='b_10')result_conv_2 = tf.nn.conv2d(input=result_relu_1, filter=self.w[19],strides=[1, 1, 1, 1], padding='SAME', name='conv_2')result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[19], name='add_bias'), name='relu_2')# up sampleself.w[20] = self.init_w(shape=[2, 2, 64, 128], name='w_11')self.b[20] = self.init_b(shape=[64], name='b_11')result_up = tf.nn.conv2d_transpose(value=result_relu_2, filter=self.w[20],output_shape=[batch_size, 512, 512, 64],strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[20], name='add_bias'), name='relu_3')# dropoutresult_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)# layer 9with tf.name_scope('layer_9'):# copy, crop and mergeresult_merge = self.copy_and_crop_and_merge(result_from_contract_layer=self.result_from_contract_layer[1], result_from_upsampling=result_dropout)# conv_1self.w[21] = self.init_w(shape=[3, 3, 128, 64], name='w_12')self.b[21] = self.init_b(shape=[64], name='b_12')result_conv_1 = tf.nn.conv2d(input=result_merge, filter=self.w[21],strides=[1, 1, 1, 1], padding='SAME', name='conv_1')result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[21], name='add_bias'), name='relu_1')# conv_2self.w[22] = self.init_w(shape=[3, 3, 64, 64], name='w_10')self.b[22] = self.init_b(shape=[64], name='b_10')result_conv_2 = tf.nn.conv2d(input=result_relu_1, filter=self.w[22],strides=[1, 1, 1, 1], padding='SAME', name='conv_2')result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[22], name='add_bias'), name='relu_2')# convolution to [batch_size, OUTPIT_IMG_WIDE, OUTPUT_IMG_HEIGHT, CLASS_NUM]self.w[23] = self.init_w(shape=[1, 1, 64, CLASS_NUM], name='w_11')self.b[23] = self.init_b(shape=[CLASS_NUM], name='b_11')result_conv_3 = tf.nn.conv2d(input=result_relu_2, filter=self.w[23],strides=[1, 1, 1, 1], padding='VALID', name='conv_3')# self.prediction = tf.nn.relu(tf.nn.bias_add(result_conv_3, self.b[23], name='add_bias'), name='relu_3')# self.prediction = tf.nn.sigmoid(x=tf.nn.bias_add(result_conv_3, self.b[23], name='add_bias'), name='sigmoid_1')self.prediction = tf.nn.bias_add(result_conv_3, self.b[23], name='add_bias')# print(self.prediction)# print(self.input_label)# softmax losswith tf.name_scope('softmax_loss'):# using one-hot# self.loss = \# tf.nn.softmax_cross_entropy_with_logits(labels=self.cast_label, logits=self.prediction, name='loss')# not using one-hotself.loss = \tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')self.loss_mean = tf.reduce_mean(self.loss)tf.add_to_collection(name='loss', value=self.loss_mean)self.loss_all = tf.add_n(inputs=tf.get_collection(key='loss'))# accuracywith tf.name_scope('accuracy'):# using one-hot# self.correct_prediction = tf.equal(tf.argmax(self.prediction, axis=3), tf.argmax(self.cast_label, axis=3))# not using one-hotself.correct_prediction = \tf.equal(tf.argmax(input=self.prediction, axis=3, output_type=tf.int32), self.input_label)self.correct_prediction = tf.cast(self.correct_prediction, tf.float32)self.accuracy = tf.reduce_mean(self.correct_prediction)# Gradient Descentwith tf.name_scope('Gradient_Descent'):self.train_step = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(self.loss_all)def train(self):# import cv2# import numpy as np# ckpt_path = os.path.join(FLAGS.model_dir, "model.ckpt")# all_parameters_saver = tf.train.Saver()# # import numpy as np# # mydata = DataProcess(INPUT_IMG_HEIGHT, INPUT_IMG_WIDE)# # imgs_train, imgs_mask_train = mydata.load_my_train_data()# my_set_image = cv2.imread('../data_set/train.tif', flags=0)# my_set_label = cv2.imread('../data_set/label.tif', flags=0)# my_set_image.astype('float32')# my_set_label[my_set_label <= 128] = 0# my_set_label[my_set_label > 128] = 1# my_set_image = np.reshape(a=my_set_image, newshape=(1, INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))# my_set_label = np.reshape(a=my_set_label, newshape=(1, OUTPUT_IMG_WIDE, OUTPUT_IMG_HEIGHT))# # cv2.imshow('image', my_set_image)# # cv2.imshow('label', my_set_label * 100)# # cv2.waitKey(0)# with tf.Session() as sess: # 開始一個會話# sess.run(tf.global_variables_initializer())# sess.run(tf.local_variables_initializer())# for epoch in range(10):# lo, acc = sess.run(# [self.loss_mean, self.accuracy],# feed_dict={# self.input_image: my_set_image, self.input_label: my_set_label,# self.keep_prob: 1.0, self.lamb: 0.004}# )# # print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))# print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))# sess.run(# [self.train_step],# feed_dict={# self.input_image: my_set_image, self.input_label: my_set_label,# self.keep_prob: 0.6, self.lamb: 0.004}# )# all_parameters_saver.save(sess=sess, save_path=ckpt_path)# print("Done training")train_file_path = os.path.join(FLAGS.data_dir, TRAIN_SET_NAME)train_image_filename_queue = tf.train.string_input_producer(string_tensor=tf.train.match_filenames_once(train_file_path), num_epochs=EPOCH_NUM, shuffle=True)ckpt_path = CHECK_POINT_PATHtrain_images, train_labels = read_image_batch(train_image_filename_queue, TRAIN_BATCH_SIZE)tf.summary.scalar("loss", self.loss_mean)tf.summary.scalar('accuracy', self.accuracy)merged_summary = tf.summary.merge_all()all_parameters_saver = tf.train.Saver()with tf.Session() as sess: # 開始一個會話sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)tf.summary.FileWriter(FLAGS.model_dir, sess.graph)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)try:epoch = 1while not coord.should_stop():# Run training steps or whatever# print('epoch ' + str(epoch))example, label = sess.run([train_images, train_labels]) # 在會話中取出image和label# print(label)lo, acc, summary_str = sess.run([self.loss_mean, self.accuracy, merged_summary],feed_dict={self.input_image: example, self.input_label: label, self.keep_prob: 1.0,self.lamb: 0.004})summary_writer.add_summary(summary_str, epoch)# print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))if epoch % 10 == 0:print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))sess.run([self.train_step],feed_dict={self.input_image: example, self.input_label: label, self.keep_prob: 0.6,self.lamb: 0.004})epoch += 1except tf.errors.OutOfRangeError:print('Done training -- epoch limit reached')finally:# When done, ask the threads to stop.all_parameters_saver.save(sess=sess, save_path=ckpt_path)coord.request_stop()# coord.request_stop()coord.join(threads)print("Done training")def validate(self):# import cv2# import numpy as np# ckpt_path = os.path.join(FLAGS.model_dir, "model.ckpt")# # mydata = DataProcess(INPUT_IMG_HEIGHT, INPUT_IMG_WIDE)# # imgs_train, imgs_mask_train = mydata.load_my_train_data()# all_parameters_saver = tf.train.Saver()# my_set_image = cv2.imread('../data_set/train.tif', flags=0)# my_set_label = cv2.imread('../data_set/label.tif', flags=0)# my_set_image.astype('float32')# my_set_label[my_set_label <= 128] = 0# my_set_label[my_set_label > 128] = 1# with tf.Session() as sess:# sess.run(tf.global_variables_initializer())# sess.run(tf.local_variables_initializer())# all_parameters_saver.restore(sess=sess, save_path=ckpt_path)# image, acc = sess.run(# fetches=[self.prediction, self.accuracy],# feed_dict={# self.input_image: my_set_image, self.input_label: my_set_label,# self.keep_prob: 1.0, self.lamb: 0.004}# )# image = np.argmax(a=image[0], axis=2).astype('uint8') * 255# # cv2.imshow('predict', image)# # cv2.imshow('o', np.asarray(a=image[0], dtype=np.uint8) * 100)# # cv2.waitKey(0)# cv2.imwrite(filename=os.path.join(FLAGS.model_dir, 'predict.jpg'), img=image)# print(acc)# print("Done test, predict image has been saved to %s" % (os.path.join(FLAGS.model_dir, 'predict.jpg')))validation_file_path = os.path.join(FLAGS.data_dir, VALIDATION_SET_NAME)validation_image_filename_queue = tf.train.string_input_producer(string_tensor=tf.train.match_filenames_once(validation_file_path), num_epochs=1, shuffle=True)ckpt_path = CHECK_POINT_PATHvalidation_images, validation_labels = read_image_batch(validation_image_filename_queue, VALIDATION_BATCH_SIZE)# tf.summary.scalar("loss", self.loss_mean)# tf.summary.scalar('accuracy', self.accuracy)# merged_summary = tf.summary.merge_all()all_parameters_saver = tf.train.Saver()with tf.Session() as sess: # 開始一個會話sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)all_parameters_saver.restore(sess=sess, save_path=ckpt_path)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)try:epoch = 1while not coord.should_stop():# Run training steps or whatever# print('epoch ' + str(epoch))example, label = sess.run([validation_images, validation_labels]) # 在會話中取出image和label# print(label)lo, acc = sess.run([self.loss_mean, self.accuracy],feed_dict={self.input_image: example, self.input_label: label, self.keep_prob: 1.0,self.lamb: 0.004})# summary_writer.add_summary(summary_str, epoch)# print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))if epoch % 1 == 0:print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))epoch += 1except tf.errors.OutOfRangeError:print('Done validating -- epoch limit reached')finally:# When done, ask the threads to stop.coord.request_stop()# coord.request_stop()coord.join(threads)print('Done validating')def test(self):import cv2test_file_path = os.path.join(FLAGS.data_dir, TEST_SET_NAME)test_image_filename_queue = tf.train.string_input_producer(string_tensor=tf.train.match_filenames_once(test_file_path), num_epochs=1, shuffle=True)ckpt_path = CHECK_POINT_PATHtest_images, test_labels = read_image_batch(test_image_filename_queue, TEST_BATCH_SIZE)# tf.summary.scalar("loss", self.loss_mean)# tf.summary.scalar('accuracy', self.accuracy)# merged_summary = tf.summary.merge_all()all_parameters_saver = tf.train.Saver()with tf.Session() as sess: # 開始一個會話sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)all_parameters_saver.restore(sess=sess, save_path=ckpt_path)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)sum_acc = 0.0try:epoch = 0while not coord.should_stop():# Run training steps or whatever# print('epoch ' + str(epoch))example, label = sess.run([test_images, test_labels]) # 在會話中取出image和label# print(label)image, acc = sess.run([tf.argmax(input=self.prediction, axis=3), self.accuracy],feed_dict={self.input_image: example, self.input_label: label,self.keep_prob: 1.0, self.lamb: 0.004})sum_acc += accepoch += 1cv2.imwrite(os.path.join(PREDICT_SAVED_DIRECTORY, '%d.jpg' % epoch), image[0] * 255)if epoch % 1 == 0:print('num %d accuracy: %.6f' % (epoch, acc))except tf.errors.OutOfRangeError:print('Done testing -- epoch limit reached \n Average accuracy: %.2f%%' % (sum_acc / TEST_SET_SIZE * 100))finally:# When done, ask the threads to stop.coord.request_stop()# coord.request_stop()coord.join(threads)print('Done testing')def predict(self):import cv2import globimport numpy as np# TODO 不應該這樣寫,應該直接讀圖片預測,而不是從tfrecord讀取,因為順序變了,無法對應predict_file_path = glob.glob(os.path.join(ORIGIN_PREDICT_DIRECTORY, '*.tif'))print(len(predict_file_path))if not os.path.lexists(PREDICT_SAVED_DIRECTORY):os.mkdir(PREDICT_SAVED_DIRECTORY)ckpt_path = CHECK_POINT_PATHall_parameters_saver = tf.train.Saver()with tf.Session() as sess: # 開始一個會話sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)all_parameters_saver.restore(sess=sess, save_path=ckpt_path)for index, image_path in enumerate(predict_file_path):# image = cv2.imread(image_path, flags=0)image = np.reshape(a=cv2.imread(image_path, flags=0), newshape=(1, INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))predict_image = sess.run(tf.argmax(input=self.prediction, axis=3),feed_dict={self.input_image: image,self.keep_prob: 1.0, self.lamb: 0.004})cv2.imwrite(os.path.join(PREDICT_SAVED_DIRECTORY, '%d.jpg' % index), predict_image[0] * 255)print('Done prediction')?
總結
以上是生活随笔為你收集整理的U-net结构及代码注释的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深度学习论文翻译
- 下一篇: WGAN的提出背景以及解决方案