tensorflow-TFRecord 文件详解
TFRecord 是 tensorflow 內(nèi)置的文件格式,它是一種二進(jìn)制文件,具有以下優(yōu)點(diǎn):
1. 統(tǒng)一各種輸入文件的操作
2. 更好的利用內(nèi)存,方便復(fù)制和移動(dòng)
3. 將二進(jìn)制數(shù)據(jù)和標(biāo)簽(label)存儲(chǔ)在同一個(gè)文件中
引言
我們先不講 TFRecord,因?yàn)橹v了你也不懂,認(rèn)識(shí)幾個(gè)操作吧
tf.train.Int64List(value=list_data)
它的作用是 把 list 中每個(gè)元素轉(zhuǎn)換成 key-value 形式,
注意,輸入必須是 list,且 list 中元素類型要相同,且與 Int 保持一致;
# value = tf.constant([1, 2]) ### 這會(huì)報(bào)錯(cuò)的 ss = 1 ### Int64List 對(duì)應(yīng)的元素只能是 int long,其他同理 tt = 2 out1 = tf.train.Int64List(value = [ss, tt]) print(out1) # value: 1 # value: 2 ss = [1 ,2] out2 = tf.train.Int64List(value = ss) print(out2) # value: 1 # value: 2
同類型的 方法還有 2 個(gè)
tf.train.FloatList tf.train.BytesList
tf.train.Feature(int64_list=)
它的作用是 構(gòu)建 一種類型的特征集,比如 整型
out = tf.train.Feature(int64_list=tf.train.Int64List(value=[33, 22]))
print(out)
# int64_list {
# value: 33
# value: 22
# }
也可以是其他類型
tf.train.Feature(float_list=tf.train.FloatList()) tf.train.Feature(bytes_list=tf.train.BytesList())
tf.train.Features(feature=dict_data)
它的作用是 構(gòu)建 多種類型 的特征集,可以 dict 格式表達(dá) 多種類型
ut = tf.train.Features(feature={
"suibian": tf.train.Feature(int64_list=tf.train.Int64List(value=[1, 2, 4])),
"a": tf.train.Feature(float_list=tf.train.FloatList(value=[5., 7.]))
})
print(out)
# feature {
# key: "a"
# value {
# float_list {
# value: 5.0
# value: 7.0
# }
# }
# }
# feature {
# key: "suibian"
# value {
# int64_list {
# value: 1
# value: 2
# value: 4
# }
# }
# }
tf.train.Example(features=tf.train.Features())
它的作用是創(chuàng)建一個(gè) 樣本,Example 對(duì)應(yīng)一個(gè)樣本
example = tf.train.Example(features=
tf.train.Features(feature={
'a': tf.train.Feature(int64_list=tf.train.Int64List(value=range(2))),
'b': tf.train.Feature(bytes_list=tf.train.BytesList(value=[b'm',b'n']))
}))
print(example)
# features {
# feature {
# key: "a"
# value {
# int64_list {
# value: 0
# value: 1
# }
# }
# }
# feature {
# key: "b"
# value {
# bytes_list {
# value: "m"
# value: "n"
# }
# }
# }
# }
一幅圖總結(jié)一下上面的代碼
Example 協(xié)議塊
它其實(shí)是一種 數(shù)據(jù)存儲(chǔ)的 格式,類似于 xml、json 等;
用上述方法實(shí)現(xiàn)該格式;
一個(gè) Example 協(xié)議塊對(duì)應(yīng)一個(gè)樣本,一個(gè)樣本有多種特征,每種特征下有多個(gè)元素,可參看上圖;
message Example{
Features features = 1;
}
message Features{
map<string,Features> feature = 1;
}
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloateList float_list = 2;
Int64List int64_list = 3;
}
}
TFRecord 文件就是以 Example協(xié)議塊 格式 存儲(chǔ)的;
TFRecord 文件
該類文件具有寫功能,且可以把其他類型的文件轉(zhuǎn)換成該類型文件,其實(shí)相當(dāng)于先讀取其他文件,再寫入 TFRecord 文件;
該類文件也具有讀功能;
TFRecord 存儲(chǔ)
存儲(chǔ)分兩步:
1.建立存儲(chǔ)器
2. 構(gòu)造每個(gè)樣本的 Example 協(xié)議塊
tf.python_io.TFRecordWriter(file_name)
構(gòu)造存儲(chǔ)器,存儲(chǔ)器有兩個(gè)常用方法
write(record):向文件中寫入一個(gè)樣本
close():關(guān)閉存儲(chǔ)器
注意:此處的 record 為一個(gè)序列化的 Example,通過(guò)Example.SerializeToString()來(lái)實(shí)現(xiàn),它的作用是將 Example 中的 map 壓縮為二進(jìn)制,節(jié)約大量空間
示例代碼1:將 MNIST 數(shù)據(jù)集保存成 TFRecord 文件
import tensorflow as tf
import numpy as np
import input_data
# 生成整數(shù)型的屬性
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
# 生成字符串類型的屬性,也就是圖像的內(nèi)容
def _string_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
# 讀取圖像數(shù)據(jù) 和一些屬性
mniset = input_data.read_data_sets('../../../data/MNIST_data',dtype=tf.uint8, one_hot=True)
images = mniset.train.images
labels = mniset.train.labels
pixels = images.shape[1] # (55000, 784)
num_examples = mniset.train.num_examples # 55000
file_name = 'output.tfrecords' ### 文件名
writer = tf.python_io.TFRecordWriter(file_name) ### 寫入器
for index in range(num_examples):
### 遍歷樣本
image_raw = images[index].tostring() ### 圖片轉(zhuǎn)成 字符型
example = tf.train.Example(features = tf.train.Features(feature = {
'pixel': _int64_feature(pixels),
'label': _int64_feature(np.argmax(labels[index])),
'image_raw': _string_feature(image_raw)
}))
writer.write(example.SerializeToString()) ### 寫入 TFRecord
writer.close()
示例代碼2:將 csv 保存成 TFRecord 文件
train_frame = pd.read_csv("../myfiles/xx3.csv")
train_labels_frame = train_frame.pop(item="label")
train_values = train_frame.values
train_labels = train_labels_frame.values
print("values shape: ", train_values.shape) # values shape: (2, 3)
print("labels shape:", train_labels.shape) # labels shape: (2,)
writer = tf.python_io.TFRecordWriter("xx3.tfrecords")
for i in range(train_values.shape[0]):
image_raw = train_values[i].tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
"image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels[i]]))
}
)
)
writer.write(record=example.SerializeToString())
writer.close()
示例3:將 png 文件保存成 TFRecord 文件
# filenames = tf.train.match_filenames_once('../myfiles/*.png')
filenames = glob.iglob('..myfiles*.png')
writer = tf.python_io.TFRecordWriter('png.tfrecords')
for filename in filenames:
img = Image.open(filename)
img_raw = img.tobytes()
label = 1
example = tf.train.Example(
features=tf.train.Features(
feature={
"image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}
)
)
writer.write(record=example.SerializeToString())
writer.close()
TFRecord 讀取
讀取文件 和 tensorflow 讀取數(shù)據(jù)方法類似,參考我的博客讀取數(shù)據(jù)
tf.TFRecordReader()
建立讀取器,有 read 和 close 方法
tf.parse_single_example(serialized,features=None,name= None)
解析單個(gè) Example 協(xié)議塊
serialized : 標(biāo)量字符串的Tensor,一個(gè)序列化的Example,文件經(jīng)過(guò)文件閱讀器之后的value
features :字典數(shù)據(jù),key為讀取的名字,value為FixedLenFeature
return : 一個(gè)鍵值對(duì)組成的字典,鍵為讀取的名字
features中的value還可以為tf.VarLenFeature(),但是這種方式用的比較少,它返回的是SparseTensor數(shù)據(jù),這是一種只存儲(chǔ)非零部分的數(shù)據(jù)格式,了解即可。
tf.FixedLenFeature(shape,dtype)
shape : 輸入數(shù)據(jù)的形狀,一般不指定,為空列表
dtype : 輸入數(shù)據(jù)類型,與存儲(chǔ)進(jìn)文件的類型要一致,類型只能是float32,int 64, string
return :返回一個(gè)定長(zhǎng)的 Tensor (即使有零的部分也存儲(chǔ))
示例代碼
filename = 'png.tfrecords'
file_queue = tf.train.string_input_producer([filename], shuffle=True)
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
### features 的 key 必須和 寫入時(shí) 一致,數(shù)據(jù)類型也必須一致,shape 可為 空
dict_data= tf.parse_single_example(value, features={'label': tf.FixedLenFeature(shape=(1,1), dtype=tf.int64),
'image_raw': tf.FixedLenFeature(shape=(), dtype=tf.string)})
label = tf.cast(dict_data['label'], tf.int32)
img = tf.decode_raw(dict_data['image_raw'], tf.uint8) ### 將 string、bytes 轉(zhuǎn)換成 int、float
image_tensor = tf.reshape(img, [500, 500, -1])
sess = tf.Session()
sess.run(tf.local_variables_initializer())
tf.train.start_queue_runners(sess=sess)
while 1:
# print(sess.run(key)) # b'png.tfrecords:0'
image = sess.run(image_tensor)
img_PIL = Image.fromarray(image)
img_PIL.show()
參考資料:
https://blog.csdn.net/chengshuhao1991/article/details/78656724 TensorFlow基礎(chǔ)5:TFRecords文件的存儲(chǔ)與讀取講解及代碼實(shí)現(xiàn)
總結(jié)
以上是生活随笔為你收集整理的tensorflow-TFRecord 文件详解的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 深度学习之卷积神经网络(11)卷积层变种
- 下一篇: 深度学习之卷积神经网络(12)深度残差网