R&Dセンター 技術開発部 古林 隆宏

TensorFlow+Keras入門 ~ ウチでもできた⁉ ディープラーニング ~

第5回 Java WebアプリでTensorFlow(実装編)

第5回は、TensorFlowモデルを実利用するにあたり、前回検討した「Pythonでモデルを構築し、それをJavaから利用する」仕組みをTensorFlow Servingを用いて実現していきます。

記事内で必要に応じてPython等のコードを示すことがありますが、言語自体の機能や構文に関する解説は割愛させていただきますのでご了承ください。
なお、この記事にて掲載しているコードはすべてMITライセンスのもと利用を許諾するものとします。

TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.

1) おさらい:Pythonで作ってJavaから使いたい

前回の記事では、Javaで書かれている既存アプリケーションに「AI機能」を組み込む、という明日にも我が身に降りかかりそうな状況を想定し、ディープラーニングの学習済みモデルを利用する方式を検討しました。

結論としては、学習済みモデルの部分を既存アプリケーションとは独立したサービスとして構成することで多くのベネフィットが得られそうということでした。
TensorFlowを使っていればTensorFlow Servingを用いて簡単に学習済みモデルをサービス化することができますので、今回はTensorFlow Servingを利用して実際に仕組みを構築していきます。

本記事では以下のような手順で実装方法を紹介していきます:

  1. Pythonから学習済みモデルを書き出す
  2. TensorFlow Serving環境を構築しサーバを起動
  3. Javaから利用できることを確認する

2) 学習済みモデルの書き出し

学習済みモデルをTensorFlow Servingに載せるには、SavedModel形式で書き出す必要があります。
第3回で作成したPythonコードに対して必要な変更は以下の2点です:

  • Estimator作成時に書き出し用の情報を付加するようにする
  • 実際に書き出すコードを追加する

(1) Estimator作成部分の変更

Estimator作成のための関数のうち最後に値を返す部分は、学習を行う場合と学習はせず利用のみする場合で条件分岐していました。
今回行いたい「書き出し」操作にも学習は必要ないので、学習はせず利用のみする場合の分岐を利用し、書き出し用の情報を付加するような変更を加えます。

<             return tf.estimator.EstimatorSpec(mode=mode, predictions=probabilities)
---
>             output = tf.estimator.export.ClassificationOutput(scores=probabilities)
>             return tf.estimator.EstimatorSpec(mode=mode, predictions=probabilities,
>                                               export_outputs={'result': output})

上記の通り、EstimatorSpecオブジェクトの作成時にClassificationOutputを渡すように変更しています。これにより、このモデルがデータの分類(Classification)タスクであり、出力としてprobabilities変数の指すもの、つまり犬猫それぞれの確率を返すということを定義しています。
この定義はTensorFlow Servingで学習済みモデルを動作させる際に利用されます。

上記の変更を適用すると、mydnn_tf.pyは以下のようになります。

import tensorflow as tf
 
def MyDNN(input_shape=(32, 32, 1), output_size=10, learning_rate=0.001,
          keep_prob=0.5, model_dir='tfmodel'):
    def mydnn_fn(features, labels, mode):
        input_layer = tf.reshape(features["img"], [-1] + list(input_shape))
 
        layer = tf.layers.conv2d(filters=20, kernel_size=5, strides=2,
                                 activation=tf.nn.relu,
                                 inputs=input_layer)
        layer = tf.layers.max_pooling2d(pool_size=3, strides=2,
                                        inputs=layer)
        layer = tf.layers.batch_normalization(inputs=layer)
         
        layer = tf.layers.conv2d(filters=50, kernel_size=5, strides=2,
                                 activation=tf.nn.relu,
                                 inputs=layer)
        layer = tf.layers.max_pooling2d(pool_size=3, strides=2,
                                        inputs=layer)
        layer = tf.layers.batch_normalization(inputs=layer)
      
        layer = tf.contrib.layers.flatten(layer)
        layer = tf.layers.dense(units=100, activation=tf.nn.relu,
                                inputs=layer)
        layer = tf.layers.dropout(rate=(1 - keep_prob),
                                  training=(mode == tf.estimator.ModeKeys.TRAIN),
                                  inputs=layer)
        output_layer = tf.layers.dense(units=output_size, 
                                       inputs=layer)
 
        # for prediction
        if mode == tf.estimator.ModeKeys.PREDICT:
            probabilities = tf.nn.softmax(output_layer)
            output = tf.estimator.export.ClassificationOutput(scores=probabilities)
            return tf.estimator.EstimatorSpec(mode=mode, predictions=probabilities,
                                              export_outputs={'result': output})
 
        # for training
        labels = tf.reshape(labels, [-1, output_size])
        loss = tf.losses.softmax_cross_entropy(onehot_labels=labels,            
                                               logits=output_layer)
        optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=1e-1)       
        train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
 
        # for evaluation
        classes = tf.argmax(input=output_layer, axis=1)
        eval_metric_ops = {
          "accuracy": tf.metrics.accuracy(labels=tf.argmax(labels, axis=1),
                                          predictions=classes)
        }
 
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)
 
 
    estimator = tf.estimator.Estimator(model_fn=mydnn_fn, model_dir=model_dir)
 
    return estimator
  1. return estimator
  2. (2) 実際に書き出すコードの追加

モデルを書き出すには、Estimatorオブジェクトのexport_savedmodelメソッドを呼び出します。
第1引数は保存先ディレクトリ、第2引数は入力データの受け取り方を定義する関数です。

    def input_receiver_fn():
        # ここでtf.estimator.export.ServingInputReceiverを作成して返す
 
    model.export_savedmodel(export_dir, input_receiver_fn)

このコードでは第2引数にわたす関数としてinput_receiver_fnを定義しています。
最終的にはtf.estimator.export.ServingInputReceiverオブジェクトを生成して返すのがこの関数の役目です。
ServingInputReceiverは名前の通りTensorFlow Servingに送られてきた入力データの処理の仕方を定義するためのものです。

では、input_receiver_fnの中身を見てみましょう。

    def input_receiver_fn():
        input_example = tf.placeholder(dtype=tf.string)
        input_size_xyc = input_size_x * input_size_y * input_size_c
        feature_spec = {'img' : tf.FixedLenFeature(dtype=tf.float32,
                                                   shape=(input_size_xyc))}
        features = tf.parse_example(input_example, feature_spec)
        receiver_tensors = {'examples': input_example}
        return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

ServingInputReceiverのイニシャライザの第1引数featuresは、クライアントから送られてくるデータを受け止めてEstimatorで利用可能な形に加工したものです
(※ 正確には、そのような加工処理を行うTensorFlowの計算工程を示すものです)。
クライアントからは文字列の形にシリアライズされたデータが送られてきますので、それをうまく構造化データにもどしてやる必要があるわけです。

受け取った文字列はまず文字列形式のplaceholderであるinput_exampleで受け止めることとします。
そしてその文字列をtf.parse_example関数に渡してパースします(ちなみにexampleという名前は、TensorFlow Servingのサーバ/クライアント間でデータの受け渡しに使われるクラスtf.train.Exampleに由来しています)。
parse_exampleの第2引数として与えているfeature_specは、受け取るデータは固定長のfloatデータであり、Estimatorに渡すためにimgというキーを添えることを指定しています。
parse_exampleの返り値としてfeaturesが得られます。

ただし、このままでは、input_exampleは数多くあるplaceholderのひとつに過ぎませんので、クライアントから受け取ったデータをinput_exampleで受け止めたいということを明示的に伝える必要があります。
それを行っているのがServingInputReceiverのイニシャライザの第2引数receiver_tensorsです。

全体として、cat_dog_dnn_tf.pyのコードは以下のようになります。
既存の「学習モード」「判定モード」に加えて「書き出しモード」を選択できるようにし、上記の書き出し部分のコードを追加しています。

import argparse
import tensorflow as tf
import os, sys
import numpy as np
from datetime import datetime
from mydnn_tf_layers import MyDNN
 
ImageDataGenerator = tf.keras.preprocessing.image.ImageDataGenerator
load_img = tf.keras.preprocessing.image.load_img
img_to_array = tf.keras.preprocessing.image.img_to_array
 
input_size_x = 224
input_size_y = 224
batch_size = 20
input_size_c = 3
output_size = 2
model_dir = 'ckpt/tf/'
 
export_dir = 'tf_savedmodels/'
 
parser = argparse.ArgumentParser()
parser.add_argument("--export", action="store_true", default=False,
                    help="学習済みモデルを本番向けに書き出す")
parser.add_argument("--infer", action="store", nargs="?", type=str,
                    help="学習は行わず、このオプションで指定した画像ファイルの判定処理を行う")
parser.add_argument("--epochs", action="store", nargs="?", default=10, type=int,
                    help="学習データ全体を何周するか")
args = parser.parse_args()
 
epochs = args.epochs
 
model = MyDNN(input_shape=(input_size_x, input_size_y, input_size_c),
              output_size=output_size, model_dir=model_dir)
 
rain = not (args.infer or args.export)
if train:
    print("学習モード")
    # 学習データの読み込み
    keras_idg = ImageDataGenerator(rescale=1.0 / 255)
    train_generator = keras_idg.flow_from_directory('data/train',
                          target_size=(input_size_x, input_size_y),
                          batch_size=1,
                          class_mode='categorical',
                          shuffle=True)
    valid_generator = keras_idg.flow_from_directory('data/valid',
                          target_size=(input_size_x, input_size_y),
                          batch_size=1,
                          class_mode='categorical')
 
    # 学習の実行
    num_data_train_dog = len(os.listdir('data/train/dog'))
    num_data_train_cat = len(os.listdir('data/train/cat'))
    num_data_train = num_data_train_dog + num_data_train_cat
 
    num_data_valid_dog = len(os.listdir('data/valid/dog'))
    num_data_valid_cat = len(os.listdir('data/valid/cat'))
    num_data_valid = num_data_valid_dog + num_data_valid_cat
 
    steps_per_epoch = num_data_train / batch_size
    validation_steps = num_data_valid / batch_size
 
    def my_input_fn(generator):
        gen_fn = lambda: generator 
        dataset = tf.data.Dataset.from_generator(gen_fn, (tf.float32, tf.float32))
        dataset = dataset.map(lambda f, l: ({"img": f}, l)).batch(batch_size)
        return dataset.make_one_shot_iterator().get_next()
 
    for epoch in range(epochs):
        print("epoch ",  epoch)
        print("training...")
        model.train(input_fn=lambda: my_input_fn(train_generator),
                    steps=steps_per_epoch)
        print("evaluation:")
        eval_results = model.evaluate(input_fn=lambda: my_input_fn(valid_generator),
                                      steps=validation_steps)
        print(eval_results)
 
elif args.infer:
    print("判定モード")
    # 判定する画像の読み込み
    image_infer = load_img(args.infer, target_size=(input_size_x, input_size_y))
    data_infer = img_to_array(image_infer)
    data_infer = np.expand_dims(data_infer, axis=0)
    data_infer = data_infer / 255.0
    print(data_infer.shape)
 
    # 判定処理の実行
    predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"img": data_infer},
                                                          batch_size=1,
                                                          shuffle=False)
    result_generator = model.predict(predict_input_fn) 
    result = next(result_generator) * 100
 
    # 判定結果の出力
    if result[0] > result[1]:
        print('Cat (%.1f%%)' % result[0])
    else:
        print('Dog (%.1f%%)' % result[1])
 
else:
    print("書き出しモード")
    def input_receiver_fn():
        input_example = tf.placeholder(dtype=tf.string)
        input_size_xyc = input_size_x * input_size_y * input_size_c
        feature_spec = {'img' : tf.FixedLenFeature(dtype=tf.float32,
                                                   shape=(input_size_xyc))}
        features = tf.parse_example(input_example, feature_spec)
        receiver_tensors = {'examples': input_example}
        return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
 
    model.export_savedmodel(export_dir, input_receiver_fn)

3) TensorFlow Serving環境の構築とサーバの起動

パッケージ管理システムとしてaptを利用できる場合、環境の構築は非常に簡単です。
公式サイト紹介されている通り、GoogleのTensorFlow Serving用リポジトリを追加後apt-getコマンドでインストールすればOKです。

$ echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
 
$ curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
 
$ sudo apt-get update && sudo apt-get install tensorflow-model-server

サーバの起動は以下のように行います。

$ tensorflow_model_server --port=9000 --model_name=(学習済みモデルの識別名) --model_base_path=(学習済みモデルの保存先ディレクトリのフルパス)

この記述では指定した1つの学習済みモデルのみを読み込みますが、設定ファイルを別途記述すれば複数の学習済みモデルを同時にサービス化することもできます。

4) JavaからTensorFlow Servingを呼び出す

(1) Javaプロジェクトの準備

TensorFlow ServingにはGoogleが開発したRPCの実装であるgRPCを用いて接続します。
Javaの外部ライブラリとして、gRPC本体と、gRPCで使用するTensorFlow Serving用のインタフェースを定義するものを参照する必要があります。

Mavenを使う場合、gRPC本体は依存関係に以下を追加すれば利用可能です。

<dependency>
  <groupId>io.grpc</groupId>
  <artifactId>grpc-netty</artifactId>
  <version>1.9.0</version>
</dependency>
<dependency>
  <groupId>io.grpc</groupId>
  <artifactId>grpc-protobuf</artifactId>
  <version>1.9.0</version>
</dependency>
<dependency>
  <groupId>io.grpc</groupId>
  <artifactId>grpc-stub</artifactId>
  <version>1.9.0</version>
</dependency>

gRPCで使用するTensorFlow Serving用のインタフェースについては、その定義がprotobufという形式で配布されており、公式な手順としてはこれをJavaのクラスに自分で変換して埋め込むということが必要です。
有志開発者によりあとはMavenでビルドするだけの状態に整えられたプロジェクトが公開されていますので、利用するとよいでしょう。

(2) 画像のロード

今回作成した学習済みモデルは、入力として固定長のfloat値の集合を受け取りますので、判定したい画像をfloat値のリストに変換してやる必要があります。
そのような操作を行うライブラリを探してみたのですが、そのものずばり実現できるものは見つからなかったので下記のようなクラスを作成しました。

package jp.scsk.furuba.wi.service;
 
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
 
import javax.imageio.ImageIO;
 
public class ImageLoader {
    private static final int NUM_CHANNELS = 3;
    private final int width;
    private final int height;
 
    public ImageLoader(int width, int height) {
        this.width = width;
        this.height = height;
    }
 
    public List<Float> load(String path) {
        BufferedImage rawImage;
        try (InputStream imgIs = new FileInputStream(new File(path))) {
            rawImage = ImageIO.read(imgIs);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
 
        BufferedImage image = new BufferedImage(width, height, rawImage.getType());
        Image scaledImage =
            rawImage.getScaledInstance(width, height, Image.SCALE_AREA_AVERAGING);
        image.getGraphics().drawImage(scaledImage, 0, 0, width, height, null);
 
        List<Float> flatImage = new ArrayList<>(width * height * NUM_CHANNELS);
        for (int x = 0; x < width; x++) {
            for (int y = 0; y < height; y++) {
                int rgb = image.getRGB(x, y);
                // intで表現されているRGB値をRGBの各チャンネルに分解し、0-1のfloatに変換
                flatImage.add(toFloatRed(rgb));
                flatImage.add(toFloatGreen(rgb));
                flatImage.add(toFloatBlue(rgb));
            }
        }
        return flatImage;
    }
 
    private static final float MAX_PIXEL_VALUE = 255;
 
    public static float toFloatRed(int rgb) {
        return (rgb >> 16 &amp; 0xff) / MAX_PIXEL_VALUE;
    }
 
    public static float toFloatGreen(int rgb) {
        return (rgb >> 8 &amp; 0xff) / MAX_PIXEL_VALUE;
    }
 
    public static float toFloatBlue(int rgb) {
        return (rgb &amp; 0xff) / MAX_PIXEL_VALUE;
    }
}

画像を読み込んで指定のサイズに変換し、Floatのリストとして出力するようにしています。
今回手元にある学習済みモデルは224ピクセル×224ピクセル×3チャンネルのデータを入力として受け取りますので、たとえば以下のように使うことになります。

ImageLoader loader = new ImageLoader(224, 224);
List<Float> flatImage = loader.load("(画像のパス)");

なお、処理にAWT関連のクラスを使用している関係上、サーバ環境で動かすとエラーとなる場合があります。
その際にはJavaの起動オプションとして-Djava.awt.headless=trueを指定してください。

(3) データをサーバで処理

無事入力データを得たところで、それを実際にサーバで処理していきます。
まずはサーバとの通信を担うStubと呼ばれるオブジェクトを作成します。

ManagedChannel grpcChannel = ManagedChannelBuilder
    .forAddress("(サーバの名前かIPアドレス)", 9000)
    .usePlaintext(true)
    .build();
PredictionServiceBlockingStub grpcStub =
    PredictionServiceGrpc.newBlockingStub(grpcChannel);

Stubにはいくつか種類がありますが、今回は同期的に処理を書きたいので、サーバからの応答を待つ間ブロックするBlockingStubを利用しています。

次にサーバに送信するリクエストを作成し、送信します。
判定したい画像のデータはflatImage変数に入っている想定です。

ClassificationRequest request;
{
    ExampleList exampleList = ExampleList.newBuilder()
        .addExamples(buildExample(flatImage))
        .build();
 
    request = ClassificationRequest.newBuilder()
        .setModelSpec(ModelSpec.newBuilder().setName("(学習済みモデルの識別名)"))
        .setInput(Input.newBuilder().setExampleList(exampleList))
        .build();
}

ClassificationRequestなどのクラスはprotoファイルから生成したもので、基本的にBuilderによってオブジェクトを作成します。
ClassificationRequestのBuilderにおいて、setModelSpecではTensorFlow Servingの保持しているどのモデルに宛てたリクエストなのかということを指定し、
setInputでは送信する入力データを指定しています。

送信する入力データであるexampleListの作成において、このコードでは1件だけ画像データをセットしていますが、addExamplesメソッドを繰り返し呼ぶことで複数件のデータをまとめてリクエストに含めることもできます。

上記のコードで使用しているbuildExample関数の定義は以下のとおりです:

private static Example buildExample(List<Float> flatImage) {
    FloatList floatList = FloatList.newBuilder().addAllValue(flatImage).build();
    Feature feature = Feature.newBuilder().setFloatList(floatList).build();
    Features features = Features.newBuilder().putFeature("img", feature).build();
    Example example = Example.newBuilder().setFeatures(features).build();
    return example;
}

List<Float>であるところのflatImageを、TensorFlow Servingが要求する入力データの形であるExampleに変換しています。

その後、grpcStubを通してリクエストの送信を行います。

ClassificationResponse response = grpcStub
    .withDeadlineAfter(10, TimeUnit.SECONDS)
    .classify(request);

(4) 結果の確認

先程得たresponseから処理の結果が取得できます。
処理結果は、入力データ1件ごとにClassificationsというクラスのオブジェクトに格納されています。これは分類タスクの結果を格納するものでで、getClassesメソッドを用いて分類ごとのスコア(ここでは猫のスコアと犬のスコア)にアクセスできます。

List<Classifications> results = response.getResult().getClassificationsList()
Classifications firstResult = results.get(0)
 
System.out.println(String.format("猫らしさ:%f", firstResult.getClasses(0).getScore()));
System.out.println(String.format("犬らしさ:%f", firstResult.getClasses(1).getScore()));

ここまでのコード(buildExample関数を除く)をまとめると、以下の通りになります。

ImageLoader loader = new ImageLoader(224, 224);
List<Float> flatImage = loader.load("(画像のパス)");
 
ManagedChannel grpcChannel = ManagedChannelBuilder
    .forAddress("(サーバの名前かIPアドレス)", 9000)
   .usePlaintext(true)
    .build();
PredictionServiceBlockingStub grpcStub = PredictionServiceGrpc.newBlockingStub(grpcChannel);
 
ClassificationRequest request;
{
    ExampleList exampleList = ExampleList.newBuilder()
        .addExamples(buildExample(flatImage))
        .build();
 
    request = ClassificationRequest.newBuilder()
        .setModelSpec(ModelSpec.newBuilder().setName("(学習済みモデルの識別名)"))
        .setInput(Input.newBuilder().setExampleList(exampleList))
        .build();
}
 
ClassificationResponse response = grpcStub
    .withDeadlineAfter(10, TimeUnit.SECONDS)
    .classify(request);
 
List<Classifications> results = response.getResult().getClassificationsList()
Classifications firstResult = results.get(0)
 
System.out.println(String.format("猫らしさ:%f", firstResult.getClasses(0).getScore()));
System.out.println(String.format("犬らしさ:%f", firstResult.getClasses(1).getScore()));

これを手元にある犬の画像に対して実行したところ、下記のような結果を得ました:

猫らしさ:0.016830
犬らしさ:0.983170

無事、JavaからTensorFlowの学習済みモデルを利用することができました!

おわりに

TensorFlow Servingを用いると、学習済みモデルをサービス化し、gRPCが使える限り好きな言語から利用することができます。
これにより、ディープラーニング関連の作業はPython環境で行いつつ、Javaのお堅いシステムからそれを利用するということが実現できます。

さて、Kerasお仕着せの楽勝ディープラーニングからスタートしたこの解説も、やや駆け足ではありましたが、本番のビジネスで使える道具立てがそろうところまで、たどり着きました。
ここから先は現場の課題に応じて、うまくディープラーニングや、その他の手段を適用していくという段階になります。
この解説が、あなたの現場の課題解決にいつか役立つことを祈念しつつ、このあたりで終わりにしたいと思います。

ここまでお読みいただき、ありがとうございました。

1 2 3 4

5

PAGE TOP