keras遥感图像Unet语义分割(支持多波段多类)
前言
網(wǎng)上其實(shí)有好多unet的教程,但是大多不支持多波段(遙感圖像除了RGB波段還有紅外等其他波段),多類(lèi)別的話標(biāo)簽做onehot編碼的時(shí)候類(lèi)別顏色要手動(dòng)輸入。針對(duì)這兩個(gè)問(wèn)題,今天寫(xiě)下這篇文字。
有問(wèn)題歡迎留言評(píng)論,覺(jué)得不錯(cuò)可以動(dòng)動(dòng)手指點(diǎn)個(gè)贊同&喜歡
如果我們的keras環(huán)境還沒(méi)有搭建好,請(qǐng)先移步我下面這個(gè)文字:
馨意:深度學(xué)習(xí)Win10_Keras_TensorFlow-GPU環(huán)境搭建?zhuanlan.zhihu.com
如果我們還沒(méi)有制作標(biāo)簽,請(qǐng)先移步我下面這個(gè)文字:
?
馨意:利用Arcgis制作遙感圖像深度學(xué)習(xí)語(yǔ)義分割標(biāo)簽?zhuanlan.zhihu.com
如果我們的圖像還沒(méi)有裁剪成深度學(xué)習(xí)樣本,請(qǐng)先移步我下面這個(gè)文字:
馨意:python遙感圖像裁剪成深度學(xué)習(xí)樣本_支持多波段?zhuanlan.zhihu.com
如果我們的樣本還沒(méi)有數(shù)據(jù)增強(qiáng),請(qǐng)先移步我下面這個(gè)文字:
馨意:numpy實(shí)現(xiàn)深度學(xué)習(xí)遙感圖像語(yǔ)義分割數(shù)據(jù)增強(qiáng)(支持多波段)?zhuanlan.zhihu.com
目錄
正文
1 讀取圖像數(shù)據(jù)
首先,我們要讀取圖像的像素矩陣,這里為了能支持多波段,我們利用GDAL讀取:
import gdal # 讀取圖像像素矩陣 # fileName 圖像文件名 def readTif(fileName):dataset = gdal.Open(fileName)width = dataset.RasterXSizeheight = dataset.RasterYSizeGdalImg_data = dataset.ReadAsArray(0, 0, width, height)return GdalImg_data2 圖像預(yù)處理
讀取了圖像之后就要做預(yù)處理了:
- 對(duì)圖像進(jìn)行歸一化,這里我們采用最大值歸一化,即除以最大值255(對(duì)于8bit數(shù)據(jù)來(lái)說(shuō));
- 對(duì)標(biāo)簽進(jìn)行onehot編碼,即將平面的label的每類(lèi)都單獨(dú)變成由0和1組成的一層。
我們發(fā)現(xiàn)函數(shù)里有個(gè)colorDict_GRAY參數(shù),這個(gè)是我們的各類(lèi)別的顏色字典,比如我們?nèi)绻挥袃深?lèi)的話,colorDict_GRAY = (0, 255)。如果類(lèi)別多了我們還要手動(dòng)輸入比較麻煩,這里我們采用程序自動(dòng)獲取colorDict_GRAY:
# 獲取顏色字典 # labelFolder 標(biāo)簽文件夾,之所以遍歷文件夾是因?yàn)橐粡垬?biāo)簽可能不包含所有類(lèi)別顏色 # classNum 類(lèi)別總數(shù)(含背景) def color_dict(labelFolder, classNum):colorDict = []# 獲取文件夾內(nèi)的文件名ImageNameList = os.listdir(labelFolder)for i in range(len(ImageNameList)):ImagePath = labelFolder + "/" + ImageNameList[i]img = cv2.imread(ImagePath).astype(np.uint32)# 如果是灰度,轉(zhuǎn)成RGBif(len(img.shape) == 2):img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB).astype(np.uint32)# 為了提取唯一值,將RGB轉(zhuǎn)成一個(gè)數(shù)img_new = img[:,:,0] * 1000000 + img[:,:,1] * 1000 + img[:,:,2]unique = np.unique(img_new)# 將第i個(gè)像素矩陣的唯一值添加到colorDict中for j in range(unique.shape[0]):colorDict.append(unique[j])# 對(duì)目前i個(gè)像素矩陣?yán)锏奈ㄒ恢翟偃∥ㄒ恢礳olorDict = sorted(set(colorDict))# 若唯一值數(shù)目等于總類(lèi)數(shù)(包括背景)ClassNum,停止遍歷剩余的圖像if(len(colorDict) == classNum):break# 存儲(chǔ)顏色的RGB字典,用于預(yù)測(cè)時(shí)的渲染結(jié)果colorDict_RGB = []for k in range(len(colorDict)):# 對(duì)沒(méi)有達(dá)到九位數(shù)字的結(jié)果進(jìn)行左邊補(bǔ)零(eg:5,201,111->005,201,111)color = str(colorDict[k]).rjust(9, '0')# 前3位R,中3位G,后3位Bcolor_RGB = [int(color[0 : 3]), int(color[3 : 6]), int(color[6 : 9])]colorDict_RGB.append(color_RGB)# 轉(zhuǎn)為numpy格式colorDict_RGB = np.array(colorDict_RGB)# 存儲(chǔ)顏色的GRAY字典,用于預(yù)處理時(shí)的onehot編碼colorDict_GRAY = colorDict_RGB.reshape((colorDict_RGB.shape[0], 1 ,colorDict_RGB.shape[1])).astype(np.uint8)colorDict_GRAY = cv2.cvtColor(colorDict_GRAY, cv2.COLOR_BGR2GRAY)return colorDict_RGB, colorDict_GRAYcolor_dict函數(shù)除了返回colorDict_GRAY,還會(huì)返回colorDict_RGB,用于預(yù)測(cè)時(shí)RGB渲染顯示。
我們利用keras.Model.fit_generator()函數(shù)進(jìn)行訓(xùn)練,所以我們需要一個(gè)訓(xùn)練數(shù)據(jù)生成器,keras自帶的生成器不支持多波段,所以我們自己編寫(xiě)實(shí)現(xiàn):
# 訓(xùn)練數(shù)據(jù)生成器 # batch_size 批大小 # train_image_path 訓(xùn)練圖像路徑 # train_label_path 訓(xùn)練標(biāo)簽路徑 # classNum 類(lèi)別總數(shù)(含背景) # colorDict_GRAY 顏色字典 # resize_shape resize大小 def trainGenerator(batch_size, train_image_path, train_label_path, classNum, colorDict_GRAY, resize_shape = None):imageList = os.listdir(train_image_path)labelList = os.listdir(train_label_path)img = readTif(train_image_path + "\\" + imageList[0])# GDAL讀數(shù)據(jù)是(BandNum,Width,Height)要轉(zhuǎn)換為->(Width,Height,BandNum)img = img.swapaxes(1, 0)img = img.swapaxes(1, 2)# 無(wú)限生成數(shù)據(jù)while(True):img_generator = np.zeros((batch_size, img.shape[0], img.shape[1], img.shape[2]), np.uint8)label_generator = np.zeros((batch_size, img.shape[0], img.shape[1]), np.uint8)if(resize_shape != None):img_generator = np.zeros((batch_size, resize_shape[0], resize_shape[1], resize_shape[2]), np.uint8)label_generator = np.zeros((batch_size, resize_shape[0], resize_shape[1]), np.uint8)# 隨機(jī)生成一個(gè)batch的起點(diǎn)rand = random.randint(0, len(imageList) - batch_size)for j in range(batch_size):img = readTif(train_image_path + "\\" + imageList[rand + j])img = img.swapaxes(1, 0)img = img.swapaxes(1, 2)# 改變圖像尺寸至特定尺寸(# 因?yàn)閞esize用的不多,我就用了OpenCV實(shí)現(xiàn)的,這個(gè)不支持多波段,需要的話可以用np進(jìn)行resizeif(resize_shape != None):img = cv2.resize(img, (resize_shape[0], resize_shape[1]))img_generator[j] = imglabel = readTif(train_label_path + "\\" + labelList[rand + j]).astype(np.uint8)# 若為彩色,轉(zhuǎn)為灰度if(len(label.shape) == 3):label = label.swapaxes(1, 0)label = label.swapaxes(1, 2)label = cv2.cvtColor(label, cv2.COLOR_RGB2GRAY)if(resize_shape != None):label = cv2.resize(label, (resize_shape[0], resize_shape[1]))label_generator[j] = labelimg_generator, label_generator = dataPreprocess(img_generator, label_generator, classNum, colorDict_GRAY)yield (img_generator,label_generator)3 Unet模型編寫(xiě)
這里我們對(duì)Unet添加BN層和Dropout層,優(yōu)化器選用Adam,損失函數(shù)為交叉熵函數(shù)。利用keras編寫(xiě)實(shí)現(xiàn):
from keras.models import Model from keras.layers import Input, BatchNormalization, Conv2D, MaxPooling2D, Dropout, concatenate, merge, UpSampling2D from keras.optimizers import Adamdef unet(pretrained_weights = None, input_size = (256, 256, 4), classNum = 2, learning_rate = 1e-5):inputs = Input(input_size)# 2D卷積層conv1 = BatchNormalization()(Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs))conv1 = BatchNormalization()(Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1))# 對(duì)于空間數(shù)據(jù)的最大池化pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)conv2 = BatchNormalization()(Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1))conv2 = BatchNormalization()(Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2))pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)conv3 = BatchNormalization()(Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2))conv3 = BatchNormalization()(Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3))pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)conv4 = BatchNormalization()(Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3))conv4 = BatchNormalization()(Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4))# Dropout正規(guī)化,防止過(guò)擬合drop4 = Dropout(0.5)(conv4)pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)conv5 = BatchNormalization()(Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4))conv5 = BatchNormalization()(Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5))drop5 = Dropout(0.5)(conv5)# 上采樣之后再進(jìn)行卷積,相當(dāng)于轉(zhuǎn)置卷積操作up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))try:merge6 = concatenate([drop4,up6],axis = 3)except:merge6 = merge([drop4,up6], mode = 'concat', concat_axis = 3)conv6 = BatchNormalization()(Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6))conv6 = BatchNormalization()(Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6))up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))try:merge7 = concatenate([conv3,up7],axis = 3)except:merge7 = merge([conv3,up7], mode = 'concat', concat_axis = 3)conv7 = BatchNormalization()(Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7))conv7 = BatchNormalization()(Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7))up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))try:merge8 = concatenate([conv2,up8],axis = 3)except:merge8 = merge([conv2,up8],mode = 'concat', concat_axis = 3)conv8 = BatchNormalization()(Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8))conv8 = BatchNormalization()(Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8))up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))try:merge9 = concatenate([conv1,up9],axis = 3)except:merge9 = merge([conv1,up9],mode = 'concat', concat_axis = 3)conv9 = BatchNormalization()(Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9))conv9 = BatchNormalization()(Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9))conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)conv10 = Conv2D(classNum, 1, activation = 'softmax')(conv9)model = Model(inputs = inputs, outputs = conv10)# 用于配置訓(xùn)練模型(優(yōu)化器、目標(biāo)函數(shù)、模型評(píng)估標(biāo)準(zhǔn))model.compile(optimizer = Adam(lr = learning_rate), loss = 'categorical_crossentropy', metrics = ['accuracy'])# 如果有預(yù)訓(xùn)練的權(quán)重if(pretrained_weights):model.load_weights(pretrained_weights)return model改進(jìn)Unet
4 模型訓(xùn)練
至此,我們可以編寫(xiě)模型訓(xùn)練的代碼了:
''' 數(shù)據(jù)集相關(guān)參數(shù) ''' # 訓(xùn)練數(shù)據(jù)圖像路徑 train_image_path = "Data\\train\\image" # 訓(xùn)練數(shù)據(jù)標(biāo)簽路徑 train_label_path = "Data\\train\\label" # 驗(yàn)證數(shù)據(jù)圖像路徑 validation_image_path = "Data\\validation\\image" # 驗(yàn)證數(shù)據(jù)標(biāo)簽路徑 validation_label_path = "Data\\validation\\label"''' 模型相關(guān)參數(shù) ''' # 批大小 batch_size = 2 # 類(lèi)的數(shù)目(包括背景) classNum = 2 # 模型輸入圖像大小 input_size = (512, 512, 3) # 訓(xùn)練模型的迭代總輪數(shù) epochs = 100 # 初始學(xué)習(xí)率 learning_rate = 1e-4 # 預(yù)訓(xùn)練模型地址 premodel_path = None # 訓(xùn)練模型保存地址 model_path = "Model\\unet_model.hdf5"# 訓(xùn)練數(shù)據(jù)數(shù)目 train_num = len(os.listdir(train_image_path)) # 驗(yàn)證數(shù)據(jù)數(shù)目 validation_num = len(os.listdir(validation_image_path)) # 訓(xùn)練集每個(gè)epoch有多少個(gè)batch_size steps_per_epoch = train_num / batch_size # 驗(yàn)證集每個(gè)epoch有多少個(gè)batch_size validation_steps = validation_num / batch_size # 標(biāo)簽的顏色字典,用于onehot編碼 colorDict_RGB, colorDict_GRAY = color_dict(train_label_path, classNum)# 得到一個(gè)生成器,以batch_size的速率生成訓(xùn)練數(shù)據(jù) train_Generator = trainGenerator(batch_size,train_image_path, train_label_path,classNum ,colorDict_GRAY,input_size)# 得到一個(gè)生成器,以batch_size的速率生成驗(yàn)證數(shù)據(jù) validation_data = trainGenerator(batch_size,validation_image_path,validation_label_path,classNum,colorDict_GRAY,input_size) # 定義模型 model = unet(pretrained_weights = premodel_path, input_size = input_size, classNum = classNum, learning_rate = learning_rate) # 打印模型結(jié)構(gòu) model.summary() # 回調(diào)函數(shù) model_checkpoint = ModelCheckpoint(model_path,monitor = 'loss',verbose = 1,# 日志顯示模式:0->安靜模式,1->進(jìn)度條,2->每輪一行save_best_only = True)# 獲取當(dāng)前時(shí)間 start_time = datetime.datetime.now()# 模型訓(xùn)練 history = model.fit_generator(train_Generator,steps_per_epoch = steps_per_epoch,epochs = epochs,callbacks = [model_checkpoint],validation_data = validation_data,validation_steps = validation_steps)# 訓(xùn)練總時(shí)間 end_time = datetime.datetime.now() log_time = "訓(xùn)練總時(shí)間: " + str((end_time - start_time).seconds / 60) + "m" print(log_time) with open('TrainTime.txt','w') as f:f.write(log_time)5 繪制loss/acc曲線圖
model.fit_generator()函數(shù)可以返回loss和acc數(shù)據(jù),然后我們利用matplotlib繪制:
# 保存并繪制loss,acc acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] book = xlwt.Workbook(encoding='utf-8', style_compression=0) sheet = book.add_sheet('test', cell_overwrite_ok=True) for i in range(len(acc)):sheet.write(i, 0, acc[i])sheet.write(i, 1, val_acc[i])sheet.write(i, 2, loss[i])sheet.write(i, 3, val_loss[i]) book.save(r'AccAndLoss.xls') epochs = range(1, len(acc) + 1) plt.plot(epochs, acc, 'r', label = 'Training acc') plt.plot(epochs, val_acc, 'b', label = 'Validation acc') plt.title('Training and validation accuracy') plt.legend() plt.savefig("accuracy.png",dpi = 300) plt.figure() plt.plot(epochs, loss, 'r', label = 'Training loss') plt.plot(epochs, val_loss, 'b', label = 'Validation loss') plt.title('Training and validation loss') plt.legend() plt.savefig("loss.png", dpi = 300) plt.show()Training and validation loss
Training and validation accuracy
6 模型預(yù)測(cè)
模型預(yù)測(cè)時(shí)數(shù)據(jù)格式要和預(yù)測(cè)時(shí)保持一致,也需要利用生成器:
# 測(cè)試數(shù)據(jù)生成器 # test_iamge_path 測(cè)試數(shù)據(jù)路徑 # resize_shape resize大小 def testGenerator(test_iamge_path, resize_shape = None):imageList = os.listdir(test_iamge_path)for i in range(len(imageList)):img = readTif(test_iamge_path + "\\" + imageList[i])img = img.swapaxes(1, 0)img = img.swapaxes(1, 2)# 歸一化img = img / 255.0if(resize_shape != None):# 改變圖像尺寸至特定尺寸img = cv2.resize(img, (resize_shape[0], resize_shape[1]))# 將測(cè)試圖片擴(kuò)展一個(gè)維度,與訓(xùn)練時(shí)的輸入[batch_size,img.shape]保持一致img = np.reshape(img, (1, ) + img.shape)yield img模型預(yù)測(cè)出的結(jié)果需要進(jìn)行onehot解碼并渲染保存結(jié)果:
# 保存結(jié)果 # test_iamge_path 測(cè)試數(shù)據(jù)圖像路徑 # test_predict_path 測(cè)試數(shù)據(jù)圖像預(yù)測(cè)結(jié)果路徑 # model_predict 模型的預(yù)測(cè)結(jié)果 # color_dict 顏色詞典 def saveResult(test_image_path, test_predict_path, model_predict, color_dict, output_size):imageList = os.listdir(test_image_path)for i, img in enumerate(model_predict):channel_max = np.argmax(img, axis = -1)img_out = np.uint8(color_dict[channel_max.astype(np.uint8)])# 修改差值方式為最鄰近差值img_out = cv2.resize(img_out, (output_size[0], output_size[1]), interpolation = cv2.INTER_NEAREST)# 保存為無(wú)損壓縮pngcv2.imwrite(test_predict_path + "\\" + imageList[i][:-4] + ".png", img_out)7 遙感圖像大圖像預(yù)測(cè)
如果將較大的待分類(lèi)遙感影像直接輸入到網(wǎng)絡(luò)模型中會(huì)造成內(nèi)存溢出,故一般將待分類(lèi)圖像裁剪為一系列較小圖像分別輸入網(wǎng)絡(luò)進(jìn)行預(yù)測(cè),然后將預(yù)測(cè)結(jié)果按照裁剪順序拼接成一張最終結(jié)果圖像。詳見(jiàn)我下面這個(gè)博客:
馨意:遙感大圖像深度學(xué)習(xí)忽略邊緣(劃窗)預(yù)測(cè)?zhuanlan.zhihu.com
8 精度評(píng)定
我們使用精確率(Precision)、召回率(Recall)、F1分?jǐn)?shù)(F1-Score)、交并比(Intersection-over-Union, IoU)、平均交并比(mean Intersection-over-Union, mIoU)、頻權(quán)交并比(Frequency Weighted Intersection-over-Union, FWIoU)等指標(biāo)進(jìn)行精度評(píng)定。詳見(jiàn)我下面這個(gè)博客:
馨意:遙感圖像語(yǔ)義分割常用精度指標(biāo)及其python實(shí)現(xiàn)?zhuanlan.zhihu.com
?
9 全部代碼
Data-----train
----------image
----------label
-----validation
----------image
----------label
-----test
----------image
----------label
----------predict
Model
-----seg_unet.py
-----unet_model
dataProcess.py
train.py
test.py
seg_metrics.py 原文:https://zhuanlan.zhihu.com/p/161925744
總結(jié)
以上是生活随笔為你收集整理的keras遥感图像Unet语义分割(支持多波段多类)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 电脑技巧:Win10无线投屏功能介绍
- 下一篇: 电脑cpu温度过高怎么办_网络资讯:电脑