loss低但精确度低_低光照图像增强网络-RetinexNet(model.py解析【2】)
論文地址:https://arxiv.org/pdf/1808.04560.pdf
代碼地址:https://github.com/weichen582/RetinexNet
解析目錄:https://zhuanlan.zhihu.com/p/88761829
整個模型架構(gòu)被實現(xiàn)為一個類:
class lowlight_enhance(object):其構(gòu)造函數(shù)實現(xiàn)了網(wǎng)絡(luò)結(jié)構(gòu)的搭建、損失函數(shù)的定義、訓(xùn)練的配置和參數(shù)的初始化,具體如下。
網(wǎng)絡(luò)結(jié)構(gòu)的搭建(該部分包括低/正常光照圖像輸入的定義以及Decom-Net、Enhance-Net和重建這三部分的對接,注意這里并沒有對Rlow進行去噪的部分):
# build the model self.input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low') self.input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high')[R_low, I_low] = DecomNet(self.input_low, layer_num=self.DecomNet_layer_num) [R_high, I_high] = DecomNet(self.input_high, layer_num=self.DecomNet_layer_num)I_delta = RelightNet(I_low, R_low)I_low_3 = concat([I_low, I_low, I_low]) I_high_3 = concat([I_high, I_high, I_high]) I_delta_3 = concat([I_delta, I_delta, I_delta])self.output_R_low = R_low self.output_I_low = I_low_3 self.output_I_delta = I_delta_3 self.output_S = R_low * I_delta_3損失函數(shù)的定義(該部分包括低/正常光照圖像的重建損失、反射分量一致性損失、光照分量平滑損失以及最后分別計算的Decom-Net和Enhance-Net的總損失):
# loss self.recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - self.input_low)) self.recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - self.input_high)) self.recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - self.input_low)) self.recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - self.input_high)) self.equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high)) self.relight_loss = tf.reduce_mean(tf.abs(R_low * I_delta_3 - self.input_high))self.Ismooth_loss_low = self.smooth(I_low, R_low) self.Ismooth_loss_high = self.smooth(I_high, R_high) self.Ismooth_loss_delta = self.smooth(I_delta, R_low)self.loss_Decom = self.recon_loss_low + self.recon_loss_high + 0.001 * self.recon_loss_mutal_low + 0.001 * self.recon_loss_mutal_high + 0.1 * self.Ismooth_loss_low + 0.1 * self.Ismooth_loss_high + 0.01 * self.equal_R_loss self.loss_Relight = self.relight_loss + 3 * self.Ismooth_loss_delta訓(xùn)練的配置(該部分包括學(xué)習(xí)率以及Decom-Net和Enhance-Net的優(yōu)化器設(shè)置):
self.lr = tf.placeholder(tf.float32, name='learning_rate') optimizer = tf.train.AdamOptimizer(self.lr, name='AdamOptimizer')self.var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name] self.var_Relight = [var for var in tf.trainable_variables() if 'RelightNet' in var.name]self.train_op_Decom = optimizer.minimize(self.loss_Decom, var_list = self.var_Decom) self.train_op_Relight = optimizer.minimize(self.loss_Relight, var_list = self.var_Relight)訓(xùn)練參數(shù)的初始化:
self.sess.run(tf.global_variables_initializer())self.saver_Decom = tf.train.Saver(var_list = self.var_Decom) self.saver_Relight = tf.train.Saver(var_list = self.var_Relight)print("[*] Initialize model successfully...")接下來是該類的一些成員函數(shù)。
def gradient(self, input_tensor, direction):self.smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])self.smooth_kernel_y = tf.transpose(self.smooth_kernel_x, [1, 0, 2, 3])if direction == "x":kernel = self.smooth_kernel_xelif direction == "y":kernel = self.smooth_kernel_yreturn tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))該函數(shù)實現(xiàn)的是通過與指定梯度算子進行卷積的方式求圖像的水平/垂直梯度圖。
def ave_gradient(self, input_tensor, direction):return tf.layers.average_pooling2d(self.gradient(input_tensor, direction), pool_size=3, strides=1, padding='SAME')該函數(shù)實現(xiàn)的是通過平均池化的方式來對圖像的水平/垂直梯度圖進行平滑。
def smooth(self, input_I, input_R):input_R = tf.image.rgb_to_grayscale(input_R)return tf.reduce_mean(self.gradient(input_I, "x") * tf.exp(-10 * self.ave_gradient(input_R, "x")) + self.gradient(input_I, "y") * tf.exp(-10 * self.ave_gradient(input_R, "y")))該函數(shù)是對光照分量平滑損失的具體實現(xiàn)(可對應(yīng)原論文中的公式來看)。
def evaluate(self, epoch_num, eval_low_data, sample_dir, train_phase):print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch_num))for idx in range(len(eval_low_data)):input_low_eval = np.expand_dims(eval_low_data[idx], axis=0)if train_phase == "Decom":result_1, result_2 = self.sess.run([self.output_R_low, self.output_I_low], feed_dict={self.input_low: input_low_eval})if train_phase == "Relight":result_1, result_2 = self.sess.run([self.output_S, self.output_I_delta], feed_dict={self.input_low: input_low_eval})save_images(os.path.join(sample_dir, 'eval_%s_%d_%d.png' % (train_phase, idx + 1, epoch_num)), result_1, result_2)該函數(shù)是對訓(xùn)練epoch_num次后的Decom-Net/Enhance-Net模型進行評估,并保存評估結(jié)果圖。
接下來是關(guān)于模型的訓(xùn)練:
def train(self, train_low_data, train_high_data, eval_low_data, batch_size, patch_size, epoch, lr, sample_dir, ckpt_dir, eval_every_epoch, train_phase):該函數(shù)中包含了預(yù)訓(xùn)練模型的加載、數(shù)據(jù)的讀取與處理、模型的訓(xùn)練、評估和保存這幾個部分。
assert len(train_low_data) == len(train_high_data) numBatch = len(train_low_data) // int(batch_size)檢查所有需要參與訓(xùn)練的低/正常光照樣本數(shù)量是否一致,若一致則計算訓(xùn)練集含有的batch數(shù)量。
# load pretrained model if train_phase == "Decom":train_op = self.train_op_Decomtrain_loss = self.loss_Decomsaver = self.saver_Decom elif train_phase == "Relight":train_op = self.train_op_Relighttrain_loss = self.loss_Relightsaver = self.saver_Relightload_model_status, global_step = self.load(saver, ckpt_dir) if load_model_status:iter_num = global_stepstart_epoch = global_step // numBatchstart_step = global_step % numBatchprint("[*] Model restore success!") else:iter_num = 0start_epoch = 0start_step = 0 print("[*] Not find pretrained model!")若存在Decom-Net/Enhance-Net對應(yīng)的預(yù)訓(xùn)練模型,則進行加載;否則從頭開始訓(xùn)練。
# generate data for a batch batch_input_low = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32") batch_input_high = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32") for patch_id in range(batch_size):h, w, _ = train_low_data[image_id].shapex = random.randint(0, h - patch_size)y = random.randint(0, w - patch_size)rand_mode = random.randint(0, 7)batch_input_low[patch_id, :, :, :] = data_augmentation(train_low_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)batch_input_high[patch_id, :, :, :] = data_augmentation(train_high_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)image_id = (image_id + 1) % len(train_low_data)if image_id == 0:tmp = list(zip(train_low_data, train_high_data))random.shuffle(list(tmp))train_low_data, train_high_data = zip(*tmp)順序讀取訓(xùn)練圖像,在每次讀取的低/正常光照圖像對上隨機取patch,并進行數(shù)據(jù)擴增(具體見 中對函數(shù)data_augmentation的描述)。這里,應(yīng)當(dāng)注意的是,訓(xùn)練數(shù)據(jù)每滿一個batch時將會重新打亂整個訓(xùn)練集。
# train _, loss = self.sess.run([train_op, train_loss], feed_dict={self.input_low: batch_input_low, self.input_high: batch_input_high, self.lr: lr[epoch]})print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss)) iter_num += 1訓(xùn)練一個iter并打印相關(guān)信息。
# evalutate the model and save a checkpoint file for it if (epoch + 1) % eval_every_epoch == 0:self.evaluate(epoch + 1, eval_low_data, sample_dir=sample_dir, train_phase=train_phase)self.save(saver, iter_num, ckpt_dir, "RetinexNet-%s" % train_phase)每訓(xùn)練eval_every_epoch次評估并保存一次模型。
保存指定iter的模型:
def save(self, saver, iter_num, ckpt_dir, model_name):if not os.path.exists(ckpt_dir):os.makedirs(ckpt_dir)print("[*] Saving model %s" % model_name)saver.save(self.sess, os.path.join(ckpt_dir, model_name), global_step=iter_num)加載最新的模型:
def load(self, saver, ckpt_dir):ckpt = tf.train.get_checkpoint_state(ckpt_dir)if ckpt and ckpt.model_checkpoint_path:full_path = tf.train.latest_checkpoint(ckpt_dir)try:global_step = int(full_path.split('/')[-1].split('-')[-1])except ValueError:global_step = Nonesaver.restore(self.sess, full_path)return True, global_stepelse:print("[*] Failed to load model from %s" % ckpt_dir)return False, 0最后是關(guān)于模型的測試(其中test_high_data并沒有用到):
def test(self, test_low_data, test_high_data, test_low_data_names, save_dir, decom_flag):該函數(shù)中包含了模型的加載、模型的測試和結(jié)果圖的保存這幾個部分。
tf.global_variables_initializer().run()print("[*] Reading checkpoint...") load_model_status_Decom, _ = self.load(self.saver_Decom, './model/Decom') load_model_status_Relight, _ = self.load(self.saver_Relight, './model/Relight') if load_model_status_Decom and load_model_status_Relight:print("[*] Load weights successfully...")初始化所有參數(shù)并加載最新的Decom-Net和Enhance-Net模型。
print("[*] Testing...") for idx in range(len(test_low_data)):print(test_low_data_names[idx])[_, name] = os.path.split(test_low_data_names[idx])suffix = name[name.find('.') + 1:]name = name[:name.find('.')]input_low_test = np.expand_dims(test_low_data[idx], axis=0)[R_low, I_low, I_delta, S] = self.sess.run([self.output_R_low, self.output_I_low, self.output_I_delta, self.output_S], feed_dict = {self.input_low: input_low_test})if decom_flag == 1:save_images(os.path.join(save_dir, name + "_R_low." + suffix), R_low)save_images(os.path.join(save_dir, name + "_I_low." + suffix), I_low)save_images(os.path.join(save_dir, name + "_I_delta." + suffix), I_delta)save_images(os.path.join(save_dir, name + "_S." + suffix), S)遍歷測試樣本進行測試,并保存最終結(jié)果圖(可自行指定是否保存Decom-Net的分解結(jié)果)。
歡迎關(guān)注公眾號:huangxiaobai880
總結(jié)
以上是生活随笔為你收集整理的loss低但精确度低_低光照图像增强网络-RetinexNet(model.py解析【2】)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 1. Xamarin开发入门
- 下一篇: Active Directory的用户属