• AI

説明可能AIの実現方法:LIME、SHAP

はじめに

昨今、「AIブーム」ということでAIが着目されています。
しかし、AIの社会実装が順調に進んでいるかと聞かれると、そうではないというのが実態だと思います。
その理由にはさまざまありますが、その1つが「AIのブラックボックス性」だと思います。
そこで、今回は「AIのブラックボックス性」の解消に向けて、「説明可能AI」の実例をサンプルを交えて紹介します。

説明可能AIとAIのブラックボックス性

そもそも「AIのブラックボックス性」とはどういうことなのでしょうか?
機械学習に話を絞って考えると、(理解できるかどうかは別として)内部でどのような計算をしているかは可視化することができます。
一方で、ブラックボックスの辞書的な意味は内部構造や処理過程が不明な装置や仕組みとなります。
ブラックボックスという言葉の辞書的な意味を考えると、「AIのブラックボックス性」という表現はちゃんちゃらおかしいことになります。
では、なぜ「AIのブラックボックス性」という言葉が使われるかというと、機械学習モデルの内部で行っている計算が非常に複雑なせいで簡単に理解することができないからだと思います。
逆に言えば、機械学習モデルの内部で行っている複雑な計算を(直感的な)解釈ができるような形にすることができれば「AIのブラックボックス性」解消につながるはずです。
これを実現するための仕組みが「説明可能AI」になります。

LIME、SHAPの概要

今回は「説明可能AI」の手法としてLIMEとSHAPを利用します。
これらは、ある特定の入力に対する予測の判断根拠を示す手法となっています。
もう少し具体的には、表形式データ(数値データ)であればどの変数が予測に効いているか、画像データであれば画像のどの部分が予測に効いているかを示すことができます。

利用するデータとモデル

今回利用するデータは表形式データ(数値データ)と画像データの2種類です。
また、モデルとして、表形式データ(数値データ)の場合はランダムフォレストを、画像データの場合はResNet50を利用します。
ランダムフォレストは決定木のアンサンブルモデルで、ResNet50は深層学習のモデルとなります。
具体的にどのようなデータを用いるかはそれぞれ以下の通りです。

表形式データ(数値データ)

  • WISDM Activity Classification
  • 以下の画像に記載されているXYZ方向の加速度データと、その際の人間の行動
  • 加速度データから人間の行動を予測する
  • 人の行動は6種類(walking, jogging, upstairs, downstairs, sitting, standing)
  • 1つのデータとして、24時点のXYZ方向の加速度を利用
  • 各加速度の値を、XYZ方向ごとに最大値が1で最小値が0になるように正規化して利用
  • サンプルコードでは、上記リンク先から取得したデータを以下のようなCSV形式に変換して、適切に学習データ(train data)と検証データ(validation data)に分割したファイルを利用

出典: Jennifer R. Kwapisz, Gary M. Weiss and Samuel A. Moore (2010). Activity Recognition using Cell Phone Accelerometers, Proceedings of the Fourth International Workshop on Knowledge Discovery from Sensor Data (at KDD-10), Washington DC.

time,label,data1,data2,data3
1,Jogging,-0.6946377,12.680544,0.50395286
2,Jogging,5.012288,11.264028,0.95342433
3,Jogging,4.903325,10.882658,-0.08172209

画像データ

  • Caltech256
  • 256クラス(ラベル)の画像データ
  • 画像からそのクラス(ラベル)を予測する

LIME、SHAPによる説明実施の方法

LIMEとSHAPによる説明を行うために用いるパッケージはGitHub上で公開されているため、こちらを利用します。
なお、どちらも商用利用可能なライセンス(それぞれ、BSD 2-Clause "Simplified" LicenseMIT License)で公開されています。
また、どちらもPythonのパッケージ管理用の仕組みであるPyPIを通してインストール可能です。

なお、以下の説明では、定義されている引数の一部のみについて記載をしています。
詳しく知りたい方は、公開されている各GitHubを参照ください。

LIMEの利用方法 - 共通の流れ

以下の3処理を実行することで説明を行うことができます。

  1. 説明用のクラスオブジェクトを作成
  2. クラスオブジェクトに説明したいデータを投入し、説明に用いるためのモデルを作成
  3. 結果を図や数値として出力

LIMEの利用方法 - 表形式データ(数値データ)

  1. 以下の引数を設定し、説明用のクラスオブジェクトLimeTabularExplainerを作成
    • training_data:説明用のモデルを作成するために利用するデータセット
    • feature_names:特徴量の名称
    • class_names:クラス(ラベル)名のリスト
    • random_state:乱数のシード値
      • ただし、以下のメソッドでランダムに値を生成する処理があるため、これを設定しても再現性は担保されない
  2. 説明用のクラスオブジェクトに対して、以下の引数を指定しexplain_instanceメソッドを実行
    • data_row:説明したいデータ
    • predict_fn:以下の引数と返却値を持つ予測用の関数
      • 引数:複数データの配列
      • 返却値:引数として与えられた複数データに対する予測結果を、各クラス(ラベル)のたしからしさとしてあらわした配列
    • labels:説明したいクラス(ラベル)のインデックス
    • num_features:表示する特徴量の最大値
  3. explain_instanceメソッドの返却値に対して、以下の通り実行することで図や数値として説明を出力
    • 図:show_in_notebookメソッド(Jupiter Notebook上に出力)
    • 数値:as_mapメソッド

LIMEの利用方法 - 画像データ

  1. 以下の引数を指定した説明用のクラスオブジェクトLimeImageExplainerを作成
    • random_state:乱数のシード値(固定することで再現性を担保)
  2. 説明用のクラスオブジェクトに対して、以下の引数を指定しexplain_instanceメソッドを実行
    • image:説明したいデータ
    • predict_fn:以下の引数と返却値を持つ予測用の関数
      • 引数:複数データの配列
      • 返却値:引数として与えられた複数データに対する予測結果を、各クラス(ラベル)のたしからしさとしてあらわした配列
    • labels:説明したいクラス(ラベル)のインデックス
    • num_samples:サンプルとして利用するデータ数(大きすぎると時間がかかり、少なすぎるとエラーが発生)
  3. explain_instanceメソッドの返却値に対して、以下の通り実行することで図として説明を出力
    • 図:get_image_and_maskメソッドで返却される画像を加工

SHAPの利用方法 - 共通の流れ

以下の3処理を実行することで説明を行うことができます。

  1. 説明用のクラスオブジェクトを作成
  2. 説明に用いるSHAP値を計算(説明に用いるためのモデルを作成)
  3. 結果を図や数値として出力
    • 図を出力するには、shap.initjs()による初期化が必要

SHAPの利用方法 - 表形式データ(数値データ)

今回はランダムフォレスト(決定木のアンサンブルモデル)を利用しているため、説明用のクラスとしてTreeExplainerを利用します。

  1. 以下の引数を指定した説明用のクラスオブジェクトTreeExplainerを作成
    • model:学習済みモデル
  2. 説明用のクラスオブジェクトに対して、以下の引数を指定したshap_valuesメソッドを実行することでSHAP値を計算
    • X:説明したいデータを含むデータ群
      • 取得したSHAP値は、(クラス(ラベル)数, Xのデータ数, データの特徴量の数)の大きさを持つ3次元配列
  3. 計算したSHAP値を利用して、以下の通り説明を取得
    • 図:以下の引数を指定してshap.force_plotメソッドを実行(Jupiter Notebook上に出力)
      • base_value:説明用クラスオブジェクトのexpected_value属性の、説明したいクラス(ラベル)に対する値
      • shap_values:SHAP値の、説明したいクラス(ラベル)及び説明したいデータに対する値
      • features:説明したいデータ
    • 数値:計算したSHAP値自体を利用

SHAPの利用方法 - 画像データ

今回は深層学習のモデルを利用しているため、説明用のクラスとしてGradientExplainerを利用します。

  1. 以下の引数を指定した説明用のクラスオブジェクトGradientExplainerを作成
    • model:学習済みモデル、または、(学習済みモデル, そのレイヤー)のタプル
    • data:説明したいデータを含むデータ群
  2. 説明用のクラスオブジェクトに対して、以下の引数を指定したshap_valuesメソッドを実行することでSHAP値を計算
    • X:説明したいデータを含むデータ群
    • ranked_outputs:上位何ラベルまで表示するかの数値
    • output_rank_order:model出力値の評価方法(“max”, “min”, “max_abs”, “custom”で指定)
    • rseed:乱数のシード値(固定することで再現性を担保)
      • 取得したSHAP値は、(クラス(ラベル)数, Xのデータ数, データの特徴量の数)の大きさを持つ3次元配列
  3. 計算したSHAP値を利用して、以下の通り説明を取得
    • 図:以下の引数を指定しshap.force_plotメソッドを実行(Jupiter Notebook上に出力)
    • ただし、説明したいデータはHWC形式(Pillow形式)である必要がある(画像を表す配列の順番としてHWCとCHWが存在。異なる場合は別途実装必要)
      • shap_values:説明用クラスオブジェクトのexpected_value属性の、説明したいクラス(ラベル)に対する値
      • pixel_values:説明したいデータ
      • labels:説明したいクラス(ラベル)の名前

LIME、SHAPによる説明結果

表形式データ(数値データ)

表形式データ(数値データ)に対してLIMEとSHAPで説明を実施すると、それぞれ以下のように図が出力されます。
説明対象としては、正解も予測もjoggingとなるデータを使用しています。

  • LIMEの場合
  • SHAPの場合

寄与度の数値だけを取り出すと以下のようになります。
ここでは、寄与度の上位6特徴量(変数)のみ取り出しています。

  • LIMEの場合
特徴量(変数) 寄与度
y13 0.41 0.068
y6 0.467 0.066
y18 0.35 0.052
y9 0.569 0.049
y15 0.5 0.044
y11 0.435 0.039
  • SHAPの場合
特徴量(変数) 寄与度
y11 0.435 0.166
y6 0.467 0.133
y18 0.35 0.11
y13 0.41 0.085
y20 0.512 0.037
y17 0.961 0.023

どちらも全体的にY軸方向(足を上下に動かす方向)の特徴量がよく利いていて、走ったり歩いたりした際にこの方向の動きがあるのでこの寄与度はおかしくないことがわかると思います。

画像データ

画像データに対してLIMEとSHAPで説明を実施すると、それぞれ以下のような図が出力されます。
LIMEでは黄色枠で囲まれた黄緑色の部分が予測に効いていて、SHAPでは「250.zebra」と記載された下の画像で赤くなっている部分が予測に効いていることを表しています。

  • LIMEの場合
  • SHAPの場合

画像の場合は、LIMEの方が良さそうなことがわかると思います。
LIMEでは、シマウマの顔の部分や体の縞の部分に着目している結果となっていますが、SHAPではシマウマの部分だけでなく背景の原っぱにも着目していて全体的にどこを見ているのかがわからない結果となっています。

利用したコード

実際に利用したコードは以下のようになります。
なお、Jupyter Notebook上での利用を前提としています。
また、AnacondaでPython 3.7.3をインストールしたうえで、必要なパッケージ(lime, shap, PyTorch, tqdm)をインストールした環境を実行環境としています。
コードは、それぞれ以下の流れで記載しています。

  • import文
  • データ読み込み
  • データ前処理
  • 学習
  • 評価
  • LIMEによる説明
  • SHAPによる説明

表形式データ(数値データ)の場合

import numpy as np
import pandas as pd
import sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import matplotlib
import matplotlib.pyplot as plt
import lime
import shap
%matplotlib inline


# CSVファイル読み込み
train_path_data = './wave_train.csv'
val_path_data = './wave_test.csv'
train_data = pd.read_csv(train_path_data)
val_data = pd.read_csv(val_path_data)

# ラベル定義
x_label = ['data1', 'data2', 'data3']
y_label = 'label'

# 学習時の1データ当たりの行数定義
nperseg = 24
noverlap = 12


# 関数定義
# 複数行を1データにまとめる関数
def pre_process(data, nperseg, noverlap, x_labels, y_label):
    # 返却値初期化
    return_data = pd.DataFrame()
    # 複数行を1データにまとめる
    seg_data = [
        data.iloc[i:i+nperseg, :] for i in range(
            0, len(data)-1-nperseg+1, nperseg-noverlap)]

    # 1データ間に複数の分類(ラベル)が存在していないかのチェック
    for seg in seg_data:
        # check the same label
        if len(seg[y_label].unique()) == 1:
            # 目的変数(正解ラベル)取得
            tmp_list = [seg.iloc[0][y_label]]

            # 説明変数取得
            for i in range(nperseg):
                for x_label in x_labels:
                    tmp_list.append(seg.iloc[i][x_label])

            tmp_data = pd.DataFrame([tmp_list])

            return_data = pd.concat([return_data, tmp_data])

    return return_data


# 正規化を行う関数
def min_max_normalization(data, labels, max_list=None, min_list=None):
    if max_list is None:
        max_list = [data[label].max() for label in labels]

    if min_list is None:
        min_list = [data[label].min() for label in labels]

    for label_num, label in enumerate(labels):
        data[label] = (data[label] - min_list[label_num]) / \
            (max_list[label_num] - min_list[label_num])

    return data, max_list, min_list


# データ前処理:時間かかる
# 0-1正規化
train_data, max_list, min_list = min_max_normalization(
    train_data, x_label)
val_data, _, _ = min_max_normalization(
    val_data, x_label, max_list=max_list, min_list=min_list)

# 複数行を1データにまとめる
train_data = pre_process(train_data, nperseg, noverlap, x_label, y_label)
val_data = pre_process(val_data, nperseg, noverlap, x_label, y_label)
train_data = train_data.reset_index(drop=True)
val_data = val_data.reset_index(drop=True)

# 目的変数(正解ラベル)と説明変数への分離
train_x = train_data.iloc[:, 1:]
train_y = train_data.iloc[:, 0]
val_x = val_data.iloc[:, 1:]
val_y = val_data.iloc[:, 0]

# ヘッダー修正(x1, y1, z1, x2,...)
header_list = []
for i in range(1, nperseg * 3 + 1):
    if i % 3 == 1:
        header = 'x' + str(i // 3 + 1)
    elif i % 3 == 2:
        header = 'y' + str(i // 3 + 1)
    else:
        header = 'z' + str(i // 3)

    header_list.append(header)

train_x.columns = header_list
val_x.columns = header_list

# ランダムフォレストでの学習
model_rf = RandomForestClassifier()
model_rf.fit(train_x, train_y)

# 作成したモデルの評価
pred_y = model_rf.predict(val_x)
acc = accuracy_score(val_y, pred_y)
print(acc)

# LIMEによる説明
# クラスオブジェクト初期化
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
    np.array(val_x),
    feature_names=header_list,
    class_names=model_rf.classes_,
    random_state=0
)

# 1つのデータに対して説明を実施
index_lime = 130
pred_label = model_rf.predict(val_x.iloc[index_lime:index_lime+1])[0]
pred_label_no = list(model_rf.classes_).index(pred_label)
exp = lime_explainer.explain_instance(
    np.array(val_x.iloc[index_lime, :]),
    model_rf.predict_proba,
    num_features=10,
    labels=[pred_label_no])

# LIMEの標準機能によるJupyter-Notebook上への表示
exp.show_in_notebook()
# LIMEの標準機能によるHTML形式での保存
exp.save_to_file('./lime_' + str(index_lime) + '.html')
# 説明結果の数値情報出力
print(exp.as_map())

# SHAPによる説明
# クラスオブジェクト初期化
shap.initjs()
shap_explainer = shap.TreeExplainer(model_rf)

# データが多いと計算時間がかかるので、データを削減
_, val_x_10, _, val_y_10 = sklearn.model_selection.train_test_split(
    val_x, val_y, test_size=0.10, random_state=0)
# SHAP値の計算
shap_values = shap_explainer.shap_values(val_x)
# shap_values_10 = shap_explainer.shap_values(val_x_10)

# SHAPの標準機能によるJupyter-Notebook上への表示
index_shap = 130
pred_label = model_rf.predict(val_x.iloc[index_shap:index_shap+1])[0]
pred_label_no = list(model_rf.classes_).index(pred_label)

# SHAPの標準機能によるJupyter-Notebook上への表示
shap.force_plot(
    shap_explainer.expected_value[pred_label_no],
    shap_values[pred_label_no][index_shap, :],
    features=val_x.iloc[index_shap, :])

# 説明結果の数値情報出力
print(shap_values[pred_label_no][index_shap, :])

画像データの場合

ただし、SHAPの場合は標準機能による出力を行うと画像の配列順番(HWC、CWH)による問題が発生するため、GitHub上のコードを修正してHWCの形式に変換することで画像を出力できるようにするための処理を加えています。

import os
import numpy as np
from PIL import Image
import torch
from torch import nn
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm_notebook
import datetime
from skimage.segmentation import mark_boundaries
import lime.lime_image
import shap
from shap.plots import colors
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline


# ハイパパラメータ設定
batch_size = 8
lr = .0001  # for sgd
epochs = 150
log_interval = 1000
num_classes = 256
# GPU利用するなら設定
device = torch.device('cuda')

# 事前学習モデルの利用
model = torchvision.models.resnet50(pretrained=True)

# fine-tuningのために全結合層以外の層の重みを固定
for name, param in model.named_parameters():
    if not name.startswith('fc'):
        param.detach_()

# 出力層をクラス数に合わせる
model.fc = nn.Linear(model.fc.in_features, num_classes)

# GPU利用するなら設定
model.cuda()
model.to(device)

# optimizer設定
optimizer = optim.SGD(model.parameters(), lr=lr)
# loss関数設定
loss_func = nn.CrossEntropyLoss()

# 学習用データセット読み込み
train_dataset = datasets.ImageFolder(
    './train',
    # 画像前処理設定
    transform=transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor()])
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)

# 学習実行
model.train()
for epoch in tqdm_notebook(range(1, epochs + 1)):
    for batch_idx, (images, target) in enumerate(train_loader):
        images = images.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = loss_func(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(images), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

# 学習済みモデル保存
today = datetime.date.today()
path_pytorch_model = 'torch_model/resnet50_' + today.strftime('%Y%m%d') + '.pth'
torch.save(model.state_dict(), path_pytorch_model)

# 作成したモデルの評価
val_dataset = datasets.ImageFolder(
    './validation',
    transform=transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor()])
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    shuffle=False
)

# ラベル一覧
list_labels = val_dataset.classes

count_all = 0
count_acc = 0

model.eval()
with torch.no_grad():
    for images_val, target_val in val_loader:
        count_all += 1
        images_val = images_val.to(device)
        target_val = target_val.to(device)
        output_val = model(images_val)
        if output_val[0].argmax().data == target_val.data:
            count_acc += 1

print('Acc: ' + str(count_acc / count_all))

# 説明用画像
path_zebra = './validation/250.zebra'
path_zebra_data = sorted([os.path.join(
    path_zebra, i) for i in os.listdir(path_zebra)])


# 関数定義
# LIME用の予測(推論)関数
def predict_for_lime(images):
    ret_predict_scores = []

    with torch.no_grad():

        for image in images:

            # resize iage
            to_pil_image = transforms.ToPILImage()
            pil_image = to_pil_image(image)
            image = torchvision.transforms.functional.resize(
                pil_image, (224, 224))

            # to Tensor
            to_tensor = torchvision.transforms.ToTensor()
            tensor_image = to_tensor(np.asarray(image))
            tensor_image = tensor_image[None, :]

            tensor_image = tensor_image.to(device)

            # predict
            output_val = model(tensor_image)
            output_val = output_val.cpu()

            ret_predict_scores.append(output_val.numpy()[0])

    return np.array(ret_predict_scores)


# LIMEによる説明
# クラスオブジェクト初期化
explainer = lime.lime_image.LimeImageExplainer(random_state=0)

index_lime = 6

# 画像の読み込み
with open(path_zebra_data[index_lime], 'rb') as f:
    img = Image.open(f)
    img_pil = img.convert('RGB')
    img_pil = np.asarray(img_pil)

explanation = explainer.explain_instance(
    img_pil,
    predict_for_lime,
    hide_color=0.9,
    num_features=100,
    num_samples=1000)  # number of images that will be sent to classification function

# 説明用画像出力
temp, mask = explanation.get_image_and_mask(
    explanation.top_labels[0], positive_only=False, num_features=10, hide_rest=False)
img_boundry = mark_boundaries(temp / 255.0, mask)
plt.imshow(img_boundry)

# SHAPによる説明
# SHAPによる説明に利用する画像の生成
# 画像の変換に用いるメソッド定義
to_pil_image = transforms.ToPILImage()
to_tensor = transforms.ToTensor()

# 画像用変数の初期化
tensor_images = torch.Tensor()

for path_data in path_zebra_data:
    with open(path_data, 'rb') as f:
        img = Image.open(f)
        img_pil = img.convert('RGB')
        img_pil = np.asarray(img_pil)

    tmp_pil_image = to_pil_image(img_pil)
    tmp_resize_image = transforms.functional.resize(
        tmp_pil_image, (224, 224))

    # to Tensor
    tmp_tensor_image = to_tensor(np.asarray(tmp_resize_image))
    tensor_images = torch.cat((tensor_images, tmp_tensor_image[None, :]))

index_shap = 6

# GPU利用するなら設定
tensor_images = tensor_images.to(device)

# SHAPによる説明1
# クラスオブジェクト初期化
grad_explainer1 = shap.GradientExplainer(model, tensor_images)
shap_values1, indexes1 = grad_explainer1.shap_values1(
    tensor_images[index_shap:index_shap+1], ranked_outputs=2, rseed=0, output_rank_order='max')

# 予測ラベル名取得
index_names1 = np.vectorize(lambda x: list_labels[x])(indexes1.cpu())
to_explain = np.array(tensor_images[index_shap:index_shap+1].cpu())

# 画像描画
labels = index_names1
width = 20
aspect = 0.2
hspace = 0.2
labelpad = None
show = True

multi_output = True
if type(shap_values1) != list:
    multi_output = False
    shap_values1 = [shap_values1]

# make sure labels
if labels is not None:
    assert labels.shape[0] == shap_values1[0].shape[0], "Labels must have same row count as shap_values1 arrays!"
    if multi_output:
        assert labels.shape[1] == len(shap_values1), "Labels must have a column for each output in shap_values1!"
    else:
        assert len(labels.shape) == 1, "Labels must be a vector for single output shap_values1."

label_kwargs = {} if labelpad is None else {'pad': labelpad}

# plot our explanations
x = to_explain.transpose(0, 2, 3, 1)
fig_size = np.array([3 * (len(shap_values1) + 1), 2.5 * (x.shape[0] + 1)])

if fig_size[0] > width:
    fig_size *= width / fig_size[0]

fig, axes = plt.subplots(
    nrows=x.shape[0], ncols=len(shap_values1) + 1, figsize=fig_size)

if len(axes.shape) == 1:
    axes = axes.reshape(1, axes.size)

for row in range(x.shape[0]):
    x_curr = x[row].copy()

    # make sure
    if len(x_curr.shape) == 3 and x_curr.shape[2] == 1:
        x_curr = x_curr.reshape(x_curr.shape[:2])
    if x_curr.max() > 1:
        x_curr /= 255.

    # get a grayscale version of the image
    if len(x_curr.shape) == 3 and x_curr.shape[2] == 3:
        x_curr_gray = (0.2989 * x_curr[:,:,0] + 0.5870 * x_curr[:,:,1] + 0.1140 * x_curr[:,:,2]) # rgb to gray
    else:
        x_curr_gray = x_curr

    axes[row, 0].imshow(x_curr, cmap=plt.get_cmap('gray'))
    axes[row, 0].axis('off')

    if len(shap_values1[0][row].shape) == 2:
        abs_vals = np.stack([np.abs(shap_values1[i]) for i in range(len(shap_values1))], 0).flatten()
    else:
        abs_vals = np.stack([np.abs(shap_values1[i].sum(-1)) for i in range(len(shap_values1))], 0).flatten()

    max_val = np.nanpercentile(abs_vals, 99.9)

    for i in range(len(shap_values1)):
        if labels is not None:
            axes[row,i+1].set_title(labels[row, i], **label_kwargs)

        shap_value = shap_values1[i]
        shap_value = shap_value.transpose(0, 2, 3, 1)

        # sv = shap_values1[i][row] if len(shap_values1[i][row].shape) == 2 else shap_values1[i][row].sum(-1)
        sv = shap_value[row] if len(shap_value[row].shape) == 2 else shap_value[row].sum(-1)
        axes[row, i+1].imshow(x_curr_gray, cmap=plt.get_cmap('gray'), alpha=0.15, extent=(-1, sv.shape[1], sv.shape[0], -1))
        im = axes[row,i+1].imshow(sv, cmap=colors.red_transparent_blue, vmin=-max_val, vmax=max_val)
        axes[row, i+1].axis('off')

if hspace == 'auto':
    fig.tight_layout()
else:
    fig.subplots_adjust(hspace=hspace)

cb = fig.colorbar(im, ax=np.ravel(axes).tolist(), label="SHAP value", orientation="horizontal", aspect=fig_size[0]/aspect)

cb.outline.set_visible(False)
if show:
    plt.show()

fig.savefig('shap_gradient1_シマウマ.png')

# SHAPによる説明2
grad_explainer2 = shap.GradientExplainer(
    (model, model.layer2), tensor_images, local_smoothing=0)
shap_values2, indexes2 = grad_explainer2.shap_values(
    tensor_images[index_shap:index_shap+1], ranked_outputs=2, rseed=0, output_rank_order='max')

# 予測ラベル名取得
index_names2 = np.vectorize(lambda x: list_labels[x])(indexes2.cpu())
to_explain = np.array(tensor_images[index_shap:index_shap+1].cpu())

# 画像描画
labels = index_names2
width = 20
aspect = 0.2
hspace = 0.2
labelpad = None
show = True

multi_output = True
if type(shap_values2) != list:
    multi_output = False
    shap_values2 = [shap_values2]

# make sure labels
if labels is not None:
    assert labels.shape[0] == shap_values2[0].shape[0], "Labels must have same row count as shap_values arrays!"
    if multi_output:
        assert labels.shape[1] == len(shap_values2), "Labels must have a column for each output in shap_values!"
    else:
        assert len(labels.shape) == 1, "Labels must be a vector for single output shap_values."

label_kwargs = {} if labelpad is None else {'pad': labelpad}

# plot our explanations
x = to_explain.transpose(0, 2, 3, 1)
fig_size = np.array([3 * (len(shap_values2) + 1), 2.5 * (x.shape[0] + 1)])

if fig_size[0] > width:
    fig_size *= width / fig_size[0]

fig, axes = plt.subplots(
    nrows=x.shape[0], ncols=len(shap_values2) + 1, figsize=fig_size)

if len(axes.shape) == 1:
    axes = axes.reshape(1, axes.size)

for row in range(x.shape[0]):
    x_curr = x[row].copy()

    # make sure
    if len(x_curr.shape) == 3 and x_curr.shape[2] == 1:
        x_curr = x_curr.reshape(x_curr.shape[:2])
    if x_curr.max() > 1:
        x_curr /= 255.

    # get a grayscale version of the image
    if len(x_curr.shape) == 3 and x_curr.shape[2] == 3:
        x_curr_gray = (0.2989 * x_curr[:,:,0] + 0.5870 * x_curr[:,:,1] + 0.1140 * x_curr[:,:,2]) # rgb to gray
    else:
        x_curr_gray = x_curr

    axes[row, 0].imshow(x_curr, cmap=plt.get_cmap('gray'))
    axes[row, 0].axis('off')

    if len(shap_values2[0][row].shape) == 2:
        abs_vals = np.stack([np.abs(shap_values2[i]) for i in range(len(shap_values2))], 0).flatten()
    else:
        abs_vals = np.stack([np.abs(shap_values2[i].sum(-1)) for i in range(len(shap_values2))], 0).flatten()

    max_val = np.nanpercentile(abs_vals, 99.9)

    for i in range(len(shap_values2)):
        if labels is not None:
            axes[row,i+1].set_title(labels[row, i], **label_kwargs)

        shap_value = shap_values2[i]
        shap_value = shap_value.transpose(0, 2, 3, 1)

        # sv = shap_values[i][row] if len(shap_values[i][row].shape) == 2 else shap_values[i][row].sum(-1)
        sv = shap_value[row] if len(shap_value[row].shape) == 2 else shap_value[row].sum(-1)
        axes[row, i+1].imshow(x_curr_gray, cmap=plt.get_cmap('gray'), alpha=0.15, extent=(-1, sv.shape[1], sv.shape[0], -1))
        im = axes[row,i+1].imshow(sv, cmap=colors.red_transparent_blue, vmin=-max_val, vmax=max_val)
        axes[row, i+1].axis('off')

if hspace == 'auto':
    fig.tight_layout()
else:
    fig.subplots_adjust(hspace=hspace)

cb = fig.colorbar(im, ax=np.ravel(axes).tolist(), label="SHAP value", orientation="horizontal", aspect=fig_size[0]/aspect)

cb.outline.set_visible(False)
if show:
    plt.show()

fig.savefig('shap_gradient2_シマウマ.png')

まとめ

このように、GitHub上に公開されているパッケージを利用することで「説明」を行うことができます。
ただし、今回紹介したLIMEやSHAP以外にも、「説明可能AI」の手法は数多く存在しています。
どのような「説明」が必要なのかをきちんと判断したうえで、適切な手法を選ぶことが必要になると思います。

関連記事

  1. DeepstreamでストリームAI処理する方法について

  2. 外国人に「SCSKってAIの会社?」と言わせた話(国際情報オリンピック…

  3. 今、もっともアツい決定木「XGBoost」

  4. 「エッジAI」への期待と課題

  5. [初心者向け]深層学習の勉強法

  6. RGB画像からの深度推定手法の精度比較

  7. 深層強化学習の学習経過を見てみた

  8. お手軽な機械学習プラットフォーム、H2O

PAGE TOP