EeBlog(テクニカルブログ)

AI(人工知能)実践 第12回 TFRecordの読み込み

前回は複数の画像データを読み込み、TFRecordに詰め込む処理を紹介しました。

今回は、TFRecordを読み込み、画像データを作成する処理を紹介します。
 

TFRecordから複数の画像に戻していきますが、
今回tf.data.TFRecordDatasetを使用して順に処理しています。
古いソースでは、「tf.train.Coodinator」「tf.train.start_queue_runners」などを使用した
サンプルが多いですが、現状は非推奨です。
サンプルソースを読む上で詰まった場合には、下記の公式サイトの説明も参考にしてみてください。
旧:https://www.tensorflow.org/api_guides/python/reading_data#_QueueRunner_
新:https://www.tensorflow.org/api_guides/python/input_dataset#Reader_classes

なお、今回は画像に戻す際の名前は「TFRecord_【ラベル】_【連番】.jpg」としています。
 

ソースコード

import os
import tensorflow as tf
from PIL import Image

TFRECORD_FILE = 'c:\\work\\test\\example.tfrecord'
OUTPUT_DIR = 'c:\\work\\test\\read'

def read_tfrecord(serialized):
    features = tf.parse_single_example(
        serialized,
        features={
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64),
            'image': tf.FixedLenFeature([], tf.string)
        })
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    label = tf.cast(features['label'], tf.int32)
    image = tf.decode_raw(features['image'], tf.uint8)    
    return height, width, label, image

with tf.Session() as sess:
    dataset = tf.data.TFRecordDataset([TFRECORD_FILE])
    dataset = dataset.map(read_tfrecord)
    iterator = dataset.make_initializable_iterator()
    next_dataset = iterator.get_next()
    sess.run(iterator.initializer)
    count = 1
    while(True):
        try:
            height, width, label, image = sess.run(next_dataset)            
            img = Image.frombytes('RGB', [height, width], image)
            img.save(os.path.join(OUTPUT_DIR, "TFRecord_{0}_{1}.jpg".format(label, count)))
            count += 1
        except tf.errors.OutOfRangeError:
            print("data end")
            break
        
    sess.close()

 

解説

TFRecordの読み込み用のソースを順に説明していきます。

ライブラリインポート

import os
import tensorflow as tf
from PIL import Image

特に目新しいものはありません。
TFRecord読み込み用にTensorFlowと画像処理用にPillowとファイル操作用にosをインポートしています。
 

データパス等の定義

TFRECORD_FILE = 'c:\\work\\test\\example.tfrecord'
TFRECIRD_CHECK_DIR = 'c:\\work\\test\\read'
count = 1

TFRecord自体のファイルパスと
TFRecordで読み込んだデータを画像として出力するフォルダパスを定義しています。
 

TFRecord読み込み用メソッドの定義

def read_tfrecord(serialized):
    features = tf.parse_single_example(
        serialized,
        features={
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64),
            'image': tf.FixedLenFeature([], tf.string)
        })
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    label = tf.cast(features['label'], tf.int32)
    image = tf.decode_raw(features['image'], tf.uint8)    
    return height, width, label, image

ここでは、TFRecordに詰め込まれた複数のデータのうち1つを順に取り出すことを行っています。
引数として渡される「serialized」に入っている値は、
スカラー文字列Tensorですので、そのままでは使用できません。
ですので、parse_single_example(https://www.tensorflow.org/api_docs/python/tf/parse_single_example)を使用して
書き込み時に使用したキーを「features」で指定して、値を取り出します。
最後に、parse_single_exampleの戻り値として受け取った「features」を使用したい型に変換してから、
read_tfrecordメソッドの呼び出し元へと返します。
 

TFRecordDatasetの準備

with tf.Session() as sess:
    dataset = tf.data.TFRecordDataset([TFRECORD_FILE])
    dataset = dataset.map(read_tfrecord)
    iterator = dataset.make_initializable_iterator()
    next_dataset = iterator.get_next()
    sess.run(iterator.initializer)

まずは、お決まりの文言ですが、セッションブロックを定義します。
TensorBoardを使用する場合は細かくブロックを分けたり、
変数・定数をTensorFlowの型を使用して名前を付けつつ定義したりしますが
今回は特に使用しないので、適当につけてしまいます。

続いて、tf.data.TFRecordDataset(https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset)を使用して
TFRecordファイルの読み込みの準備を始めます。
TFRecordDataset.mapメソッドでデータセットに含まれるデータの変換を行う処理を追加しています。
ここで行うデータの変換は、先ほど定義したTFRecord読み込み用メソッドのread_tfrecordを使用します。
尚、TFRecordDataset.shuffleメソッドを使用すると、
TFRecordファイルに含まれるデータの取り出し順をシャッフルすることが出来ます。
その他にもいろいろ操作出来るメソッドが用意されていますので、
公式サイトを流し読みすると良いかもしれません。

次に、make_initializable_iteratorメソッドでデータセットのイテレータを取得します。
取得したイテレータのget_nextメソッドで順にデータを取り出していきます。
取り出しを行う前に、一度初期化を行う必要がありますので、
sess.run(iterator.initializer)を行っています。
 

TFRecordのデータを順に処理

    count = 1
    while(True):
        try:
            height, width, label, image = sess.run(next_dataset)            
            img = Image.frombytes('RGB', [height, width], image)
            img.save(os.path.join(OUTPUT_DIR, "TFRecord_{0}_{1}.jpg".format(label, count)))
            count += 1
        except tf.errors.OutOfRangeError:
            print("data end")
            break
        
    sess.close()

順に処理を行うので、ループを回しますが、
まずは画像の名前で使用する連番用のcount変数を定義します。
続いて、TFRecordの全てのデータを処理する為に無限ループを定義しています。
中のデータ数が分かってる場合や、
一定数のみ処理をする場合などはその数を指定してループで良いかと思います。

ループ内の処理では、TFRecordの最後のデータを読み込んだ後にget_nextメソッドを呼んだ場合に
「tf.errors.OutOfRangeError」がthrowされますので、try-exceptブロックを定義します。

tryブロックの中で、sess.run(next_dataset)を呼び出し、
TFRecordファイルの中のデータを1つず取り出していきます。
今回は、read_tfrecordの戻り値を受け取れますので、
「height」「width」「label」「image」を受け取ります。
受け取ったデータをPillowに含まれるImage.frombytesにて、
バイナリデータをイメージデータに変換します。
その後、変換したイメージデータを、saveメソッドにてjpg形式で保存します。
保存した後は連番を繰り上げるため、count変数を+1しています。

続いてexceptブロックでは、TFRecordが最後まで読み終わった際に回ってくる処理となりますので、
printにてデータ終了を表示した後、無限ループとなっているwhileのループをbreakします。

以上で、TFRecordファイルを読み込み→1データずつ順に格納した際に使用したキーで値を取得→
取得した値(高さ、幅、ラベル、画像)を所定のフォルダに保存の一連の処理となります。

勿論、今回のような使い方ではなく、
学習の際に1データずつ順に取り出して渡していくといった使い方が主であるので、
各所に落ちている古めのサンプルを
Datasetを使用する形に書き直して勉強してみるのも良いのではないでしょうか。