書籍転載:TensorFlowはじめました ― 実践!最新Googleマシンラーニング(6)

書籍転載:TensorFlowはじめました ― 実践!最新Googleマシンラーニング(6)

TensorFlowによる学習 ― 画像を分類するCIFAR-10の基礎

2016年8月23日

転載6回目。CIFAR-10データセットを使った学習と評価を行う。「推論」(=画像の種類・クラスを判別)が終わったので、今回は「学習」(=訓練)について説明する。

有山 圭二
  • このエントリーをはてなブックマークに追加

書籍転載について

 本コーナーは、インプレスR&D[Next Publishing]発行の書籍『TensorFlowはじめました ― 実践!最新Googleマシンラーニング』の中から、特にBuild Insiderの読者に有用だと考えられる項目を編集部が選び、同社の許可を得て転載したものです。

 『TensorFlowはじめました ― 実践!最新Googleマシンラーニング』(Kindle電子書籍もしくはオンデマンドペーパーバック)の詳細や購入はAmazon.co.jpのページをご覧ください。書籍全体の目次は連載INDEXページに掲載しています。プログラムのダウンロードは、「TensorFlowはじめました」のサポート用フォームから行えます。

ご注意

本記事は、書籍の内容を改変することなく、そのまま転載したものです。このため用字用語の統一ルールなどはBuild Insiderのそれとは一致しません。あらかじめご了承ください。

2.3 学習(learn)

 モデルそのものは推論(inference)の機能しか持ちません。現在のパラメーターを使って出力層のノードそれぞれについて確からしさを「推論」するのがモデルの役割です。

 「学習(learn)」とは、推論の結果と期待する結果との誤差をもとに、推論をより期待する結果に近づくようにパラメーターを更新することを言います。学習を「訓練(train)」と呼ぶこともありますが、同じ意味と考えて差し支えありません。

 学習は、以下の手順で行います。

  1. 推論(inference)
  2. 損失関数を使って推論の結果と正解の誤差(損失)を求める
  3. 最適化アルゴリズムを使ってパラメーターを更新する

 リスト2.12は、学習に必要なオペレーションを加えたグラフ(図2.6)を構築します。

図2.6: 学習のオペレーションを追加したグラフ
Python
def _loss(logits, label):
  labels = tf.cast(label, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits, labels, name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  return cross_entropy_mean


def _train(total_loss, global_step):
  opt = tf.train.GradientDescentOptimizer(learning_rate=0.001)
  grads = opt.compute_gradients(total_loss)
  train_op = opt.apply_gradients(grads, global_step=global_step)
  return train_op


filenames = [
  os.path.join(
    FLAGS.data_dir,'data_batch_%d.bin' % i) for i in range(1, 6)
  ]


def main(argv=None):
  global_step = tf.Variable(0, trainable=False)
  
  train_placeholder = tf.placeholder(tf.float32,
                                     shape=[32, 32, 3],
                                     name='input_image')
  label_placeholder = tf.placeholder(tf.int32, shape=[1], name='label')
  
  # (width, height, depth) -> (batch, width, height, depth)
  image_node = tf.expand_dims(train_placeholder, 0)
  
  logits = model.inference(image_node)
  total_loss = _loss(logits, label_placeholder)
  train_op = _train(total_loss, global_step)
  
  with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    
    total_duration = 0
    
    for epoch in range(1, FLAGS.epoch + 1):
      start_time = time.time()
      
      for file_index in range(5):
        print('Epoch %d: %s' % (epoch, filenames[file_index]))
        reader = Cifar10Reader(filenames[file_index])
        
        for index in range(10000):
          image = reader.read(index)
          
          _, loss_value, logits_value = sess.run(
            [train_op, total_loss, logits],
            feed_dict={
              train_placeholder: image.byte_array,
              label_placeholder: image.label
            })
          
          assert not np.isnan(loss_value), \
            'Model diverged with loss = NaN'
          
          if index % 1000 == 0:
            print('[%d]: %r' % (image.label, logits_value))
        
        reader.close()
      
      duration = time.time() - start_time
      total_duration += duration
      
      tf.train.SummaryWriter(FLAGS.checkpoint_dir, sess.graph)
    
    print('Total duration = %d sec' % total_duration)
リスト2.12: train.py
損失関数(Loss Function)

 損失関数(誤差関数)は、推論の結果と正解の「誤差」を求めます。

 リスト2.13では、損失関数としてCross Entropy (交差エントロピー)を使うことで、推論の結果logitsと正解labelの誤差(cross_entropy_mean)を求めています。

Python
def _loss(logits, label):
  labels = tf.cast(label, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits, labels, name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  return cross_entropy_mean
リスト2.13: _loss関数
最適化アルゴリズム(Optimizer)

 最適化アルゴリズムは、損失関数が計算した誤差をもとに、より正解に近い推論ができるようにパラメーターを更新します。

 リスト2.14では、勾配降下法 (Gradient Descent Algorithm)を用いてパラメーターを更新します。

Python
def _train(total_loss, global_step):
  opt = tf.train.GradientDescentOptimizer(learning_rate=0.001)
  grads = opt.compute_gradients(total_loss)
  train_op = opt.apply_gradients(grads, global_step=global_step)
  return train_op
リスト2.14: _train関数

 一般的に、学習が上手く進んでいると誤差が小さくなります。反対に学習に失敗していると誤差は大きくなっていきます。前者を「収束」、後者を「発散」とそれぞれ呼びます。

Note: 学習率(Leaning Rate)

学習率は、勾配をもとにどの程度の割合でパラメーターを更新するのかを決める値です。学習率が大きすぎるとパラメーターの更新量が大きくなりすぎて誤差が収束しなかったり、発散したりして学習に失敗します。逆に学習率が小さすぎると、勾配に対するパラメーターの更新量が小さすぎて学習が進まなかったりします。

学習率のように人が決定する必要がある値を「ハイパーパラメーター」と呼びます。

学習の結果

 リスト2.12のプログラムを実行すると、入力画像に対する推論と学習を実行します。リスト2.15は、学習を2エポック*6実行したときの「data_batch_1.bin」の1枚目の画像(Frog: 6)に対する推論の結果(logits)です。

# 初期状態
[6]: [[ 0.00440715, -0.0027222 ,  0.01911113, -0.00304644,  0.02481071,
       -0.01777996,  0.00692139,  0.00362923,  0.02579042,  0.0111508 ]]

# 1 Epoch
[6]: [[-0.01063725,  0.6968801 ,  0.73776543, -0.07745715,  0.18181412,
       -0.49579746,  0.29869688, -0.81775153,  0.01088678, -0.43416238]]

# 2 Epoch
[6]: [[-3.4853344 , -1.36028135, -0.60522491,  2.92168188,  0.6786142 ,
        1.10281229,  4.0964551 , -0.12701799, -1.98207819, -0.99662489]]
リスト2.15: 各エポックのlogits

 初期状態ではすべてのクラスに対して低い値を示していますが、2エポックを過ぎると学習が進み「Frog: 6」と判定できるようになりました。

  • *6 データセットに含まれるすべてのデータについて処理を終えると「1エポック(Epoch)」と数えます。

 今回は「学習」(=訓練)を行いました。次回は、「評価」について説明します。

※以下では、本稿の前後を合わせて5回分(第3回~第7回)のみ表示しています。
 連載の全タイトルを参照するには、[この記事の連載 INDEX]を参照してください。

3. TensorFlowの“テンソル(Tensor)”とは? TensorBoardの使い方

転載3回目。テンソル(Tensor)とTensorBoardによるグラフの可視化を解説する。「第1章 TensorFlowの基礎」は今回で完結。

4. TensorFlowでデータの読み込み ― 画像を分類するCIFAR-10の基礎

転載4回目。今回から「畳み込みニューラルネットワーク」のモデルを構築して、CIFAR-10のデータセットを使った学習と評価を行う。今回はデータの読み込みを説明。

5. TensorFlowによる推論 ― 画像を分類するCIFAR-10の基礎

転載5回目。CIFAR-10データセットを使った学習と評価を行う。画像データの読み込みが終わったので、今回は画像の種類(クラス)を判別、つまり「推論」について説明する。

6. 【現在、表示中】≫ TensorFlowによる学習 ― 画像を分類するCIFAR-10の基礎

転載6回目。CIFAR-10データセットを使った学習と評価を行う。「推論」(=画像の種類・クラスを判別)が終わったので、今回は「学習」(=訓練)について説明する。

7. TensorFlowによる評価 ― 画像を分類するCIFAR-10の基礎

転載7回目(最終回)。CIFAR-10データセットを使った学習と評価を行う。「学習」(=訓練)が終わったので、今回は「評価」について説明する。「第2章 CIFAR-10の学習と評価」は今回で完結。

サイトからのお知らせ

Twitterでつぶやこう!


Build Insider賛同企業・団体

Build Insiderは、以下の企業・団体の支援を受けて活動しています(募集概要)。

ゴールドレベル

  • 日本マイクロソフト株式会社
  • グレープシティ株式会社