TFRecord フォーマットは、TensorFlow がサポートしているデータセットの表現形式の一つ。 このフォーマットは、一言で表すと TensorFlow で扱うデータを Protocol Buffers でシリアライズしたものになっている。 特に、Dataset API との親和性に優れていたり、Cloud TPU を扱う上で実用上はほぼ必須といった特徴がある。 今回は、そんな TFRecord の扱い方について見ていくことにする。
使った環境は次のとおり。
$ sw_vers ProductName: macOS ProductVersion: 11.5 BuildVersion: 20G71 $ python -V Python 3.9.6 $ pip list | grep -i tensorflow tensorflow 2.5.0 tensorflow-datasets 4.3.0 tensorflow-estimator 2.5.0 tensorflow-metadata 1.1.0
もくじ
下準備
あらかじめ TensorFlow をインストールしておく。
$ pip install tensorflow tensorflow_datasets
そして、Python のインタプリタを起動する。
$ python
tensorflow
パッケージを tf
という名前でインポートしておく。
>>> import tensorflow as tf
概要
TFRecord フォーマットを TensorFlow の Python API から扱おうとすると、いくつかのオブジェクト (クラス) が登場する。 ただ、意外とその数が多いので、理解する上でとっつきにくさを生んでいる感じがある。 そこで、まずは一通りトップダウンで説明することにする。
それぞれの関係は、あるオブジェクトが別のオブジェクトを内包するようになっている。 階層構造で表すと、以下のような感じ。 階層構造で上にあるオブジェクトが、下にあるオブジェクトを内包する。
tf.Example
tf.train.Features
tf.train.Feature
tf.train.BytesList
tf.train.FloatList
tf.train.Int64List
tf.Example
tf.Example
は、データセットに含まれる特定のサンプル (データポイント) に対応したオブジェクトになっている。
たとえば、教師あり学習のデータセットなら、あるサンプルの説明変数と目的変数のペアがこれに当たるイメージ。
ただ、サンプルに対応しているオブジェクトというだけなので、別に必要なら何を入れても構わない。
たとえば、画像データなら付随するメタデータとして横幅 (Width) と縦幅 (Height) のピクセル数が必要とかはあるはず。
このオブジェクトは単一の tf.train.Features
というオブジェクトを内包する。
tf.train.Features
tf.train.Features
は、名前から複数の特徴量を束ねるオブジェクトっぽいけど、まあ概ねその理解で正しいと思う。
概ね、というのは前述したとおりメタデータ的なものや説明変数も含まれるため。
このオブジェクトは複数の tf.train.Feature
を内包する。
tf.train.Feature
tf.train.Feature
は、特定の特徴量ないしメタデータや説明変数に対応したオブジェクト。
このオブジェクトは単一の tf.train.BytesList
または tf.train.FloatList
または tf.train.Int64List
を内包する。
tf.train.BytesList
tf.train.BytesList
は、特徴量としてバイト列のリストを扱うために用いるオブジェクト。
このオブジェクトは bytes
型のリストを内包する。
任意のバイト列を扱えるので、何らかのオブジェクトをシリアライズしたものを入れることができる。
詳しくは後述するけど、この特性は割と重要になってくる。
なぜなら、他の tf.train.FloatList
や tf.train.Int64List
は一次元配列しか扱えないため。
tf.train.FloatList
tf.train.FloatList
は、特徴量として浮動小数点のリストを扱うために用いるオブジェクト。
このオブジェクトは浮動小数点のリストを内包する。 前述したとおり、リストは一次元のものしか扱えない。
tf.train.Int64List
tf.train.Int64List
は、特徴量として整数のリストを扱うために用いるオブジェクト。
このオブジェクトは整数のリストを内包する。 前述したとおり、リストは一次元のものしか扱えない。
基本的な使い方
一通りのオブジェクトの説明が終わったので、ここからは実際にコードを実行しながら試してみよう。 先ほどの説明とは反対に、ボトムアップでの実行になる。 これは、そうでないとオブジェクトを組み立てられないため。
まず、最もプリミティブなオブジェクトである tf.train.Int64List
と tf.train.FloatList
と tf.train.BytesList
から。
これらは前述したとおりバイト列・浮動小数点・整数のリストを内包するオブジェクトになっている。
>>> int64_list = tf.train.Int64List(value=[1, 2, 3]) >>> int64_list value: 1 value: 2 value: 3 >>> float_list = tf.train.FloatList(value=[1., 2., 3.]) >>> float_list value: 1.0 value: 2.0 value: 3.0 >>> bytes_list = tf.train.BytesList(value=[b'x', b'y', b'z']) >>> bytes_list value: "x" value: "y" value: "z"
前述したとおり、value
には一次元配列しか渡せないらしい。
渡そうとすると次のようにエラーになる。
>>> import numpy as np >>> x = np.random.randint(low=0, high=100, size=(3, 2)) >>> tf.train.Int64List(value=x) Traceback (most recent call last): File "<stdin>", line 1, in <module> TypeError: only integer scalar arrays can be converted to a scalar index >>> tf.train.Int64List(value=[[1, 2], [3, 4]]) Traceback (most recent call last): File "<stdin>", line 1, in <module> TypeError: [1, 2] has type list, but expected one of: int, long
この仕様だと、画像データとか扱うときに面倒くさくない?と思うはず。 そんなときは、多次元配列を次のようにバイト列にシリアライズしてやれば良い。
>>> serialized_x = tf.io.serialize_tensor(x) >>> serialized_x <tf.Tensor: shape=(), dtype=string, numpy=b'\x08\t\x12\x08\x12\x02\x08\x03\x12\x02\x08\x02"0\x08\x00\x00\x00\x00\x00\x00\x00^\x00\x00\x00\x00\x00\x00\x00\x0f\x00\x00\x00\x00\x00\x00\x006\x00\x00\x00\x00\x00\x00\x00H\x00\x00\x00\x00\x00\x00\x00G\x00\x00\x00\x00\x00\x00\x00'>
バイト列になっていれば tf.train.BytesList
に入れることができる。
>>> tf.train.BytesList(value=[serialized_x.numpy()]) value: "\010\t\022\010\022\002\010\003\022\002\010\002\"0\010\000\000\000\000\000\000\000^\000\000\000\000\000\000\000\017\000\000\000\000\000\000\0006\000\000\000\000\000\000\000H\000\000\000\000\000\000\000G\000\000\000\000\000\000\000"
なお、もちろん多次元配列は Flatten して、別で保存しておいた shape の情報を使って復元してもかまわない。
続いては tf.train.Feature
を使って先ほどの *List オブジェクトをラップする。
型ごとに引数が異なるため、そこだけ注意する。
>>> int64_feature = tf.train.Feature(int64_list=int64_list) >>> int64_feature int64_list { value: 1 value: 2 value: 3 } >>> float_feature = tf.train.Feature(float_list=float_list) >>> float_feature float_list { value: 1.0 value: 2.0 value: 3.0 } >>> bytes_feature = tf.train.Feature(bytes_list=bytes_list) >>> bytes_feature bytes_list { value: "x" value: "y" value: "z" }
続いては、tf.train.Features
を使って、複数の tf.train.Feature
を束ねる。
>>> feature_mappings = { ... 'feature0': int64_feature, ... 'feature1': float_feature, ... 'feature2': bytes_feature, ... } >>> features = tf.train.Features(feature=feature_mappings) >>> features feature { key: "feature0" value { int64_list { value: 1 value: 2 value: 3 } } } feature { key: "feature1" value { float_list { value: 1.0 value: 2.0 value: 3.0 } } } feature { key: "feature2" value { bytes_list { value: "x" value: "y" value: "z" } } }
あとは tf.train.Example
でラップするだけ。
>>> example = tf.train.Example(features=features) >>> example features { feature { key: "feature0" value { int64_list { value: 1 value: 2 value: 3 } } } feature { key: "feature1" value { float_list { value: 1.0 value: 2.0 value: 3.0 } } } feature { key: "feature2" value { bytes_list { value: "x" value: "y" value: "z" } } } }
上記で完成した tf.train.Example
オブジェクトがデータセットの中の特定のサンプルに対応することになる。
まあ、使っているのがダミーデータなのでちょっとイメージがつきにくいかもしれないけど。
tf.train.Example
オブジェクトは SerializeToString()
メソッドを使うことでバイト列にシリアライズできる。
つまり、.tfrecord
の拡張子がついた TFRecord ファイルは、このシリアライズされたバイト列が書き込まれている。
>>> serialized_data = example.SerializeToString() >>> serialized_data b'\nL\n\x17\n\x08feature2\x12\x0b\n\t\n\x01x\n\x01y\n\x01z\n\x1c\n\x08feature1\x12\x10\x12\x0e\n\x0c\x00\x00\x80?\x00\x00\x00@\x00\x00@@\n\x13\n\x08feature0\x12\x07\x1a\x05\n\x03\x01\x02\x03'
ちなみに、これまでに登場したオブジェクトも、それぞれ単独で SerializeToString()
を使えばシリアライズできる。
>>> int64_list.SerializeToString() b'\n\x03\x01\x02\x03' >>> int64_feature.SerializeToString() b'\x1a\x05\n\x03\x01\x02\x03' >>> features.SerializeToString() b'\n\x13\n\x08feature0\x12\x07\x1a\x05\n\x03\x01\x02\x03\n\x17\n\x08feature2\x12\x0b\n\t\n\x01x\n\x01y\n\x01z\n\x1c\n\x08feature1\x12\x10\x12\x0e\n\x0c\x00\x00\x80?\x00\x00\x00@\x00\x00@@'
そして、シリアライズしたバイト列は、tf.train.Example.FromString()
関数を使ってデシリアライズできる。
>>> deserialized_object = tf.train.Example.FromString(serialized_data) >>> deserialized_object features { feature { key: "feature0" value { int64_list { value: 1 value: 2 value: 3 } } } feature { key: "feature1" value { float_list { value: 1.0 value: 2.0 value: 3.0 } } } feature { key: "feature2" value { bytes_list { value: "x" value: "y" value: "z" } } } }
データセットを TFRecord ファイルに変換する
基本的な使い方がわかったところで、続いては実際にデータセットを TFRecord 形式のファイルに変換してみよう。
使う題材は特に何でも良いんだけど、今回は tensorflow-datasets 経由でロードした CIFAR10 を使うことにする。
>>> import tensorflow_datasets as tfds >>> ds_train = tfds.load('cifar10', split='train')
このデータセットには (32, 32, 3)
の形状を持った画像のテンソルと、それに対応したラベルが入っている。
画像のデータが一次元になっていないので、わざわざ Flatten する代わりに前述したシリアライズしてバイト列にする作戦でいこう。
>>> from pprint import pprint >>> pprint(ds_train.element_spec) {'id': TensorSpec(shape=(), dtype=tf.string, name=None), 'image': TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}
まず、特定のサンプルに対応したテンソルとラベルを前述した手順でシリアライズする関数を次のように定義する。
>>> def serialize_example(image, label): ... """1 サンプルを Protocol Buffers で TFRecord フォーマットにシリアライズする関数""" ... # 画像データをバイト列にシリアライズする ... serialized_image = tf.io.serialize_tensor(image) ... image_bytes_list = tf.train.BytesList(value=[serialized_image.numpy()]) ... # ラベルデータ ... label_int64_list = tf.train.Int64List(value=[label.numpy()]) ... # 特徴量を Features にまとめる ... feature_mappings = { ... 'image': tf.train.Feature(bytes_list=image_bytes_list), ... 'label': tf.train.Feature(int64_list=label_int64_list), ... } ... features = tf.train.Features(feature=feature_mappings) ... # Example にまとめる ... example_proto = tf.train.Example(features=features) ... # バイト列にする ... return example_proto.SerializeToString() ...
続いて、データセットから取り出したサンプルに上記の関数を定義するヘルパー関数を次のように定義する。
>>> def tf_serialize_example(element): ... """シリアライズ処理を tf.data.Dataset に適用するためのヘルパー関数""" ... # イメージとラベルを取り出す ... image = element['image'] ... label = element['label'] ... tf_string = tf.py_function( ... serialize_example, ... (image, label), ... tf.string, ... ) ... return tf.reshape(tf_string, ()) ...
Dataset API を使って、上記の関数をデータセットに適用する。
>>> serialized_ds_train = ds_train.map(tf_serialize_example)
イテレータにしてサンプルをひとつ取り出してみよう。
>>> ite = iter(serialized_ds_train) >>> next(ite) <tf.Tensor: shape=(), dtype=string, numpy=b'\n\xb6\x18\n\x0e\n\x05label\x12\x05\x1a\x03\n\x01\x07\n\xa3\x18\n\x05image\x12\x99 ...
ちゃんとシリアライズされたバイト列が確認できる。
あとは、シリアライズしたバイト列が取り出せる Dataset オブジェクトを引数にして tf.data.experimental.TFRecordWriter#write()
を呼び出すだけ。
>>> filename = 'cifar10-train.tfrecord'
>>> writer = tf.data.experimental.TFRecordWriter(filename)
>>> writer.write(serialized_ds_train)
上記はデータセットを丸ごと 1 つのファイルにしてる。 公式ドキュメントを見ると、パフォーマンスを考えると 100 ~ 200MB 程度のサイズで複数に分割するのがおすすめらしい。 これは、おそらく GCS とかにアップロードして並列で読み出すときの話。
カレントディレクトリを確認すると、次のようにファイルが書き出されているはず。
$ du -m cifar10-train.tfrecord
161 cifar10-train.tfrecord
$ file cifar10-train.tfrecord
cifar10-train.tfrecord: data
TFRecord ファイルからデータを読み出す
次は上記のファイルを読み込んでデシリアライズする。
まず、tf.data.TFRecordDataset
に TFRecord ファイルのパスを指定する。
これで、シリアライズしたバイト列を読み出せる Dataset オブジェクトが得られる。
>>> loaded_ds_train = tf.data.TFRecordDataset(filename)
上記からは tf.Example
に対応したバイト列が 1 つずつ読み出せる。
なので、それを元のテンソルに戻す関数を次のように定義する。
>>> def deserialize_example(example_proto): ... """バイト列をデシリアライズしてオブジェクトに戻す関数""" ... # バイト列のフォーマット ... feature_description = { ... 'image': tf.io.FixedLenFeature([], tf.string), ... 'label': tf.io.FixedLenFeature([], tf.int64), ... } ... # Tensor オブジェクトの入った辞書に戻す ... parsed_element = tf.io.parse_single_example(example_proto, ... feature_description) ... # 画像はバイト列になっているのでテンソルに戻す ... parsed_element['image'] = tf.io.parse_tensor(parsed_element['image'], ... out_type=tf.uint8) ... return parsed_element ...
上記を先ほどの Dataset オブジェクトに適用する。
>>> deserialized_ds_train = loaded_ds_train.map(deserialize_example)
試しに中身を取り出してみると、ちゃんと画像とラベルのテンソルが復元できていることがわかる。
>>> ite = iter(deserialized_ds_train) >>> next(ite) {'image': <tf.Tensor: shape=(32, 32, 3), dtype=uint8, numpy= array([[[143, 96, 70], [141, 96, 72], [135, 93, 72], ..., [ 96, 37, 19], [105, 42, 18], [104, 38, 20]], [[128, 98, 92], [146, 118, 112], [170, 145, 138], ..., [108, 45, 26], [112, 44, 24], [112, 41, 22]], [[ 93, 69, 75], [118, 96, 101], [179, 160, 162], ..., [128, 68, 47], [125, 61, 42], [122, 59, 39]], ..., [[187, 150, 123], [184, 148, 123], [179, 142, 121], ..., [198, 163, 132], [201, 166, 135], [207, 174, 143]], [[187, 150, 117], [181, 143, 115], [175, 136, 113], ..., [201, 164, 132], [205, 168, 135], [207, 171, 139]], [[195, 161, 126], [187, 153, 123], [186, 151, 128], ..., [212, 177, 147], [219, 185, 155], [221, 187, 157]]], dtype=uint8)>, 'label': <tf.Tensor: shape=(), dtype=int64, numpy=7>}
いじょう。