書籍転載:TensorFlowはじめました ― 実践!最新Googleマシンラーニング(4)

書籍転載:TensorFlowはじめました ― 実践!最新Googleマシンラーニング(4)

TensorFlowでデータの読み込み ― 画像を分類するCIFAR-10の基礎

2016年8月16日

転載4回目。今回から「畳み込みニューラルネットワーク」のモデルを構築して、CIFAR-10のデータセットを使った学習と評価を行う。今回はデータの読み込みを説明。

有山 圭二
  • このエントリーをはてなブックマークに追加

 書籍『TensorFlowはじめました ― 実践!最新Googleマシンラーニング』から全7本の記事を転載します。本稿はその4回目です。今回から「第2章 CIFAR-10の学習と評価」に入ります。今回はデータの読み込みについて説明します。

書籍転載について

 本コーナーは、インプレスR&D[Next Publishing]発行の書籍『TensorFlowはじめました ― 実践!最新Googleマシンラーニング』の中から、特にBuild Insiderの読者に有用だと考えられる項目を編集部が選び、同社の許可を得て転載したものです。

 『TensorFlowはじめました ― 実践!最新Googleマシンラーニング』(Kindle電子書籍もしくはオンデマンドペーパーバック)の詳細や購入はAmazon.co.jpのページをご覧ください。書籍全体の目次は連載INDEXページに掲載しています。プログラムのダウンロードは、「TensorFlowはじめました」のサポート用フォームから行えます。

ご注意

本記事は、書籍の内容を改変することなく、そのまま転載したものです。このため用字用語の統一ルールなどはBuild Insiderのそれとは一致しません。あらかじめご了承ください。

第2章 CIFAR-10の学習と評価

 CIFAR-10*1は、10種類の画像を分類する「多クラス分類」と呼ばれるタスクの画像セットです。

図2.1: CIFAR-10のクラス(名前の左側の数字はラベルの値)
  • *1 CIFAR: Canadian Institute For Advanced Research。

 本章では、「畳み込みニューラルネットワーク」のモデルを構築して、CIFAR-10のデータセットを使った学習と評価を行います。

2.1 データの読み込み

データの入手

 CIFAR-10のデータセットは、トロント大学のAlex Krizhevsky氏が配布しています。

 Pythonで直接読み込める形式(Pickle形式)が用意されてますが、今回はバイナリ形式の「CIFAR-10 binary version (suitable for C programs)」をダウンロードします。

 ダウンロードしたZIPファイルには、6つのファイルが含まれています。

 data_batch_[1-5].binが訓練データ、test_batch.binがテストデータです。以後、これらのファイルをディレクトリ./dataに配置したとして解説を進めます。

CIFAR-10のデータ構造

 CIFAR-10のデータは、JPEGやPNGなど一般的な画像形式ではありません。そのため、データ構造に合わせてプログラムの中で読み込み、さらにTensorFlowで取り扱うことができる形式へ変換する必要があります。

 1つのファイル(データセット)には、10,000個のレコードが含まれています(図2.2)。

図2.2: データセットの構造

 レコードは固定長の3,073バイト。先頭の1バイトがラベルで、残り3,072バイトは縦横32pxの画像データを直列化したものです。一般的なBitmap形式と違い、RGB各チャンネルのデータが1,024バイトずつ並ぶ構造になっています。

 画像ラベルと種類(クラス)の対応は図2.1の通りです。

読み込みと構造変更

 リスト2.1は、CIFAR-10形式のデータセットを読み込むプログラムです。

 まず、Cifar10Readerのコンストラクタに、読み込むデータセット(ファイル)の名前を指定します。次に、readメソッドに「レコード番号」を指定すると対応するレコード(Cifar10Record)が得られます。

Python
# coding: UTF-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import numpy as np


class Cifar10Record(object):
  width = 32
  height = 32
  depth = 3
  
  def set_label(self, label_byte):
    self.label = np.frombuffer(label_byte, dtype=np.uint8)
  
  def set_image(self, image_bytes):
    byte_buffer = np.frombuffer(image_bytes, dtype=np.int8)
    reshaped_array = np.reshape(byte_buffer,
                                [self.depth, self.width, self.height])
    self.byte_array = np.transpose(reshaped_array, [1, 2, 0])
    self.byte_array = self.byte_array.astype(np.float32)

class Cifar10Reader(object):
  def __init__(self, filename):
    if not os.path.exists(filename):
      print(filename + ' is not exist')
      return
  
    self.bytestream = open(filename, mode="rb")
  
  def close(self):
    if not self.bytestream:
      self.bytestream.close()
  
  def read(self, index):
    result = Cifar10Record()
    
    label_bytes = 1
    image_bytes = result.height * result.width * result.depth
    record_bytes = label_bytes + image_bytes
    
    self.bytestream.seek(record_bytes * index, 0)
    
    result.set_label(self.bytestream.read(label_bytes))
    result.set_image(self.bytestream.read(image_bytes))
    
    return result
リスト2.1: reader.py

 readメソッド内で取得する画像データは、最初、直列化された1次元配列です。TensorFlowで取り扱うには、一般的なBitmap形式の「幅・高さ・チャンネル」の構造に変換する必要があります。

 そこで、リスト2.1ではまず、numpyのreshapeで一次元配列から「チャンネル・幅・高さ」の3次元配列に変換しています。次にnumpyのtranspose関数を用いて「幅・高さ・チャンネル」の形式に配列を転置して、最後にfloat32に型を変換しています。

PNG形式で書き出し

 CIFAR-10形式から画像データを取り出すことができたので、取得した画像データを一般的な画像形式で保存してみましょう(リスト2.2)。

 この行程はTensorFlowを学ぶ上では必要ありませんが、画像データの操作の練習として試してみてください。

Python
# coding: UTF-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import numpy as np
import tensorflow as tf
from PIL import Image

from reader import Cifar10Reader

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('file', None, "処理するファイルのパス")
tf.app.flags.DEFINE_integer('offset', 0, "読み飛ばすレコード数")
tf.app.flags.DEFINE_integer('length', 16, "読み込んで変換するレコード数")

basename = os.path.basename(FLAGS.file)
path = os.path.dirname(FLAGS.file)

reader = Cifar10Reader(FLAGS.file)

stop = FLAGS.offset + FLAGS.length
for index in range(FLAGS.offset, stop):
  image = reader.read(index)
  
  print('label: %d' % image.label)
  imageshow = Image.fromarray(image.byte_array.astype(np.uint8))
  
  file_name = '%s-%02d-%d.png' % (basename, index, image.label)
  file = os.path.join(path, file_name)
  with open(file, mode='wb') as out:
    imageshow.save(out, format='png')

reader.close()
リスト2.2: convert_cifar10_png.py

 処理するファイルをコマンドラインの引数に指定して実行すると、図2.3のような画像が、データセットと同じディレクトリに書き出されます。

$ python3 convert_cifar10_png.py --file ./data/data_batch_1.bin
図2.3: 取り出したCIFAR-10画像(拡大)
図2.3: 取り出したCIFAR-10画像(拡大)

Note: tf.app.flags

リスト2.2冒頭にあるtf.app.flagsは、コマンドラインの引数を簡単に設定する機能を提供します。

 これはGoogleがオープンソースで公開している「gflags」のPython実装「python-gflags」と同等のものです。

 今回は「画像データの読み込み」を行いました。次回は、「推論」(=画像の種類・クラスを判別)について説明します。

※以下では、本稿の前後を合わせて5回分(第1回~第5回)のみ表示しています。
 連載の全タイトルを参照するには、[この記事の連載 INDEX]を参照してください。

1. TensorFlowとは? データフローグラフを構築・実行してみよう

技術書オンリー即売会「技術書典」で頒布された同名出版物をベースとして制作されたTensorFlowの入門書籍を転載開始。その1回目として、データフローグラフや定数といったTensorFlowの基礎を説明する。

2. TensorFlow入門 ― 変数とプレースホルダー

転載2回目。TensorFlowの基礎の第2弾として、変数とプレースホルダーを実際のコードと実行結果で示しながら解説する。

3. TensorFlowの“テンソル(Tensor)”とは? TensorBoardの使い方

転載3回目。テンソル(Tensor)とTensorBoardによるグラフの可視化を解説する。「第1章 TensorFlowの基礎」は今回で完結。

4. 【現在、表示中】≫ TensorFlowでデータの読み込み ― 画像を分類するCIFAR-10の基礎

転載4回目。今回から「畳み込みニューラルネットワーク」のモデルを構築して、CIFAR-10のデータセットを使った学習と評価を行う。今回はデータの読み込みを説明。

5. TensorFlowによる推論 ― 画像を分類するCIFAR-10の基礎

転載5回目。CIFAR-10データセットを使った学習と評価を行う。画像データの読み込みが終わったので、今回は画像の種類(クラス)を判別、つまり「推論」について説明する。

サイトからのお知らせ

Twitterでつぶやこう!