UNet多类别分割的keras实现
生活随笔
收集整理的這篇文章主要介紹了
UNet多类别分割的keras实现
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
本文包含制作數據集、訓練、推理、測試圖像及結果四部分內容
目錄
制作數據集
訓練
推理
測試圖像及結果
制作數據集
該數據集包含420張224*400圖像,圖像由畫圖工具產生,包含圓形、矩形和背景三種類別,選用不同的顏色進行填充。部分訓練圖像和標簽圖像如下圖所示
?
訓練
根據所填充顏色,將每張標注圖像生成為rows*cols*class_nums的形式
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img, save_img import numpy as np import oscolorDict = [[0, 0, 0], [34, 177, 76], [237, 28, 36]] ###背景、圓形、矩形的填充色 colorDict_RGB = np.array(colorDict) colorDict_GRAY = colorDict_RGB[:, 0] num_classes = 3def data_preprocess(label, class_num):for i in range(colorDict_GRAY.shape[0]):label[label == colorDict_GRAY[i]] = inew_label = np.zeros(label.shape + (class_num,))for i in range(class_num):new_label[label == i, i] = 1label = new_labelreturn labeldef visual(array):for j in range(num_classes):vis = array[:, :, j]vis = vis*255vis = vis.reshape(224, 400, 1)vis_out = array_to_img(vis)vis_out.show()class dataProcess(object):def __init__(self, out_rows, out_cols, data_path="../train1", label_path="../label1",test_path="../test1", npy_path="../npydata", img_type="bmp"):# 數據處理類,初始化self.out_rows = out_rowsself.out_cols = out_colsself.data_path = data_pathself.label_path = label_pathself.img_type = img_typeself.test_path = test_pathself.npy_path = npy_pathself.num_classes = num_classes# 創建訓練數據def create_train_data(self):print('-' * 30)print('Creating training images...')print('-' * 30)img_list = os.listdir(self.data_path)imgdatas = np.ndarray((len(img_list), self.out_rows, self.out_cols, 1), dtype=np.uint8)imglabels = np.ndarray((len(img_list), self.out_rows, self.out_cols, self.num_classes), dtype=np.uint8)for i in range(len(img_list)):img = load_img(self.data_path + "/" + img_list[i], color_mode="grayscale")img = img_to_array(img)imgdatas[i] = imglabel = load_img(self.label_path + "/" + img_list[i])label = img_to_array(label)[:, :, 0]label = data_preprocess(label, num_classes)# visual(label)imglabels[i] = labelnp.save(self.npy_path + '/imgs_train.npy', imgdatas)np.save(self.npy_path + '/imgs_mask_train.npy', imglabels)print('Saving to .npy files done.')def load_train_data(self):print('-' * 30)print('load train images...')print('-' * 30)imgs_train = np.load(self.npy_path + "/imgs_train.npy")imgs_mask_train = np.load(self.npy_path + "/imgs_mask_train.npy")imgs_train = imgs_train.astype('float32')imgs_mask_train = imgs_mask_train.astype('float32')imgs_train /= 255.0imgs_mask_train /= 255.0return imgs_train, imgs_mask_traindef create_test_data(self):test_list = []print('-' * 30)print('Creating test images...')print('-' * 30)img_list = os.listdir(self.test_path)testdatas = np.ndarray((len(img_list), self.out_rows, self.out_cols, 1), dtype=np.uint8)for i in range(len(img_list)):img = load_img(self.test_path + "/" + img_list[i], color_mode="grayscale")img = img_to_array(img)testdatas[i] = imgtest_list.append(img_list[i])np.save(self.npy_path + '/imgs_test.npy', testdatas)print('Saving to .npy files done.')return test_listdef load_test_data(self):print('-' * 30)print('load test images...')print('-' * 30)imgs_test = np.load(self.npy_path + "/imgs_test.npy")imgs_test = imgs_test.astype('float32')imgs_test /= 255.0return imgs_testif __name__ == "__main__":mydata = dataProcess(224, 400)mydata.create_train_data()imgs_train, imgs_mask_train = mydata.load_train_data()print(imgs_train.shape, imgs_mask_train.shape)將conv10中的類別數目修改為class_nums,將激活函數修改為softmax,將loss函數修改為'categorical_crossentropy'
import numpy as np from keras.models import * from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D from keras.optimizers import * from keras.callbacks import ModelCheckpoint from keras import backend as keras from my_test.data import * from keras.models import Modelclass myUnet(object):def __init__(self, img_rows = 224, img_cols = 400):self.img_rows = img_rowsself.img_cols = img_colsdef load_data(self):mydata = dataProcess(self.img_rows, self.img_cols)imgs_train, imgs_mask_train = mydata.load_train_data()return imgs_train, imgs_mask_traindef get_unet(self):inputs = Input((self.img_rows, self.img_cols, 1))# 網絡結構定義'''#unet with crop(because padding = valid) conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(inputs)print "conv1 shape:",conv1.shapeconv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv1)print "conv1 shape:",conv1.shapecrop1 = Cropping2D(cropping=((90,90),(90,90)))(conv1)print "crop1 shape:",crop1.shapepool1 = MaxPooling2D(pool_size=(2, 2))(conv1)print "pool1 shape:",pool1.shapeconv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool1)print "conv2 shape:",conv2.shapeconv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv2)print "conv2 shape:",conv2.shapecrop2 = Cropping2D(cropping=((41,41),(41,41)))(conv2)print "crop2 shape:",crop2.shapepool2 = MaxPooling2D(pool_size=(2, 2))(conv2)print "pool2 shape:",pool2.shapeconv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool2)print "conv3 shape:",conv3.shapeconv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv3)print "conv3 shape:",conv3.shapecrop3 = Cropping2D(cropping=((16,17),(16,17)))(conv3)print "crop3 shape:",crop3.shapepool3 = MaxPooling2D(pool_size=(2, 2))(conv3)print "pool3 shape:",pool3.shapeconv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool3)conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv4)drop4 = Dropout(0.5)(conv4)crop4 = Cropping2D(cropping=((4,4),(4,4)))(drop4)pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool4)conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv5)drop5 = Dropout(0.5)(conv5)up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))merge6 = merge([crop4,up6], mode = 'concat', concat_axis = 3)conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge6)conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv6)up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))merge7 = merge([crop3,up7], mode = 'concat', concat_axis = 3)conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge7)conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv7)up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))merge8 = merge([crop2,up8], mode = 'concat', concat_axis = 3)conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge8)conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv8)up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))merge9 = merge([crop1,up9], mode = 'concat', concat_axis = 3)conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge9)conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)conv9 = Conv2D(2, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)'''conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)pool1 = MaxPooling2D((2, 2))(conv1)conv2 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)conv2 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)pool2 = MaxPooling2D((2, 2))(conv2)conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)drop4 = Dropout(0.5)(conv4)pool4 = MaxPooling2D((2, 2))(drop4)conv5 = Conv2D(1024, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)conv5 = Conv2D(1024, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)drop5 = Dropout(0.5)(conv5)up6 = Conv2D(512, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))merge6 = concatenate([drop4, up6], axis = 3)conv6 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)conv6 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)up7 = Conv2D(256, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))merge7 = concatenate([conv3, up7], axis = 3)conv7 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)conv7 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)up8 = Conv2D(128, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))merge8 = concatenate([conv2, up8], axis = 3)conv8 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)conv8 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)up9 = Conv2D(64, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))merge9 = concatenate([conv1, up9], axis = 3)conv9 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)conv9 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)conv9 = Conv2D(2, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)# conv10 = Conv2D(1, (1,1), activation = 'sigmoid')(conv9)conv10 = Conv2D(class_nums, (1,1), activation = 'softmax')(conv9)model = Model(inputs = inputs, outputs = conv10)model.summary()# model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])model.compile(optimizer = Adam(lr = 1e-4), loss = 'categorical_crossentropy', metrics = ['accuracy'])return modeldef train(self):print("loading data")imgs_train, imgs_mask_train = self.load_data()print("loading data done")model = self.get_unet()print("got unet")model_checkpoint = ModelCheckpoint('my_unet.hdf5', monitor='loss',verbose=1, save_best_only=True)print('Fitting model...')model.fit(imgs_train, imgs_mask_train, batch_size=4, epochs=10, verbose=1, validation_split=0.2, shuffle=True, callbacks=[model_checkpoint])if __name__ == '__main__':class_nums = 3myunet = myUnet()myunet.train()推理
將推理結果拆分為3通道圖像,分別顯示各通道圖像
from my_test.data import * import numpy as np from keras.models import load_model from keras.preprocessing.image import array_to_imgdef save_img(test_list):print("array to image")imgs = np.load('../11/imgs_mask_test.npy')for i in range(imgs.shape[0]):img = imgs[i]for j in range(class_num):out = img[:, :, j]out = out.reshape(224, 400, 1)out = array_to_img(out)out.save("../11/" + str(j) + '_' + test_list[i])unet_model_path = 'my_unet.hdf5' model = load_model(unet_model_path) class_num = 3 mydata = dataProcess(224, 400) imgs_test = mydata.load_test_data() test_list = mydata.create_test_data() imgs_mask_test = model.predict(imgs_test, batch_size=1, verbose=1) np.save('../11/imgs_mask_test.npy', imgs_mask_test) save_img(test_list)測試圖像及結果
網絡輸入為rows*cols*1,輸出為rows*cols*class_nums。在數據處理階段,將通道0中的背景設為mask區域,通道1中的圓形設置為mask區域,通道2中的矩形設置為mask區域,因此對輸出的三個通道進行拆分得到:通道0為背景的分割結果,通道1為圓形的分割結果,通道2為矩形的分割結果
總結
以上是生活随笔為你收集整理的UNet多类别分割的keras实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Python贪婪算法
- 下一篇: squid mysql认证_Squid