TensorFlow(4)-TFRecord
TFRecord
- 1. tf.train.Example
- 1.1 tfrecord 數據范式轉化
- 1.2 demo 數據集構建
- 2. TFRecord 讀寫
- 2.1 寫入1-tf.io.TFRecordWriter()
- 2.3 讀取-tf.data.TFRecordDataset()
- 2.3 data -> dataset -> 存儲-tf.data.experimental.TFRecordWriter()
tfrecord 用于存儲二進制序列數據的一種范式,按順序存,按順序取。里面存的每一條數據都是一個 byte-string, 最常用的轉byte-string的方式是tf.train.Example 。tf.train.Example (or protobuf) 以字典{“string”: value}的形式存儲消息,這種消息存儲機制可讀性高。
demo1–tfrecord存儲
value can be a num / list / array pybyte_value = np.array(value).tobytes() # 0.轉Python字節數據 tfbyte_value = tf.train.BytesList(value=[pybyte_value]) # 1.轉tf.train 字節數據 feature_dict[key] = tf.train.Feature(bytes_list=tfbyte_value)# 2.轉tf.train.Feature()注意是tf.train.Feature()沒有s .......... feature_example = tf.train.Example(features=tf.train.Features(feature=tffeature_dict))# 3.轉tf.train.Example() 注意tf.train.Features()s exmp_serial = feature_example.SerializeToString() # 序列化feature_example tf_writer = tf.python_io.TFRecordWriter(tfrecord_path) # 構建tf寫句柄 tf_writer.write(exmp_serial) # 寫入tf文件 tf_writer.close() # 關閉句柄np.array().tobytes()構造包含數組中原始數據的Python字節數據
1. tf.train.Example
須將用戶數據轉化為tfrecord 約定的格式,才能使用tfrecord 格式存儲數據。
1.1 tfrecord 數據范式轉化
1-> tfrecord支持寫入三種格式的數據:string,int64,float32,分別通過tf.train.BytesList、tf.train.Int64List、tf.train.FloatList寫入tf.train.Feature中?!揪褪钦f數據要寫入tf.train.Feature前必須使用tf.train.BytesList,tf.train.Int64List,tf.train.FloatList必須使用強制類型轉換】
# python 數據類型轉tf.train.BytesList、tf.train.Int64List、tf.train.FloatList # tf.train.BytesList:string、byte # tf.train.FloatList:float (float32)、double (float64) # tf.train.Int64List :bool、enum、int32、uint32、int64、uint64 # 強制類型轉換 value = 1 value_ed = tf.train.Int64List(value=[value])2-> tf.train.Feature 接受tf.train.BytesList、tf.train.Int64List、tf.train.FloatList 類型的數據。以下為scalar 轉 tf.train.Feature 的快捷函數。 not scalar 的數據只需要用np.array().tobytes()/tf.io.serialize_tensor 轉換成binary-strings,然后使用以下借口函數封裝成 tf.train.Feature 即可。
# input : a scalar input # output: tf.train.Feature def _bytes_feature(value): """Returns a bytes_list from a string / byte."""if isinstance(value, type(tf.constant(0))):value = value.numpy() # BytesList won't unpack a string from an EagerTensor.return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def _float_feature(value):"""Returns a float_list from a float / double."""return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))def _int64_feature(value):"""Returns an int64_list from a bool / enum / int / uint."""return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))3->tf.train.Feature 構成特征字典 -> 特征字典 轉 Features message -> Features message 轉 tf.train.Example -> tf.train.Example 序列化后可以存入tfrecord 文件?!?Note that the tf.train.Example message is just a wrapper around the Features message:】
1.2 demo 數據集構建
構建一個包含10000個觀測數據的數據集,每條數據包含4個特征:[bool, label_index, lable_string, random_score]
n_observations = int(1e4) # The number of observations in the dataset. feature0 = np.random.choice([False, True], n_observations) #Boolean feature, encoded as False or True. feature1 = np.random.randint(0, 5, n_observations) # Integer feature, random from 0 to 4. strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat']) # String feature. feature2 = strings[feature1] feature3 = np.random.randn(n_observations) # Float feature, from a standard normal distribution.單個樣本轉tf.train.Feature-> tf.train.Features -> tf.train.Example()->SerializeToString() 接口函數
def serialize_example(feature0, feature1, feature2, feature3):# Create a Feature dict : {key: tf.train.Feature}feature = {'feature0': _int64_feature(feature0),'feature1': _int64_feature(feature1),'feature2': _bytes_feature(feature2),'feature3': _float_feature(feature3),}# Create a Features message and conver to tf.train.Example.example_proto = tf.train.Example(features=tf.train.Features(feature=feature))return example_proto.SerializeToString()觀測序列化[serialized_example ]和反序列化[tf.train.Example()]的結果
for i in range(n_observations):f0, f1, f2, f3 = feature0[i], feature1[i], feature2[i], feature3[i]# 序列化 tf.train.Example 消息serialized_example = serialize_example(f0, f1, f2, f3) # b'\nR\n\x14\n\x08feature2\x12\x08\n\x06\.....# 反序列化 tf.train.Exampleexample_proto = tf.train.Example.FromString(serialized_example) '''features {feature {key: "feature0"value {int64_list {value: 0}}}feature {key: "feature1"value {int64_list {value: 4}}}feature {key: "feature2"value {bytes_list {value: "goat"}}}feature {key: "feature3"value {float_list {value: 0.9876000285148621}}} }'''2. TFRecord 讀寫
tfrecord 中每一條record按照下面的范式存儲。tfrecord 文件中并非只能存tf.train.Example 序列化的結果,tf.train.Example 只是將字典序列化的一種方法。任何 byte-string都能夠存入TFRecord file。
uint64 length uint32 masked_crc32_of_length byte data[length] uint32 masked_crc32_of_data2.1 寫入1-tf.io.TFRecordWriter()
# Write the `tf.train.Example` observations to the file. with tf.io.TFRecordWriter(filename) as writer: # 獲取寫入句柄for i in range(n_observations):example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])writer.write(example)2.3 讀取-tf.data.TFRecordDataset()
# 讀取tfrecord文件, 獲取序列化的樣本 filenames = [filename] raw_dataset = tf.data.TFRecordDataset(filenames) # tf.data.Dataset 對象 for raw_record in raw_dataset.take(10): # 讀取前10 條print(repr(raw_record)) # raw_record序列化的樣本# 序列化樣本反序列化 # tf.data.Dataset 在圖中執行,feature_description能夠建立數據集shape和type的signature。 feature_description = {'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0), } def _parse_function(example_proto):# Parse the input `tf.train.Example` proto using the dictionary above.# 一次只解析一條數據: use tf.parse example 可以一次解析一個batch的數據return tf.io.parse_single_example(example_proto, feature_description)# 利用tf.data.Dataset.map 函數將_parse_function 應用于數據集raw_dataset中的每一個元素parsed_dataset = raw_dataset.map(_parse_function) # 可以用的數據 # {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.5251196>} # 讀取 filenames = [filename] raw_dataset = tf.data.TFRecordDataset(filenames) # tf.train.Example.ParseFromString反序列化 得到的是tf.train.Example features, 很難直接使用 for raw_record in raw_dataset.take(1):example = tf.train.Example()example.ParseFromString(raw_record.numpy()) # tf.train.Example features 轉 dict of numpy array result = {} for key, feature in example.features.feature.items():# The values are the Feature objects which contain a `kind` which contains:# one of three fields: bytes_list, float_list, int64_listkind = feature.WhichOneof('kind')result[key] = np.array(getattr(feature, kind).value)2.3 data -> dataset -> 存儲-tf.data.experimental.TFRecordWriter()
from_tensor_slices 將data 轉成dataset-> 序列化dataset 中的每一個元素-> 存入tf record 文件
features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3)) for f0,f1,f2,f3 in features_dataset.take(1): # 逐個獲取數據print(f0, f1, f2, f3)# tf.Tensor(False, shape=(), dtype=bool),tf.Tensor(4, shape=(), dtype=int64),tf.Tensor(b'goat', shape=(), dtype=string),tf.Tensor(0.5251196235602504, shape=(), dtype=float64)# 序列化方式1:tf.data.Dataset.map 映射數據集中的每一個元素 # 對于自定義的序列化操作函數serialize_example。為了使其成為TensorFlow graph 的節點,須使用 tf.py_function封裝;之后再使用tf.data.Dataset.map 映射序列化數據集中的每一個元素。 def tf_serialize_example(f0,f1,f2,f3):# (自定義函數,函數輸入,函數輸出)tf_string = tf.py_function(serialize_example,(f0, f1, f2, f3),tf.string)return tf.reshape(tf_string, ()) # The result is a scalar. serialized_features_dataset = features_dataset.map(tf_serialize_example)# 序列化方式2:tf.data.Dataset.from_generator()映射數據集中的每一個元素 def generator():for features in features_dataset:yield serialize_example(*features) serialized_features_dataset = tf.data.Dataset.from_generator(generator, output_types=tf.string, output_shapes=())整個序列化的數據集寫入tfrecord.
# 整個寫入tfrecord filename = 'test.tfrecord' writer = tf.data.experimental.TFRecordWriter(filename) # 與1.0 的接口有些不太一樣 writer.write(serialized_features_dataset)參考資料:TFRecord and tf.train.Example
總結
以上是生活随笔為你收集整理的TensorFlow(4)-TFRecord的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: C++: 21---引用和指针
- 下一篇: C++(STL):14--- forwa