CV:基于Keras利用CNN主流架构之mini_XCEPTION训练性别分类模型hdf5并保存到指定文件夹下
生活随笔
收集整理的這篇文章主要介紹了
CV:基于Keras利用CNN主流架构之mini_XCEPTION训练性别分类模型hdf5并保存到指定文件夹下
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
CV:基于Keras利用CNN主流架構之mini_XCEPTION訓練性別分類模型hdf5并保存到指定文件夾下
?
?
目錄
圖示過程
核心代碼
?
?
?
圖示過程
?
核心代碼
from keras.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping from keras.callbacks import ReduceLROnPlateau from models.cnn import mini_XCEPTION# parameters1、定義參數(shù):每個batch的采樣本數(shù)、訓練輪數(shù)、輸入shape、部分比例分離用于驗證、冗長參數(shù)、分類個數(shù)、patience、do_random_crop batch_size = 32 num_epochs = 1000 validation_split = .2 do_random_crop = False #random crop only works for classification since the current implementation does no transform bounding boxes patience = 100 num_classes = 2 dataset_name = 'imdb' input_shape = (64, 64, 1)#if判斷,然后指定圖像、log、loghdf5各自保存路徑 if input_shape[2] == 1:grayscale = True images_path = '../datasets/imdb_crop/' log_file_path = '../trained_models/gender_models/gender_training.log' trained_models_path = '../trained_models/gender_models/gender_mini_XCEPTION'# data loader data_loader = DataManager(dataset_name) #自定義DataManager函數(shù)實現(xiàn)根據(jù)數(shù)據(jù)集name進行加載 ground_truth_data = data_loader.get_data() #自定義get_data函數(shù)根據(jù)不同數(shù)據(jù)集name得到各自的ground truth data, train_keys, val_keys = split_imdb_data(ground_truth_data, validation_split) print('Number of training samples:', len(train_keys)) print('Number of validation samples:', len(val_keys))#調(diào)用ImageDataGenerator函數(shù)實現(xiàn)實時數(shù)據(jù)增強生成小批量的圖像數(shù)據(jù)。 image_generator = ImageGenerator(ground_truth_data, batch_size,input_shape[:2],train_keys, val_keys, None,path_prefix=images_path,vertical_flip_probability=0,grayscale=grayscale,do_random_crop=do_random_crop)# model parameters/compilation2、建立XCEPTION模型并compile編譯配置參數(shù),最后輸出網(wǎng)絡摘要 model = mini_XCEPTION(input_shape, num_classes) model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy']) model.summary()#3、指定要訓練的數(shù)據(jù)集(gender→imdb即男女數(shù)據(jù)集)# model callbacks # callbacks4、回調(diào):通過調(diào)用CSVLogger、EarlyStopping、ReduceLROnPlateau、ModelCheckpoint等函數(shù)得到訓練參數(shù)存到一個list內(nèi) early_stop = EarlyStopping('val_loss', patience=patience) reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1,patience=int(patience/2), verbose=1) csv_logger = CSVLogger(log_file_path, append=False) model_names = trained_models_path + '.{epoch:02d}-{val_acc:.2f}.hdf5' model_checkpoint = ModelCheckpoint(model_names,monitor='val_loss',verbose=1,save_best_only=True,save_weights_only=False) callbacks = [model_checkpoint, csv_logger, early_stop, reduce_lr]# training model5、調(diào)用fit_generator函數(shù)訓練模型 model.fit_generator(image_generator.flow(mode='train'),steps_per_epoch=int(len(train_keys) / batch_size),epochs=num_epochs, verbose=1,callbacks=callbacks,validation_data=image_generator.flow('val'),validation_steps=int(len(val_keys) / batch_size))?
?
總結
以上是生活随笔為你收集整理的CV:基于Keras利用CNN主流架构之mini_XCEPTION训练性别分类模型hdf5并保存到指定文件夹下的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: CV:基于Keras利用CNN主流架构之
- 下一篇: CV:基于keras利用cv2自带两步检