tensorflow综合示例5:图象分割
本文主要內容來自: https://www.tensorflow.org/tutorials/images/segmentation?hl=zh-cn
圖像分割
這篇教程將重點討論圖像分割任務,使用的是改進版的 U-Net。
什么是圖像分割?
目前你已經了解在圖像分類中,神經網絡的任務是給每張輸入圖像分配一個標簽或者類別。但是,有時你想知道一個物體在一張圖像中的位置、這個物體的形狀、以及哪個像素屬于哪個物體等等。**這種情況下你會希望分割圖像,也就是給圖像中的每個像素各分配一個標簽。因此,圖像分割的任務是訓練一個神經網絡來輸出該圖像對每一個像素的掩碼。**這對從更底層,即像素層級,來理解圖像很有幫助。圖像分割在例如醫療圖像、自動駕駛車輛以及衛星圖像等領域有很多應用。
本教程將使用的數據集是 Oxford-IIIT Pet 數據集,由 Parkhi et al. 創建。該數據集由圖像、圖像所對應的標簽、以及對像素逐一標記的掩碼組成。掩碼其實就是給每個像素的標簽。每個像素分別屬于以下三個類別中的一個:
- 類別 1:像素是寵物的一部分。
- 類別 2:像素是寵物的輪廓。
- 類別 3:以上都不是/外圍像素。
下載 Oxford-IIIT Pets 數據集
這個數據集已經集成在 Tensorflow datasets 中,只需下載即可。圖像分割掩碼在版本 3.0.0 中才被加入,因此我們特別選用這個版本。
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)下面的代碼進行了一個簡單的圖像翻轉擴充。然后,將圖像標準化到 [0,1]。最后,如上文提到的,像素點在圖像分割掩碼中被標記為 {1, 2, 3} 中的一個。為了方便起見,我們將分割掩碼都減 1,得到了以下的標簽:{0, 1, 2}。
def normalize(input_image, input_mask):input_image = tf.cast(input_image, tf.float32) / 255.0input_mask -= 1return input_image, input_mask @tf.function def load_image_train(datapoint):input_image = tf.image.resize(datapoint['image'], (128, 128))input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))if tf.random.uniform(()) > 0.5:input_image = tf.image.flip_left_right(input_image)input_mask = tf.image.flip_left_right(input_mask)input_image, input_mask = normalize(input_image, input_mask)return input_image, input_mask def load_image_test(datapoint):input_image = tf.image.resize(datapoint['image'], (128, 128))input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))input_image, input_mask = normalize(input_image, input_mask)return input_image, input_mask數據集已經包含了所需的測試集和訓練集劃分,所以我們也延續使用相同的劃分。
TRAIN_LENGTH = info.splits['train'].num_examples BATCH_SIZE = 64 BUFFER_SIZE = 1000 STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) test = dataset['test'].map(load_image_test) train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) test_dataset = test.batch(BATCH_SIZE)我們來看一下數據集中的一例圖像以及它所對應的掩碼。
def display(display_list):plt.figure(figsize=(15, 15))title = ['Input Image', 'True Mask', 'Predicted Mask']for i in range(len(display_list)):plt.subplot(1, len(display_list), i+1)plt.title(title[i])plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))plt.axis('off')plt.show() for image, mask in train.take(1):sample_image, sample_mask = image, mask display([sample_image, sample_mask])定義模型
這里用到的模型是一個改版的 U-Net。U-Net 由一個編碼器(下采樣器(downsampler))和一個解碼器(上采樣器(upsampler))組成。為了學習到魯棒的特征,同時減少可訓練參數的數量,這里可以使用一個預訓練模型作為編碼器。因此,這項任務中的編碼器將使用一個預訓練的 MobileNetV2 模型,它的中間輸出值將被使用。解碼器將使用在 TensorFlow Examples 中的 Pix2pix tutorial 里實施過的升頻取樣模塊。
輸出信道數量為 3 是因為每個像素有三種可能的標簽。把這想象成一個多類別分類,每個像素都將被分到三個類別當中。
OUTPUT_CHANNELS = 3如之前提到的,編碼器是一個預訓練的 MobileNetV2 模型,它在 tf.keras.applications 中已被準備好并可以直接使用。編碼器中包含模型中間層的一些特定輸出。注意編碼器在模型的訓練過程中是不會被訓練的。
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)# 使用這些層的激活設置 layer_names = ['block_1_expand_relu', # 64x64'block_3_expand_relu', # 32x32'block_6_expand_relu', # 16x16'block_13_expand_relu', # 8x8'block_16_project', # 4x4 ] layers = [base_model.get_layer(name).output for name in layer_names]# 創建特征提取模型 down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)down_stack.trainable = False Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5 9412608/9406464 [==============================] - 0s 0us/step解碼器/升頻取樣器是簡單的一系列升頻取樣模塊,在 TensorFlow examples 中曾被實施過。
up_stack = [pix2pix.upsample(512, 3), # 4x4 -> 8x8pix2pix.upsample(256, 3), # 8x8 -> 16x16pix2pix.upsample(128, 3), # 16x16 -> 32x32pix2pix.upsample(64, 3), # 32x32 -> 64x64 ] def unet_model(output_channels):inputs = tf.keras.layers.Input(shape=[128, 128, 3])x = inputs# 在模型中降頻取樣skips = down_stack(x)x = skips[-1]skips = reversed(skips[:-1])# 升頻取樣然后建立跳躍連接for up, skip in zip(up_stack, skips):x = up(x)concat = tf.keras.layers.Concatenate()x = concat([x, skip])# 這是模型的最后一層last = tf.keras.layers.Conv2DTranspose(output_channels, 3, strides=2,padding='same') #64x64 -> 128x128x = last(x)return tf.keras.Model(inputs=inputs, outputs=x)訓練模型
現在,要做的只剩下編譯和訓練模型了。這里用到的損失函數是 losses.sparse_categorical_crossentropy。使用這個損失函數是因為神經網絡試圖給每一個像素分配一個標簽,和多類別預測是一樣的。在正確的分割掩碼中,每個像素點的值是 {0,1,2} 中的一個。同時神經網絡也輸出三個信道。本質上,每個信道都在嘗試學習預測一個類別,而 losses.sparse_categorical_crossentropy 正是這一情形下推薦使用的損失函數。根據神經網絡的輸出值,分配給每個像素的標簽為輸出值最高的信道所表示的那一類。這就是 create_mask 函數所做的工作。
model = unet_model(OUTPUT_CHANNELS) model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])快速瀏覽一下最終的模型架構:
tf.keras.utils.plot_model(model, show_shapes=True)我們試著運行一下模型,看看它在訓練之前給出的預測值。
def create_mask(pred_mask):pred_mask = tf.argmax(pred_mask, axis=-1)pred_mask = pred_mask[..., tf.newaxis]return pred_mask[0] def show_predictions(dataset=None, num=1):if dataset:for image, mask in dataset.take(num):pred_mask = model.predict(image)display([image[0], mask[0], create_mask(pred_mask)])else:display([sample_image, sample_mask,create_mask(model.predict(sample_image[tf.newaxis, ...]))]) show_predictions()我們來觀察模型是怎樣隨著訓練而改善的。為達成這一目的,下面將定義一個 callback 函數。
class DisplayCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):clear_output(wait=True)show_predictions()print ('\nSample Prediction after epoch {}\n'.format(epoch+1)) EPOCHS = 20 VAL_SUBSPLITS = 5 VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITSmodel_history = model.fit(train_dataset, epochs=EPOCHS,steps_per_epoch=STEPS_PER_EPOCH,validation_steps=VALIDATION_STEPS,validation_data=test_dataset,callbacks=[DisplayCallback()]) Sample Prediction after epoch 2057/57 [==============================] - 3s 54ms/step - loss: 0.1308 - accuracy: 0.9401 - val_loss: 0.3246 - val_accuracy: 0.8903 loss = model_history.history['loss'] val_loss = model_history.history['val_loss']epochs = range(EPOCHS)plt.figure() plt.plot(epochs, loss, 'r', label='Training loss') plt.plot(epochs, val_loss, 'bo', label='Validation loss') plt.title('Training and Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss Value') plt.ylim([0, 1]) plt.legend() plt.show()做出預測
我們來做幾個預測。為了節省時間,這里只使用很少的周期(epoch)數,但是你可以設置更多的數量以獲得更準確的結果。
show_predictions(test_dataset, 3)接下來
現在你已經對圖像分割是什么以及它的工作原理有所了解。你可以在本教程里嘗試使用不同的中間層輸出值,或者甚至使用不同的預訓練模型。你也可以去 Kaggle 舉辦的 Carvana 圖像分割挑戰賽上挑戰自己。
你也可以看看 Tensorflow Object Detection API 上面其他的你可以使用自己數據進行再訓練的模型。
總結
以上是生活随笔為你收集整理的tensorflow综合示例5:图象分割的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tensorflow综合示例4:逻辑回归
- 下一篇: tensorflow综合示例1:tens