生活随笔
收集整理的這篇文章主要介紹了
使用VGG16网络结构训练自己的图像分类模型
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
配置
tensorflow2.4.0
python3.6
貓狗大戰數據集
代碼
VGG16網絡很著名,這里不再介紹。
keras里有預訓練好的VGG16,tensorflow2.0以后的版本中已經集成了keras。
解釋在代碼中。
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras import applications
from tensorflow.keras.layers import Dropout, Flatten, Dense
from tensorflow.keras.optimizers import SGD
import pickle
import numpy as np# 開啟GPU加速
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:tf.config.experimental.set_memory_growth(gpu, True)OUT_CATEGORIES = 2 # 分類數
batch_size = 2 # 批量大小
epochs = 50 # 迭代次數
imgSize = 256
def model():img_shape = (imgSize, imgSize, 3)# 加載不包含全連接層的VGG16網絡base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=img_shape)base_model.summary()# 根據分類數目增加自定義的全連接層,并與VGG16連接model_out = Sequential()model_out.add(Flatten(input_shape=base_model.output_shape[1:]))model_out.add(Dense(256, activation='relu'))model_out.add(Dropout(0.5))model_out.add(Dense(OUT_CATEGORIES, activation='sigmoid'))model = Model(inputs=base_model.input, outputs=model_out(base_model.output))model.compile(loss='binary_crossentropy', optimizer=SGD(lr=0.0001, momentum=0.9),metrics=['accuracy']) # 損失函數為二進制交叉熵,優化器為SGDreturn modelpickle_in = open("x.pickle", "rb")
x = pickle.load(pickle_in)pickle_in = open("y.pickle", "rb")
y = pickle.load(pickle_in)
y = np.array(y)
# 數據集分割為訓練集和測試集
train_num = int(x.shape[0] * 0.7)
test_num = x.shape[0] - train_num
# x歸一化,制作數據集時沒有歸一化
x = x/255
# 打亂
state = np.random.get_state()
np.random.shuffle(x)
np.random.set_state(state)
np.random.shuffle(y)train_x = x[0:train_num, :, :, :]
test_x = x[train_num:train_num+test_num, :, :, :]
train_y = y[0:train_num, :]
test_y = y[train_num:train_num+test_num, :]
# 將label轉換
train_label = keras.utils.to_categorical(train_y, OUT_CATEGORIES)
test_label = keras.utils.to_categorical(test_y, OUT_CATEGORIES)model = model()
model.fit(train_x, train_label, batch_size=batch_size, epochs=epochs, validation_data=(test_x, test_label), shuffle=True)
model.save("catDog.h5")
總結
以上是生活随笔為你收集整理的使用VGG16网络结构训练自己的图像分类模型的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。