keras如何保存模型
使用model.save(filepath)將Keras模型和權重保存在一個HDF5文件中,該文件將包含:
模型的結構,以便重構該模型 模型的權重 訓練配置(損失函數,優化器等) 優化器的狀態,以便于從上次訓練中斷的地方開始使用keras.models.load_model(filepath)來重新實例化你的模型,如果文件中存儲了訓練配置的話,該函數還會同時完成模型的編譯
只保存模型結構,而不包含其權重或配置信息
#保存成json格式的文件 # save as JSONjson_string = model.to_json() open('my_model_architecture.json','w').write(json_string) from keras.models import model_from_json model = model_from_json(open('my_model_architecture.json').read()) #保存成yaml文件 # save as YAML yaml_string = model.to_yaml() open('my_model_architectrue.yaml','w').write(yaml_string) from keras.models import model_from_yaml model = model_from_yaml(open('my_model_architecture.yaml').read())#這項操作將把模型序列化為json或yaml文件,這些文件對人而言也是友好的,如果需要的話你甚至可以手動打開這些文件并進行編輯。當然,你也可以從保存好的json文件或yaml文件中載入模型:# model reconstruction from JSON: from keras.modelsimport model_from_json model = model_from_json(json_string) # model reconstruction from YAML model =model_from_yaml(yaml_string)
需要保存模型的權重
import keras.models import load_model model.save_weights('my_model_weights.h5') #需要在代碼中初始化一個完全相同的模型 model.load_weights('my_model_weights.h5') #需要加載權重到不同的網絡結構(有些層一樣)中,例如fine-tune或transfer-learning,可以通過層名字來加載模型 model.load_weights('my_model_weights.h5', by_name=True) open('my_model_architecture.json','w').write(json_string) model.save_weights('my_model_weights.h5') model = model_from_json(open('my_model_architecture.json').read()) model.load_weights('my_model_weights.h5')
實時保存模型結構、訓練出來的權重、及優化器狀態并調用
keras 的callback參數可以幫助我們實現在訓練過程中的適當時機被調用。實現實時保存訓練模型以及訓練參數
keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1 )1. filename:字符串,保存模型的路徑 2. monitor:需要監視的值 3. verbose:信息展示模式,0或1 4. save_best_only:當設置為True時,將只保存在驗證集上性能最好的模型 5. mode:‘auto’,‘min’,‘max’之一,在save_best_only=True時決定性能最佳模型的評判準則,例如,當監測值為val_acc時,模式應為max,當檢測值為val_loss時,模式應為min。在auto模式下,評價準則由被監測值的名字自動推斷。 6. save_weights_only:若設置為True,則只保存模型權重,否則將保存整個模型(包括模型結構,配置信息等) 7. period:CheckPoint之間的間隔的epoch數示例
""" 假如原模型為:model = Sequential()model.add(Dense(2, input_dim=3, name="dense_1"))model.add(Dense(3, name="dense_2"))...model.save_weights(fname) """ # new model model = Sequential() model.add(Dense(2, input_dim=3, name="dense_1")) # will be loaded model.add(Dense(10, name="new_dense")) # will not be loaded# load weights from first model; will only affect the first layer, dense_1. model.load_weights(fname, by_name=True)
How to Check-Point Deep Learning Models in Keras
Checkpoint Neural Network Model Improvements
# Checkpoint the weights when validation accuracy improves from keras.models import Sequential from keras.layers import Dense from keras.callbacks import ModelCheckpoint import matplotlib.pyplot as plt import numpy # fix random seed for reproducibility seed = 7 numpy.random.seed(seed) # load pima indians dataset dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",") # split into input (X) and output (Y) variables X = dataset[:,0:8] Y = dataset[:,8] # create model model = Sequential() model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu')) model.add(Dense(8, kernel_initializer='uniform', activation='relu')) model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid')) # Compile model model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # checkpoint filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5" checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max') callbacks_list = [checkpoint] # Fit the model model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)Running the example produces the following output (truncated for brevity):
... Epoch 00134: val_acc did not improve Epoch 00135: val_acc did not improve Epoch 00136: val_acc did not improve Epoch 00137: val_acc did not improve Epoch 00138: val_acc did not improve Epoch 00139: val_acc did not improve Epoch 00140: val_acc improved from 0.83465 to 0.83858, saving model to weights-improvement-140-0.84.hdf5 Epoch 00141: val_acc did not improve Epoch 00142: val_acc did not improve Epoch 00143: val_acc did not improve Epoch 00144: val_acc did not improve Epoch 00145: val_acc did not improve Epoch 00146: val_acc improved from 0.83858 to 0.84252, saving model to weights-improvement-146-0.84.hdf5 Epoch 00147: val_acc did not improve Epoch 00148: val_acc improved from 0.84252 to 0.84252, saving model to weights-improvement-148-0.84.hdf5 Epoch 00149: val_acc did not improveYou will see a number of files in your working directory containing the network weights in HDF5 format. For example:
... weights-improvement-53-0.76.hdf5 weights-improvement-71-0.76.hdf5 weights-improvement-77-0.78.hdf5 weights-improvement-99-0.78.hdf5Checkpoint Best Neural Network Model Only
# Checkpoint the weights for best model on validation accuracy from keras.models import Sequential from keras.layers import Dense from keras.callbacks import ModelCheckpoint import matplotlib.pyplot as plt import numpy # fix random seed for reproducibility seed = 7 numpy.random.seed(seed) # load pima indians dataset dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",") # split into input (X) and output (Y) variables X = dataset[:,0:8] Y = dataset[:,8] # create model model = Sequential() model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu')) model.add(Dense(8, kernel_initializer='uniform', activation='relu')) model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid')) # Compile model model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # checkpoint filepath="weights.best.hdf5" checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max') callbacks_list = [checkpoint] # Fit the model model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)Running this example provides the following output (truncated for brevity):
... Epoch 00139: val_acc improved from 0.79134 to 0.79134, saving model to weights.best.hdf5 Epoch 00140: val_acc did not improve Epoch 00141: val_acc did not improve Epoch 00142: val_acc did not improve Epoch 00143: val_acc did not improve Epoch 00144: val_acc improved from 0.79134 to 0.79528, saving model to weights.best.hdf5 Epoch 00145: val_acc improved from 0.79528 to 0.79528, saving model to weights.best.hdf5 Epoch 00146: val_acc did not improve Epoch 00147: val_acc did not improve Epoch 00148: val_acc did not improve Epoch 00149: val_acc did not improveYou should see the weight file in your local directory.
weights.best.hdf5Loading a Check-Pointed Neural Network Model
# How to load and use weights from a checkpoint from keras.models import Sequential from keras.layers import Dense from keras.callbacks import ModelCheckpoint import matplotlib.pyplot as plt import numpy # fix random seed for reproducibility seed = 7 numpy.random.seed(seed) # create model model = Sequential() model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu')) model.add(Dense(8, kernel_initializer='uniform', activation='relu')) model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid')) # load weights model.load_weights("weights.best.hdf5") # Compile model (required to make predictions) model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) print("Created model and loaded weights from file") # load pima indians dataset dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",") # split into input (X) and output (Y) variables X = dataset[:,0:8] Y = dataset[:,8] # estimate accuracy on whole dataset using loaded weights scores = model.evaluate(X, Y, verbose=0) print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))Running the example produces the following output
Created model and loaded weights from file acc: 77.73%參考文獻
How to Check-Point Deep Learning Models in Keras
http://blog.csdn.net/u010159842/article/details/54602217
用Keras搞一個閱讀理解機器人
Keras中文文檔
如何保存Keras模型
人工神經網絡(三) –keras模型的保存和使用
總結
以上是生活随笔為你收集整理的keras如何保存模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Linux tar.gz、tar、bz2
- 下一篇: 简洁美观QQ在线客服漂浮窗口代码