CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: TFRecord フォーマットについて

TFRecord フォーマットは、TensorFlow がサポートしているデータセットの表現形式の一つ。 このフォーマットは、一言で表すと TensorFlow で扱うデータを Protocol Buffers でシリアライズしたものになっている。 特に、Dataset API との親和性に優れていたり、Cloud TPU を扱う上で実用上はほぼ必須といった特徴がある。 今回は、そんな TFRecord の扱い方について見ていくことにする。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.5
BuildVersion:   20G71
$ python -V
Python 3.9.6
$ pip list | grep -i tensorflow
tensorflow               2.5.0
tensorflow-datasets      4.3.0
tensorflow-estimator     2.5.0
tensorflow-metadata      1.1.0

もくじ

下準備

あらかじめ TensorFlow をインストールしておく。

$ pip install tensorflow tensorflow_datasets

そして、Python のインタプリタを起動する。

$ python

tensorflow パッケージを tf という名前でインポートしておく。

>>> import tensorflow as tf

概要

TFRecord フォーマットを TensorFlow の Python API から扱おうとすると、いくつかのオブジェクト (クラス) が登場する。 ただ、意外とその数が多いので、理解する上でとっつきにくさを生んでいる感じがある。 そこで、まずは一通りトップダウンで説明することにする。

それぞれの関係は、あるオブジェクトが別のオブジェクトを内包するようになっている。 階層構造で表すと、以下のような感じ。 階層構造で上にあるオブジェクトが、下にあるオブジェクトを内包する。

  • tf.Example
    • tf.train.Features
      • tf.train.Feature
        • tf.train.BytesList
        • tf.train.FloatList
        • tf.train.Int64List

tf.Example

tf.Example は、データセットに含まれる特定のサンプル (データポイント) に対応したオブジェクトになっている。 たとえば、教師あり学習のデータセットなら、あるサンプルの説明変数と目的変数のペアがこれに当たるイメージ。 ただ、サンプルに対応しているオブジェクトというだけなので、別に必要なら何を入れても構わない。 たとえば、画像データなら付随するメタデータとして横幅 (Width) と縦幅 (Height) のピクセル数が必要とかはあるはず。

このオブジェクトは単一の tf.train.Features というオブジェクトを内包する。

tf.train.Features

tf.train.Features は、名前から複数の特徴量を束ねるオブジェクトっぽいけど、まあ概ねその理解で正しいと思う。 概ね、というのは前述したとおりメタデータ的なものや説明変数も含まれるため。

このオブジェクトは複数の tf.train.Feature を内包する。

tf.train.Feature

tf.train.Feature は、特定の特徴量ないしメタデータや説明変数に対応したオブジェクト。

このオブジェクトは単一の tf.train.BytesList または tf.train.FloatList または tf.train.Int64List を内包する。

tf.train.BytesList

tf.train.BytesList は、特徴量としてバイト列のリストを扱うために用いるオブジェクト。

このオブジェクトは bytes 型のリストを内包する。 任意のバイト列を扱えるので、何らかのオブジェクトをシリアライズしたものを入れることができる。 詳しくは後述するけど、この特性は割と重要になってくる。 なぜなら、他の tf.train.FloatListtf.train.Int64List は一次元配列しか扱えないため。

tf.train.FloatList

tf.train.FloatList は、特徴量として浮動小数点のリストを扱うために用いるオブジェクト。

このオブジェクトは浮動小数点のリストを内包する。 前述したとおり、リストは一次元のものしか扱えない。

tf.train.Int64List

tf.train.Int64List は、特徴量として整数のリストを扱うために用いるオブジェクト。

このオブジェクトは整数のリストを内包する。 前述したとおり、リストは一次元のものしか扱えない。

基本的な使い方

一通りのオブジェクトの説明が終わったので、ここからは実際にコードを実行しながら試してみよう。 先ほどの説明とは反対に、ボトムアップでの実行になる。 これは、そうでないとオブジェクトを組み立てられないため。

まず、最もプリミティブなオブジェクトである tf.train.Int64Listtf.train.FloatListtf.train.BytesList から。 これらは前述したとおりバイト列・浮動小数点・整数のリストを内包するオブジェクトになっている。

>>> int64_list = tf.train.Int64List(value=[1, 2, 3])
>>> int64_list
value: 1
value: 2
value: 3

>>> float_list = tf.train.FloatList(value=[1., 2., 3.])
>>> float_list
value: 1.0
value: 2.0
value: 3.0

>>> bytes_list = tf.train.BytesList(value=[b'x', b'y', b'z'])
>>> bytes_list
value: "x"
value: "y"
value: "z"

前述したとおり、value には一次元配列しか渡せないらしい。 渡そうとすると次のようにエラーになる。

>>> import numpy as np
>>> x = np.random.randint(low=0, high=100, size=(3, 2))
>>> tf.train.Int64List(value=x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: only integer scalar arrays can be converted to a scalar index
>>> tf.train.Int64List(value=[[1, 2], [3, 4]])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: [1, 2] has type list, but expected one of: int, long

この仕様だと、画像データとか扱うときに面倒くさくない?と思うはず。 そんなときは、多次元配列を次のようにバイト列にシリアライズしてやれば良い。

>>> serialized_x = tf.io.serialize_tensor(x)
>>> serialized_x
<tf.Tensor: shape=(), dtype=string, numpy=b'\x08\t\x12\x08\x12\x02\x08\x03\x12\x02\x08\x02"0\x08\x00\x00\x00\x00\x00\x00\x00^\x00\x00\x00\x00\x00\x00\x00\x0f\x00\x00\x00\x00\x00\x00\x006\x00\x00\x00\x00\x00\x00\x00H\x00\x00\x00\x00\x00\x00\x00G\x00\x00\x00\x00\x00\x00\x00'>

バイト列になっていれば tf.train.BytesList に入れることができる。

>>> tf.train.BytesList(value=[serialized_x.numpy()])
value: "\010\t\022\010\022\002\010\003\022\002\010\002\"0\010\000\000\000\000\000\000\000^\000\000\000\000\000\000\000\017\000\000\000\000\000\000\0006\000\000\000\000\000\000\000H\000\000\000\000\000\000\000G\000\000\000\000\000\000\000"

なお、もちろん多次元配列は Flatten して、別で保存しておいた shape の情報を使って復元してもかまわない。

続いては tf.train.Feature を使って先ほどの *List オブジェクトをラップする。 型ごとに引数が異なるため、そこだけ注意する。

>>> int64_feature = tf.train.Feature(int64_list=int64_list)
>>> int64_feature
int64_list {
  value: 1
  value: 2
  value: 3
}

>>> float_feature = tf.train.Feature(float_list=float_list)
>>> float_feature
float_list {
  value: 1.0
  value: 2.0
  value: 3.0
}

>>> bytes_feature = tf.train.Feature(bytes_list=bytes_list)
>>> bytes_feature
bytes_list {
  value: "x"
  value: "y"
  value: "z"
}

続いては、tf.train.Features を使って、複数の tf.train.Feature を束ねる。

>>> feature_mappings = {
...     'feature0': int64_feature,
...     'feature1': float_feature,
...     'feature2': bytes_feature,
... }
>>> features = tf.train.Features(feature=feature_mappings)
>>> features
feature {
  key: "feature0"
  value {
    int64_list {
      value: 1
      value: 2
      value: 3
    }
  }
}
feature {
  key: "feature1"
  value {
    float_list {
      value: 1.0
      value: 2.0
      value: 3.0
    }
  }
}
feature {
  key: "feature2"
  value {
    bytes_list {
      value: "x"
      value: "y"
      value: "z"
    }
  }
}

あとは tf.train.Example でラップするだけ。

>>> example = tf.train.Example(features=features)
>>> example
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 1
        value: 2
        value: 3
      }
    }
  }
  feature {
    key: "feature1"
    value {
      float_list {
        value: 1.0
        value: 2.0
        value: 3.0
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "x"
        value: "y"
        value: "z"
      }
    }
  }
}

上記で完成した tf.train.Example オブジェクトがデータセットの中の特定のサンプルに対応することになる。 まあ、使っているのがダミーデータなのでちょっとイメージがつきにくいかもしれないけど。

tf.train.Example オブジェクトは SerializeToString() メソッドを使うことでバイト列にシリアライズできる。 つまり、.tfrecord の拡張子がついた TFRecord ファイルは、このシリアライズされたバイト列が書き込まれている。

>>> serialized_data = example.SerializeToString()
>>> serialized_data
b'\nL\n\x17\n\x08feature2\x12\x0b\n\t\n\x01x\n\x01y\n\x01z\n\x1c\n\x08feature1\x12\x10\x12\x0e\n\x0c\x00\x00\x80?\x00\x00\x00@\x00\x00@@\n\x13\n\x08feature0\x12\x07\x1a\x05\n\x03\x01\x02\x03'

ちなみに、これまでに登場したオブジェクトも、それぞれ単独で SerializeToString() を使えばシリアライズできる。

>>> int64_list.SerializeToString()
b'\n\x03\x01\x02\x03'
>>> int64_feature.SerializeToString()
b'\x1a\x05\n\x03\x01\x02\x03'
>>> features.SerializeToString()
b'\n\x13\n\x08feature0\x12\x07\x1a\x05\n\x03\x01\x02\x03\n\x17\n\x08feature2\x12\x0b\n\t\n\x01x\n\x01y\n\x01z\n\x1c\n\x08feature1\x12\x10\x12\x0e\n\x0c\x00\x00\x80?\x00\x00\x00@\x00\x00@@'

そして、シリアライズしたバイト列は、tf.train.Example.FromString() 関数を使ってデシリアライズできる。

>>> deserialized_object = tf.train.Example.FromString(serialized_data)
>>> deserialized_object
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 1
        value: 2
        value: 3
      }
    }
  }
  feature {
    key: "feature1"
    value {
      float_list {
        value: 1.0
        value: 2.0
        value: 3.0
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "x"
        value: "y"
        value: "z"
      }
    }
  }
}

データセットを TFRecord ファイルに変換する

基本的な使い方がわかったところで、続いては実際にデータセットを TFRecord 形式のファイルに変換してみよう。

使う題材は特に何でも良いんだけど、今回は tensorflow-datasets 経由でロードした CIFAR10 を使うことにする。

>>> import tensorflow_datasets as tfds
>>> ds_train = tfds.load('cifar10', split='train')

このデータセットには (32, 32, 3) の形状を持った画像のテンソルと、それに対応したラベルが入っている。 画像のデータが一次元になっていないので、わざわざ Flatten する代わりに前述したシリアライズしてバイト列にする作戦でいこう。

>>> from pprint import pprint
>>> pprint(ds_train.element_spec)
{'id': TensorSpec(shape=(), dtype=tf.string, name=None),
 'image': TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None),
 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}

まず、特定のサンプルに対応したテンソルとラベルを前述した手順でシリアライズする関数を次のように定義する。

>>> def serialize_example(image, label):
...     """1 サンプルを Protocol Buffers で TFRecord フォーマットにシリアライズする関数"""
...     # 画像データをバイト列にシリアライズする
...     serialized_image = tf.io.serialize_tensor(image)
...     image_bytes_list = tf.train.BytesList(value=[serialized_image.numpy()])
...     # ラベルデータ
...     label_int64_list = tf.train.Int64List(value=[label.numpy()])
...     # 特徴量を Features にまとめる
...     feature_mappings = {
...         'image': tf.train.Feature(bytes_list=image_bytes_list),
...         'label': tf.train.Feature(int64_list=label_int64_list),
...     }
...     features = tf.train.Features(feature=feature_mappings)
...     # Example にまとめる
...     example_proto = tf.train.Example(features=features)
...     # バイト列にする
...     return example_proto.SerializeToString()
... 

続いて、データセットから取り出したサンプルに上記の関数を定義するヘルパー関数を次のように定義する。

>>> def tf_serialize_example(element):
...     """シリアライズ処理を tf.data.Dataset に適用するためのヘルパー関数"""
...     # イメージとラベルを取り出す
...     image = element['image']
...     label = element['label']
...     tf_string = tf.py_function(
...         serialize_example, 
...         (image, label),
...         tf.string,
...     )
...     return tf.reshape(tf_string, ())
... 

Dataset API を使って、上記の関数をデータセットに適用する。

>>> serialized_ds_train = ds_train.map(tf_serialize_example)

イテレータにしてサンプルをひとつ取り出してみよう。

>>> ite = iter(serialized_ds_train)
>>> next(ite)
<tf.Tensor: shape=(), dtype=string, numpy=b'\n\xb6\x18\n\x0e\n\x05label\x12\x05\x1a\x03\n\x01\x07\n\xa3\x18\n\x05image\x12\x99
...

ちゃんとシリアライズされたバイト列が確認できる。

あとは、シリアライズしたバイト列が取り出せる Dataset オブジェクトを引数にして tf.data.experimental.TFRecordWriter#write() を呼び出すだけ。

>>> filename = 'cifar10-train.tfrecord'
>>> writer = tf.data.experimental.TFRecordWriter(filename)
>>> writer.write(serialized_ds_train)

上記はデータセットを丸ごと 1 つのファイルにしてる。 公式ドキュメントを見ると、パフォーマンスを考えると 100 ~ 200MB 程度のサイズで複数に分割するのがおすすめらしい。 これは、おそらく GCS とかにアップロードして並列で読み出すときの話。

カレントディレクトリを確認すると、次のようにファイルが書き出されているはず。

$ du -m cifar10-train.tfrecord
161    cifar10-train.tfrecord
$ file cifar10-train.tfrecord 
cifar10-train.tfrecord: data

TFRecord ファイルからデータを読み出す

次は上記のファイルを読み込んでデシリアライズする。 まず、tf.data.TFRecordDataset に TFRecord ファイルのパスを指定する。 これで、シリアライズしたバイト列を読み出せる Dataset オブジェクトが得られる。

>>> loaded_ds_train = tf.data.TFRecordDataset(filename)

上記からは tf.Example に対応したバイト列が 1 つずつ読み出せる。 なので、それを元のテンソルに戻す関数を次のように定義する。

>>> def deserialize_example(example_proto):
...     """バイト列をデシリアライズしてオブジェクトに戻す関数"""
...     # バイト列のフォーマット
...     feature_description = {
...         'image': tf.io.FixedLenFeature([], tf.string),
...         'label': tf.io.FixedLenFeature([], tf.int64),
...     }
...     # Tensor オブジェクトの入った辞書に戻す
...     parsed_element = tf.io.parse_single_example(example_proto,
...                                                 feature_description)
...     # 画像はバイト列になっているのでテンソルに戻す
...     parsed_element['image'] = tf.io.parse_tensor(parsed_element['image'],
...                                                  out_type=tf.uint8)
...     return parsed_element
... 

上記を先ほどの Dataset オブジェクトに適用する。

>>> deserialized_ds_train = loaded_ds_train.map(deserialize_example)

試しに中身を取り出してみると、ちゃんと画像とラベルのテンソルが復元できていることがわかる。

>>> ite = iter(deserialized_ds_train)
>>> next(ite)
{'image': <tf.Tensor: shape=(32, 32, 3), dtype=uint8, numpy=
array([[[143,  96,  70],
        [141,  96,  72],
        [135,  93,  72],
        ...,
        [ 96,  37,  19],
        [105,  42,  18],
        [104,  38,  20]],

       [[128,  98,  92],
        [146, 118, 112],
        [170, 145, 138],
        ...,
        [108,  45,  26],
        [112,  44,  24],
        [112,  41,  22]],

       [[ 93,  69,  75],
        [118,  96, 101],
        [179, 160, 162],
        ...,
        [128,  68,  47],
        [125,  61,  42],
        [122,  59,  39]],

       ...,

       [[187, 150, 123],
        [184, 148, 123],
        [179, 142, 121],
        ...,
        [198, 163, 132],
        [201, 166, 135],
        [207, 174, 143]],

       [[187, 150, 117],
        [181, 143, 115],
        [175, 136, 113],
        ...,
        [201, 164, 132],
        [205, 168, 135],
        [207, 171, 139]],

       [[195, 161, 126],
        [187, 153, 123],
        [186, 151, 128],
        ...,
        [212, 177, 147],
        [219, 185, 155],
        [221, 187, 157]]], dtype=uint8)>, 'label': <tf.Tensor: shape=(), dtype=int64, numpy=7>}

いじょう。

参考

www.tensorflow.org