迁移学习--Xception
生活随笔
收集整理的這篇文章主要介紹了
迁移学习--Xception
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
一.Xception的概述
Xception是inception處于極端假設的一種網絡結構。當卷積層試圖在三維空間(兩個空間維度和一個通道維度)進行卷積過程時,一個卷積核需要同時繪制跨通道相關性和空間相關性。
前面分享的inception模塊的思想就是將這一卷積過程分解成一系列相互獨立的操作,使其更為便捷有效。典型的inception模塊假設通道相關性和空間相關性的繪制有效脫鉤,而Xception的思想則是inception模塊思想的一種極端情況,即卷積神經網絡的特征圖中的跨通道相關性和空間相關性的繪制可以完全脫鉤。
Xception實現遷移學習也是基于微調的方式,和InceptionV3實現遷移學習一樣,在獲取基于imageNet預訓練完畢的Xception模型后,用自己搭建的全連接層(包括輸出層)代替xception模型的全連接層和輸出層,進而得到一個新的網絡模型,固定新網絡模型的部分參數,使其不參與訓練,基于mnist數據集訓練余下未固定的參數。
二.Xception實現遷移學習
代碼實現:
from keras.applications.xception import Xception from keras.datasets import mnist from keras.utils import np_utils from keras.layers import Dense,GlobalAveragePooling2D,Dropout,Input,UpSampling3D from keras.models import Model from matplotlib import pyplot as plt import numpy as np(X_train,Y_train),(X_test,Y_test)=mnist.load_data() X_test1=X_test Y_test1=Y_test X_train=X_train.reshape(-1,28,28,1).astype("float32")/255.0 X_test=X_test.reshape(-1,28,28,1).astype("float32")/255.0 Y_test=np_utils.to_categorical(Y_test,10) Y_train=np_utils.to_categorical(Y_train,10)#搭建xception模型 #weight="imagenet",xcception權重使用基于imagenet訓練獲得的權重,include_to=false代表不包含頂層的全連接層 base_model=Xception(weights="imagenet",include_top=False) input_xception=Input(shape=(28,28,1),dtype="float32",name="xception imput") #對數據進行上采樣,沿著數據的3個維度分別重復size[0],size[1],size[2] x=UpSampling3D(size=(3,3,3),data_format="channels_last")(input_xception) #將數據送入網絡 x=base_model(x) #此時模型沒有全連接層,需要自己搭建全連接層 #通過GlobalAveragePooling2D對每張二維特征圖進行全局平均池化,輸出對應一維數值 x=GlobalAveragePooling2D()(x) x=Dense(1024,activation="relu")(x) x=Dropout(0.5)(x) pre=Dense(10,activation="softmax")(x) #調用Model,定義一個新的模型Xception_model xception_model=Model(inputs=input_xception,outputs=pre) #查看每一層的名稱和對應的層數 for i,layer in enumerate(base_model.layers):print(i,layer.name) #固定base_model中前36層的參數,使其不參與訓練 for layer in base_model.layers[:36]:layer.trainable=False #查看模型的摘要 xception_model.summary() #編譯 xception_model.compile(loss="categorical_crossentropy",optimizer="adam",metrics=["accuracy"] ) #訓練 training=xception_model.fit(X_train,Y_train,epochs=5,batch_size=64,validation_split=0.2,verbose=1 )test=xception_model.evaluate(X_test,Y_test) print("誤差:",test[0]) print("準確值:",test[1])#畫出訓練集和驗證集的隨著時期的變化曲線 def plot_history(training_history,train,validation):plt.plot(training.history[train],linestyle="-",color="b")plt.plot(training.history[validation],linestyle="--",color="r")plt.title("xception_model accuracy")plt.xlabel("epochs")plt.ylabel("accuracy")plt.legend(["train","validation"],loc="lower right")plt.show() plot_history(training,"accuracy","val_accuracy") def plot_history1(training_history,train,validation):plt.plot(training.history[train],linestyle="-",color="b")plt.plot(training.history[validation],linestyle="--",color="r")plt.title("xception_model accuracy")plt.xlabel("epochs")plt.ylabel("loss")plt.legend(["train","validation"],loc="upper right")plt.show() plot_history1(training,"loss","val_loss")#預測值 prediction=xception_model.predict(X_test) #打印圖片 def plot_image(image):fig=plt.gcf()fig.set_size_inches(2,2)plt.imshow(image,cmap="binary")plt.show() def result(i):plot_image(X_test1[i])print("真實值:",Y_test1[i])print("預測值:",np.argmax(prediction[i])) result(0) result(1)?
?
總結
以上是生活随笔為你收集整理的迁移学习--Xception的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 迁移学习---inceptionV3
- 下一篇: Packet Tracer 思科模拟器之