Fire Engine

消防士→ITエンジニア→研究者

機械学習の予測を解釈するKernel SHAPの高速性と拡張性の向上を目指したライブラリを開発した

先日,協力ゲーム理論のシャープレイ値に基づき機械学習モデルの予測を解釈するKernel SHAPという手法の理論と既存のライブラリの実装についてのブログを書いた.

blog.tsurubee.tech

既存のSHAPライブラリであるslundberg/shap(以下,単にSHAPライブラリと呼ぶ)は,SHAPの提案論文*1のファーストオーサーにより開発されており,多くのSHAPのアルゴリズムの実装や可視化の機能が非常に充実している素晴らしいライブラリである.

しかし,私が自身の研究の中でSHAPライブラリの中のKernel SHAPを使っている際に,計算速度と拡張のしやすさの観点で改善したいポイントがいくつか出てきた.今回は,まだ絶賛開発中であるが,Kernel SHAPの高速性と拡張性の向上を目指したShapPackというライブラリのプロトタイプが完成したので,それについて紹介する.

目次

ShapPackの概要

github.com

ShapPackではSHAPライブラリのKernel SHAPと比較して,以下の三つの新たな機能を実装している.

  1. マルチプロセスで並列処理できる機能
  2. 特性関数を独自に実装して組み込める機能
  3. SHAP値を計算しない特徴量を指定できる機能

1と3は高速性に関わる機能で,2は拡張性に関わる機能である.
わかりやすい結果として,1の並列処理の機能により8コアのサーバでSHAP値の計算を実行すると,以下のように計算時間が短縮できる.

SHAPライブラリ ShapPack
5.54 s 0.684 s

上の結果の実験条件や,他の二つの機能の詳細については後述する.
下図は,SHAPライブラリとShapPackの実装の違いの概要をまとめたものである.

f:id:hirotsuru314:20210720165023p:plain
図1 SHAPライブラリとShapPackの実装の違いの概要図

ShapPackでは,サンプリングされた部分集合に対する特性関数の計算がボトルネックになることに着目し,部分集合を分割してマルチプロセスで並列処理することで計算速度を改善している.また,SHAPライブラリでは特性関数としてライブラリに実装済みのものしか利用できないことに着目し,利用者が外から独自に実装した特性関数を組み込める仕組みをとることで拡張性を改善している.

ShapPackの使い方

データとモデルの準備

データとして,scikit-learnに付属しているボストンの住宅価格のデータセットを用いた.データは,特徴量の尺度を揃えるために標準化している.
モデルは,動径基底関数(RBF)カーネルのサポートベクトル回帰(SVR)を用いた.

import numpy as np
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVR

# Fix seed to reproduce results
SEED = 123
np.random.seed(SEED)

# Prepare dataset
boston = load_boston()
X_train, X_test, y_train, y_test = train_test_split(boston["data"], boston["target"], test_size=0.2, random_state=SEED)
scaler = StandardScaler()
X_train_std = scaler.fit_transform(X_train)
X_test_std = scaler.transform(X_test)

# Prepare model
model = SVR(kernel="rbf")
model.fit(X_train_std, y_train)

SHAP値の計算

ShapPackを用いてSHAP値を計算するためのコードは以下の通りである.

import shappack
i = 2
explainer = shappack.KernelExplainer(model.predict, X_train_std[:100])
shap_value = explainer.shap_values(X_test_std[i])

ShapPackでは,既に広く使われているSHAPライブラリとなるべく同じ使い方ができるようにしており,実際に上のコードは「shappack」を「shap」と置き換えるとそのまま実行できる.
SHAPライブラリとの重要な違いとして,ShapPackのshap_values関数は,前述の三つの機能に紐づく新たな三つの引数(n_workerscharacteristic_funcskip_features)を渡せるようになっており,これらの引数が高速性や拡張性に寄与する.
以下,それぞれの引数の使い方を説明する.

1. n_workers:プロセス数を指定
shap_value = explainer.shap_values(X_test_std[i], n_workers=-1)

n_workersには,SHAP値の計算に使用するプロセス数を指定する.指定しない場合は,n_workers=1となる. n_workers=-1は,実行するサーバのコア数をn_workersに設定することを意味する. この引数により,プログラムをマルチコアのサーバーで実行する場合は,計算時間の短縮が期待できる.

2. characteristic_func:特性関数を指定
def my_characteristic_func(instance, subsets, model, data):
    # own implemented characteristic function

shap_value = explainer.shap_values(X_test_std[i], characteristic_func=my_characteristic_func)

characteristic_funcには,独自に実装した特性関数を渡すことができる.指定しない場合は,SHAPライブラリのKernel SHAPと同等のものが実行される.

3. skip_features:計算をスキップする特徴量を指定
explainer = shappack.KernelExplainer(model.predict, X_train_std[:100], feature_names=boston.feature_names)
shap_value = explainer.shap_values(X_test_std[i], skip_features=["PTRATIO", "TAX"])

skip_featureには,SHAP値の計算をスキップしたい特徴量を指定する.特徴量は,特徴量名またはインデックス番号で指定できる. なお,skip_featuresを特徴量名で指定する場合は,KernelExplainerクラスのfeature_names引数に特徴量名のリストを渡す必要がある.

可視化

現時点では,ShapPackは独自に可視化の仕組みを持っていないため,SHAP値の可視化のためにはSHAPライブラリを使う必要がある.

import shap
shap.initjs()
shap.force_plot(explainer.base_val[0], shap_value, X_test[i], boston.feature_names)

f:id:hirotsuru314:20210720171941p:plain
図2 Kernel SHAPに計算された予測値への貢献度の可視化

ちなみに,上の結果は,あるデータの住宅価格を平均より高く予測しており,その予測に対してRM(1戸当たりの平均部屋数)やLSTAT(低所得者人口の割合)が強く貢献していることを示している.

ShapPackの詳細

ここでは,新たに追加した三つの機能についての詳細をそれぞれ説明する.

1. マルチプロセスで並列処理できる機能

Kernel SHAPを利用する際の大きな課題の一つとして,計算時間が挙げられる.ShapPackではSHAPライブラリのボトルネックとなる箇所をマルチプロセスで並列処理できるようにすることで,計算速度の向上を目指している.
ここでは,まずSHAPライブラリとShapPackの計算時間の比較評価の結果を示す. 次に,SHAPライブラリのKernel SHAPのボトルネックと,ボトルネックを解消するための並列計算の実装について述べる.

計算時間の評価

下のコードのようにJupyter NotebookでSHAP値の計算時間の10回平均値を測定した.

%%timeit -r 10
shap_value = explainer.shap_values(X_test_std[i])

データやモデルは「ShapPackの使い方」で示した例と同じものを使用し,サーバはCPU8コア,メモリ24GB,バックグラウンドデータセットのサイズは100を採用した. 評価結果を下の表1に示す.

f:id:hirotsuru314:20210721092848p:plain:w500
表1 SHAP値の計算時間(単位は秒)

表1から,SHAPライブラリの計算時間が5.54秒であるのに対し,n_workers=8で設定したShapPackの計算時間は0.684秒であり,本実験条件では約8倍の計算時間の短縮が達成できた.これは,単純にSHAPライブラリだとマルチコアのサーバでも1コアしか計算に使わない一方で,ShapPackでマルチコアをフルに使って計算を実行するためである.
ShapPackでマルチコアを使わないn_workers=1でもSHAPライブラリよりも早い結果を示している.これはSHAPライブラリを参考にShapPackの実装を進める中でいくつかの細かい実装の工夫をしたためだと思われるが,その詳細はここでは割愛する.

SHAPライブラリのボトルネック

Jupyter Notebookを用いる場合,以下のように%%prunをセルの先頭に差し込むだけで比較的簡単にコードのプロファイリングができる.

%%prun
shap_value = explainer.shap_values(X_test_std[i])

SHAPライブラリのSHAP値の計算に対する出力結果を下に示す.

85988 function calls (82014 primitive calls) in 5.572 seconds

   Ordered by: internal time

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2    4.719    2.360    4.719    2.360 {sklearn.svm._libsvm.predict}
     1    0.731    0.731    5.455    5.455 _kernel.py:503(run)
  2074    0.022    0.000    0.024    0.000 _kernel.py:477(addsample)
     1    0.012    0.012    5.572    5.572 _kernel.py:204(explain)
    ・・・

この結果から,全体の計算時間(5.572秒)に対してSVMのpredict関数(4.719秒)が支配的であることがわかる.実際に該当する箇所は下のコードの最後のself.model.f(data) である.

# https://github.com/slundberg/shap/blob/master/shap/explainers/_kernel.py
def run(self):
    num_to_run = self.nsamplesAdded * self.N - self.nsamplesRun * self.N
    data = self.synth_data[self.nsamplesRun*self.N:self.nsamplesAdded*self.N,:]
    if self.keep_index:
        index = self.synth_data_index[self.nsamplesRun*self.N:self.nsamplesAdded*self.N]
        index = pd.DataFrame(index, columns=[self.data.index_name])
        data = pd.DataFrame(data, columns=self.data.group_names)
        data = pd.concat([index, data], axis=1).set_index(self.data.index_name)
        if self.keep_index_ordered:
            data = data.sort_index()
    modelOut = self.model.f(data) #Bottleneck!
    ・・・

なぜここの処理が時間がかかるかというと,モデルに渡ってくるdataのサイズが大きいからである.例えば,ボストンの住宅価格のデータセットの場合,特徴量数が13であり,部分集合のサンプリング数のライブラリの推奨値が「2*13+2048=2074」となる.また,特性関数の計算に用いるバックグラウンドデータセット(KernelExplainerクラスの引数の二つ目)のサイズを100とすると,207,400個のデータに対してモデルのpredict関数を実行しなければならない.この計算がKernel SHAPの計算時間のボトルネックになる.

並列計算の実装

サイズの大きなデータに対するモデルの予測を早く計算する方法として,ShapPackではデータを分割してマルチプロセスでモデルの予測を計算するように実装している(図1).もちろん,用いるモデル自体がマルチプロセスでの並列計算をサポートしていれば,そちらの機能を使えばよいが,全てのモデルがマルチプロセス処理に対応しているわけではないので,ShapPack側でマルチプロセス処理を実装し,引数で並列プロセス数を指定できるようにした.

2. 特性関数を独自に実装して組み込める機能

特性関数の実装例

以下にShapPackに組み込む特性関数を独自に実装した例を示す.

def my_characteristic_func(instance, subsets, model, data):
    n_subsets = subsets.shape[0]
    n_data = data.shape[0]
    synth_data = np.tile(data, (n_subsets, 1))
    for i, subset in enumerate(subsets):
        offset = i * n_data
        features_idx = np.where(subset == 1.0)[0]
        synth_data[offset : offset + n_data, features_idx] = instance[:, features_idx][0]
    model_preds = model(synth_data)
    ey = np.zeros(n_subsets)
    for i in range(n_subsets):
        ey[i] = np.min(model_preds[i * n_data : i * n_data + n_data])
    return ey

shap_value = explainer.shap_values(X_test_std[i], characteristic_func=my_characteristic_func)

この例は,SHAPライブラリのKernel SHAPの特性関数 E[f(x_S,X_{\bar{S}})]の期待値 Eを最小値計算 minに置き換えた例である. コードの詳細は述べないが,上のように十数行でKernel SHAPのオリジナルの特性関数を少し拡張したものを実装できる.
特性関数は「特徴量の不在をどのように再現するか」という問題が含まれており,SHAP値の計算の上では重要なポイントである.特性関数を独自に設計するというのは簡単ではないが,用途やデータに合わせてを特性関数を設計したいというケースもある(事例は後述).独自に特性関数を設計して適用したい場合,SHAPライブラリでは既存のソースコードに直接変更を加えなければならないが,ShapPackでは関数を実装して引数として渡すだけで,独自に設計した特性関数を適用できる.

特性関数を独自に設計した事例

特定の用途に適した特性関数を独自に設計している事例として,異常検知の解釈のために特性関数を設計した事例がある. 例えば,解釈したいデータ点(以下,インスタンス)に応じて適応的に参照値を選択するために,インスタンスと学習データとのinfluence weightingを用いている例がある*2.これは,ShapPackに組み込む特性関数の中で実装可能である.

また,インスタンスの近傍での異常スコアの最小化問題を解くことで特徴量の不在を近似する特性関数が提案されている*3. この手法では,インスタンスと部分集合の両方に適応的に参照値を決定することになるが,ShapPackの特性関数では,上の特性関数の実装例のinstancesubsetsでどちらも関数に渡たされる仕様になっているため,こちらもShapPackに組み込む特性関数の中で実装可能である.

3. SHAP値を計算しない特徴量を指定できる機能

SHAPライブラリでは,機械学習モデルの入力となる特徴量すべてに対してSHAP値を計算するが,ShapPackでは指定した特徴量のSHAP値の計算をスキップすることができる.これは例えば,人間が経験的に予測にほとんど影響を与えない特徴量を事前知識として持っている場合や,ある特徴量の影響がないと仮定した上での他の特徴量の貢献度を知りたい場合などを想定している.また,この機能により,指定する特徴量がある場合は計算が必要なSHAP値を減らすことができ,計算速度の向上にも貢献する.

実装としては,指定した特徴量は部分集合に必ず含まれているとし,その存在の有無の効果を測らないようにしている.例えば,以下のように特徴量「PTRATIO」と「TAX」を指定して実行した場合の可視化の結果は,図2の結果から「PTRATIO」と「TAX」が除外されていることがわかる.

explainer = shappack.KernelExplainer(model.predict, X_train_std[:100], feature_names=boston.feature_names)
skip_features=["PTRATIO", "TAX"]
shap_value = explainer.shap_values(X_test_std[i], skip_features=skip_features, n_workers=-1)
feature_names = np.delete(boston.feature_names, explainer.skip_idx)
x_test = np.delete(X_test[i], explainer.skip_idx)
shap.force_plot(explainer.base_val[0], shap_value, x_test, feature_names)

f:id:hirotsuru314:20210721091246p:plain
図3 PTRATIOとTAXの計算を除外した場合の貢献度の可視化

今後の展望

まず,現在のShapPackの大きな制約としてシャープレイ値の推定方法がKernel SHAPに限定されることが挙げられる. 一方,SHAPライブラリはKernel SHAP以外にDeep SHAPやTree SHAPなど豊富なアルゴリズムの実装がある. そのため,ShapPackでも高速性や拡張性を意識しつつKernel SHAP以外のアルゴリズムを実装していきたい.

次に,シャープレイ値を用いた機械学習の解釈の興味深い発展として,特徴量間の因果関係を導入したシャープレイ値がいくつか提案されている. 例えば,Shapley Flow*4や Causal Shapley Values*5などがある.これらの手法について調査してShapPackに実装できないか検討していきたい.

参考文献