EeBlog(テクニカルブログ)

AI(人工知能)実践 第6回 TensorFlow(アヤメの分類)1

AI関連 第6回です。
第4回、第5回ではNumpyについて触れてきました。
第6回では機械学習ライブラリで有名なTensorFlowを触れていきます。
 

TensorFlowの中身に入る前に、前提知識として、Tensor(読み方:テンソル)についてお話します。
 
Tensorとは、wikipedia(https://ja.wikipedia.org/wiki/)によると

テンソル(英: tensor, 独: Tensor)とは、線形的な量または線形的な幾何概念を一般化したもので、
基底を選べば、多次元の配列として表現できるようなものである。
しかし、テンソル自身は、特定の座標系によらないで定まる対象である。個々のテンソルについて、
対応する量を記述するのに必要な配列の添字の組の数は、そのテンソルの階数とよばれる。

とあります。
プログラマの方であれば、単純に多次元配列の事を指していると考えるとイメージしやすいかと思います。
 
テンソルに関連するものとして頻出する単語として、スカラー・ベクトル・行列がありますが、
それぞれ下記のようになります。
 
スカラーとは、大きさを示すものであり、0階のテンソルとも表現出来ます。
つまり、距離や速さ(速度はベクトル)、重さ、時間等が該当します。
プログラマであれば、配列ではなく、単純なint等のオブジェクトとして表現されるものとすると、
イメージしやすいかと思います。
 
ベクトルとは、大きさと方向を示すものであり、1階のテンソルとも表現できます。
つまり、速度や加速度、運動量等が該当します。
プログラマであれば、1次元の配列で表現されるものとするとイメージしやすいかと思います。
 
行列とは、数あるいは式(関数)を長方形状に並べて括弧で囲んだものであり、
2階のテンソルとも表現出来ます。つまり、行と列で構成されたものです。
Microsft Excelや表をイメージするとわかりやすいかもしれません。
 
 
では、ざっくりとした説明が終わったところで、TensorFlowを使った機械学習へ移っていきます。
今回からはチュートリアルにもあるアヤメの分類のコードをベースに、複数回に分けて説明していきます。
尚、TensorFlowのバージョンについては、2018/6時点で最新の1.8を使用する想定をしています。
※下記のコードは古いバージョン(1.3)のチュートリアル
(https://www.tensorflow.org/versions/r1.3/get_started/estimator#load_the_iris_csv_data_to_tensorflow)
にあるコードをベースにしていますが、TensorFlowバージョン1.8段階で非推奨部分は、
違うやり方に書き換えてあります。
 

ソース全文

import pandas as pd
import os
import urllib

import tensorflow as tf
import numpy as np

IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"

IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

#データファイルが無い場合ダウンロード
if not os.path.exists(IRIS_TRAINING):
    raw = urllib.urlopen(IRIS_TRAINING_URL).read()
    with open(IRIS_TRAINING,'w') as f:
        f.write(raw)

if not os.path.exists(IRIS_TEST):
    raw = urllib.urlopen(IRIS_TEST_URL).read()
    with open(IRIS_TEST,'w') as f:
        f.write(raw)
    
#データ読み込み用のメソッド作成
def read_file(file_name):
    data=pd.read_csv(file_name, usecols = [0, 1, 2, 3])
    label = pd.read_csv(file_name, usecols = [4])
    return data.values, label.values

#データ読み込み
training_x, training_y = read_file(IRIS_TRAINING)
test_x, test_y = read_file(IRIS_TEST)

# すべてのfeatureが実数値データであることを指定する
feature_columns = [tf.feature_column.numeric_column("x", shape=[4])]

# 10, 20, 10のノードを持つ3層のDNNを作成
classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                        hidden_units=[10, 20, 10],
                                        n_classes=3,
                                        model_dir="./iris_model")

# 学習用データ(辞書)を返す関数の作成
train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": np.array(training_x)},
    y=np.array(training_y),
    num_epochs=None,
    shuffle=True)

# 学習の実行。stepsを分割して実行しても同様の結果を得られる
classifier.train(input_fn=train_input_fn, steps=2000)

# 評価用データ(辞書)を返す関数の作成
test_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": np.array(test_x)},
    y=np.array(test_y),
    num_epochs=1,
    shuffle=False)

# 評価
accuracy_score = classifier.evaluate(input_fn=test_input_fn)["accuracy"]

print("\nTest Accuracy: {0:f}\n".format(accuracy_score))

解説

前述のソースコードを順に説明していきます。

ライブラリインポート

import pandas as pd
import os
import urllib

import tensorflow as tf
import numpy as np

必要となるライブラリのインポートを行っています。
上から順に

pandas : 高性能で使いやすいデータ構造とデータ解析ツールを提供するオープンソースのライブラリ
os : OS 依存の機能を利用するポータブルな方法を提供するPython標準モジュール
usllib: URL を扱う幾つかのモジュールを集めたPython標準パッケージ
tensorflow : Googleの開発した機械学習用のオープンソースのライブラリ
numpy :数値計算を効率的に行うためのオープンソースの拡張モジュール

です。
 

学習・評価用データのダウンロード元・保存ファイル名定義

IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"

IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

学習・評価用データファイルダウンロードや、データ読み込みの際に使うため、
ファイル名とダウンロード元を定義しています。
 

学習・評価用のデータのダウンロード

#データファイルが無い場合ダウンロード
if not os.path.exists(IRIS_TRAINING):
    raw = urllib.urlopen(IRIS_TRAINING_URL).read()
    with open(IRIS_TRAINING,'w') as f:
        f.write(raw)

if not os.path.exists(IRIS_TEST):
    raw = urllib.urlopen(IRIS_TEST_URL).read()
    with open(IRIS_TEST,'w') as f:
        f.write(raw)

os.path.exsistsメソッドにて、前述のデータファイルが存在するかチェックしています。
ファイルが存在しない場合、urllib.urlopenメソッド及びreadメソッドにて、
ダウンロード元にアクセスし、データの読み込みを行っています。
その後、with openでファイルを上書きモードで開き(ファイルが存在しなければ新規作成される)、
writeで読み込んだデータをファイルに書き込んでいます。
 
次回に続きます。