Keras-数据准备
簡介
本文主要以 Caltech101 圖片數據集為例,講解 Keras 中的數據處理模塊(數據讀入、預處理、增強等)。
數據集構建
本文使用比較經典的 Caltech101 數據集,共含有 101 個類別,如下圖,其中BACKGROUND_Google子文件夾為雜項,無法分類,使用該數據集時刪除該文件夾即可。
這里不妨將數據集重構為常見數據集格式,這樣便于后面說明 Keras 的數據加載 API。具體重構數據集的代碼可在文末 Github 找到,這里不做贅述,最后生成數據集如下,分為訓練集、驗證集和測試集(比例 8:1:1),每個文件夾下有 101 個子文件夾代表 101 個類別的圖片。
數據劃分完成后就要制作相關的數據集說明文件,在很多大型的數據集中經??吹竭@種文件且一般是csv 格式的文件,該文件一般存放所有圖片的路徑及其標簽(包含的就是所有數據的說明)。生成了三個說明文件如下,圖中示例的是訓練集的說明文件。這部分的具體代碼也可以在文末 Github 找到。
Keras數據讀取API
上一節,構建了比較標準的數據集及數據集說明文件,Keras對于標準格式存儲的數據集封裝了非常合適的數據加載相關的API,這部分API都在Keras模塊下的preprocessing模塊中,主要封裝三種格式的數據,分別為圖像、序列、文本(對應模塊名為image,sequence,text),本系列文章均以圖像數據為主,其他類型數據加載可以查看TensorFlow官方文檔相關部分。
tf.keras.preprocessing.image下封裝了一些方法如img_to_array、array_to_img、load_img、save_img等,但是這些都是瑣碎的對具體圖片的處理,對整個數據集進行處理的關鍵是tf.keras.preprocessing.image.ImageDataGenerator這個類,我們通過該類實例化一個數據集生成器對象,該對象不包含具體數據集的數據,只含有對數據的處理手段。
具體參數如下,包含大部分常用的數據增強的方法如ZCA白化、圖像標準化、隨機旋轉、隨機平移、翻轉等,具體參數含義可以查看我關于Keras數據增強的文章,這里不多贅述。
tf.keras.preprocessing.image.ImageDataGenerator(featurewise_center=False, samplewise_center=False,featurewise_std_normalization=False, samplewise_std_normalization=False,zca_whitening=False, zca_epsilon=1e-06, rotation_range=0, width_shift_range=0.0,height_shift_range=0.0, brightness_range=None, shear_range=0.0, zoom_range=0.0,channel_shift_range=0.0, fill_mode='nearest', cval=0.0, horizontal_flip=False,vertical_flip=False, rescale=None, preprocessing_function=None,data_format=None, validation_split=0.0, dtype=None )這里構造三個生成器,對應訓練集、驗證集、測試集,由于訓練集用于訓練可以進行數據增強(簡單進行了翻轉、旋轉等預處理方法),驗證集和測試集為了驗證模型效果,不能進行數據增強。
train_gen = keras.preprocessing.image.ImageDataGenerator(rescale=1/255.,horizontal_flip=True,shear_range=0.2,width_shift_range=0.1 ) valid_gen = keras.preprocessing.image.ImageDataGenerator(rescale=1/255. ) test_gen = keras.preprocessing.image.ImageDataGenerator(rescale=1/255. )獲得設定了數據預處理方法的數據集生成器,那么具體的數據怎么讀取呢?事實上,ImageDataGenerator對象封裝了三個flow開頭的方法,分別為flow、flow_from_directory以及flow_from_dataframe。flow表示從張量中批量產生數據,會迭代返回直到取完整個張量,使用不多;flow_from_directory和flow_from_dataframe是很常用的數據加載方法,他們依據數據集文件夾或者數據集說明文件讀取Dataframe到本地進行數據讀取,每次讀取一個批次,占用內存和顯存較小,符合實際訓練需求。
flow(x,y=None,batch_size=32,shuffle=True,sample_weight=None,seed=None,save_to_dir=None,save_prefix='',save_format='png',subset=None ) flow_from_directory(directory,target_size=(256, 256),color_mode='rgb',classes=None,class_mode='categorical',batch_size=32,shuffle=True,seed=None,save_to_dir=None,save_prefix='',save_format='png',follow_links=False,subset=None,interpolation='nearest' ) flow_from_dataframe(dataframe,directory=None,x_col="filename",y_col="class",weight_col=None,target_size=(256, 256),color_mode='rgb',classes=None,class_mode='categorical',batch_size=32,shuffle=True,seed=None,save_to_dir=None,save_prefix='',save_format='png',subset=None,interpolation='nearest',validate_filenames=True,**kwargs )上述三個數據生成的方法具體參數在我的數據增強博文中解釋了常用的一些,其他的可以參考官方文檔。
例如,使用flow_from_directory讀取上一節生成數據集的訓練集,具體代碼和結果如下(第一行輸出是因為generator獲得具體數據后會進行一個默認信息的輸出,共6907張圖片,按照給定的32的批尺寸,需要迭代215步)。
train_generator = train_gen.flow_from_directory(directory="../data/Caltech101/train/",target_size=(224, 224),batch_size=32,class_mode='categorical' )print("class number", train_generator.classes) print("images number", train_generator.n) print("steps", train_generator.n // train_generator.batch_size) Found 6907 images belonging to 101 classes. class number [ 0 0 0 ... 100 100 100] images number 6907 steps 215再例如,使用flow_from_dataframe按照數據集說明文件讀取數據(DataFrame使用Pandas預先讀取),該方法實際上是上一種方法的變種,當數據集沒有按照文件夾劃分訓練和測試,而是由說明文件劃分時,該方法非常實用。示例代碼和運行結果如下(這里directory參數為空是因為說明文件給出的就是對于當前目錄的數據集目錄,而該方法是按照當前目錄+directory參數目錄+dataframe指定目錄進行索引,故此處為空即可)。
df_train = pd.read_csv('../data/desc_train.csv', encoding='utf8') df_train['class'] = df_train['class'].astype(str)train_generator = train_gen.flow_from_dataframe(dataframe=df_train,directory="",x_col='file_name',y_col='class',target_size=(224, 224),batch_size=32,class_mode='categorical' )print("class number", train_generator.classes) print("images number", train_generator.n) print("steps", train_generator.n // train_generator.batch_size) Found 6907 images belonging to 101 classes. class number [ 0 0 0 ... 100 100 100] images number 6907 steps 215數據使用
現在通過構造完整的數據生成器,有了獲得具體數據的途徑,事實上這個生成器就是一個數據迭代器而已,可以類似Pytorch動態圖那樣通過循環訪問每一批次的數據,代碼如下;也可以通過Keras對Generator封裝的訓練方法fit_generator一鍵實現訓練,這點我們后面的文章提到。
for step, (x, y) in enumerate(train_generator):print(x.shape)print(y.shape)補充說明
本文主要介紹使用Keras對圖像數據的加載、增廣、使用等,具體代碼可以查看我的Github,歡迎Star或者Fork。
總結
以上是生活随笔為你收集整理的Keras-数据准备的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Keras-简介
- 下一篇: Fastai-数据准备