2. shap 解释模型

1. shap解释回归模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction import stop_words
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import string, re
from time import time

导入数据

1
2
3
4
5
6
7
8
9
df = pd.read_csv("https://query.data.world/s/yd24ckbjzyp7h6zp7bacafpv2lgfkh", encoding="ISO-8859-1")
display(df.shape)
display(df["relevance"].value_counts()/df.shape[0])
=================================================================================
(8000, 15)
no 0.821375
yes 0.177500
not sure 0.001125
Name: relevance, dtype: float64

去除not sure:

1
2
3
4
df = df[df.relevance != 'not sure']
df.shape
======================================================================
(7991, 15)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
df['relevance'] = df.relevance.map({'yes':1, 'no':0})
df = df[['text', 'relevance']]
df.shape
======================================================================
(7991, 2)

stopwords = stop_words.ENGLISH_STOP_WORDS
def clean(doc): #doc is a string of text
doc = doc.replace("</br>", " ") #This text contains a lot of <br/> tags.
doc = "".join([char for char in doc if char not in string.punctuation and not char.isdigit()])
doc = " ".join([token for token in doc.split() if token not in stopwords])
#remove punctuation and numbers
return doc

x = df.text
y = df.relevance
print(x.shape, y.shape)
=========================================================================
(7991,) (7991,)

x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1)
print(x_train.shape, y_train.shape)
=========================================================================
(5993,) (5993,)



from sklearn.feature_extraction.text import TfidfVectorizer

vect = TfidfVectorizer(min_df=5)
x_train_dtm = vect.fit_transform(x_train)
x_test_dtm = vect.transform(x_test)
model = LogisticRegression(class_weight='balanced')
model.fit(x_train_dtm, y_train)

y_pred_class = model.predict(x_test_dtm)
print("Accuracy: ", accuracy_score(y_test, y_pred_class))
==============================================================================
Accuracy: 0.7382382382382382

shap 解释模型

这里有个比较全面的实例 模型解释–SHAP Value的简单介绍 。使用步骤:

  1. 实例化线性解释器 shap.LinearExplainer()
    • 不同模型对应不同解释器:
      • TreeExplainer : Support XGBoost, LightGBM, CatBoost and scikit-learn models by Tree SHAP.
      • DeepExplainer (DEEP SHAP) : Support TensorFlow and Keras models by using DeepLIFT and Shapley values.
      • GradientExplainer : Support TensorFlow and Keras models.
      • KernelExplainer (Kernel SHAP) : Applying to any models by using LIME and Shapley values.
  2. 对数据进行解释
  3. 对结果进行可视化
1
2
3
4
5
6
7
8
9
10
11
import shap
explainer = shap.LinearExplainer(model, x_train_dtm, feature_perturbation="intervebtional")
shap_values = explainer.shap_values(x_test_dtm)
x_test_array = x_test_dtm.toarray()

from pprint import pprint
pprint(df['text'][0])

plt.figure(dpi=120)
shap.initjs()
shap.summary_plot(shap_values, x_test_array, feature_names=vect.get_feature_names())

Snipaste_2021-05-18_15-58-56

上图说明:

  1. 特征重要性:变量重要程度由上往下递减,这里看到economy, 对于整个预测来说是比较重要的
  2. 水平方向是每个特征对于对应预测的影响
  3. 原始值:红色代表观察的数值大,蓝色代表观察的数值小
  4. 相关程度:dollar对于决策这篇文章是不是跟美国经济相关程度
1
2
3
4
5
shap.initjs()
shap.force_plot(
explainer.expected_value, shap_values[0, :], x_test_array[0, :],
feature_names=vect.get_feature_names()
)

Snipaste_2021-05-18_16-00-34

上图说明:

  1. 原始论文 说base_value是 : the value that would be predicted if we did not know any features for the current output.可以理解为预测值或者预测期望。
  2. 红色和蓝色: 将预测值推高的特征值显示为红色,推低的显示为蓝色。
  3. 经济: 对该文章是否跟美国经济有关有积极影响,将预测值推向右边。

2. shap 解释lstm

注: tf使用为1.15.2,1.14也行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from sklearn.preprocessing import LabelEncoder

import warnings
warnings.filterwarnings('ignore')
import re, os, sys
import numpy as np
import pandas as pd

from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from keras.layers import Dense, Input, GlobalMaxPooling1D,\
Conv1D, MaxPooling1D, Embedding, LSTM
from keras.models import Model, Sequential
from keras.initializers import Constant
import tensorflow as tf
from tensorflow import keras

max_seq_len = 1000
max_num_words = 2000
emb_dim = 100
valid_split = 0.2

vocab_size = 20000
maxlen = 1000

导入数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def load_directory_data(directory):
data = {}
data["sentence"] = []
data["sentiment"] = []
for file_path in os.listdir(directory):
with tf.gfile.GFile(os.path.join(directory, file_path), "r") as f:
data["sentence"].append(f.read())
data["sentiment"].append(re.match("\d+_(\d+)\.txt", file_path).group(1))
return pd.DataFrame.from_dict(data)


# Merge positive and negative examples, add a polarity column and shuffle.
def load_dataset(directory):
pos_df = load_directory_data(os.path.join(directory, "pos"))
neg_df = load_directory_data(os.path.join(directory, "neg"))
pos_df["polarity"] = 1
neg_df["polarity"] = 0
return pd.concat([pos_df, neg_df]).sample(frac=1).reset_index(drop=True)


def download_and_load_datasets(force_download=False):
dataset = tf.keras.utils.get_file(
fname="aclImdb.tar.gz",
origin="http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz",
extract=True)

train_df = load_dataset(os.path.join(os.path.dirname(dataset),
"aclImdb", "train"))
test_df = load_dataset(os.path.join(os.path.dirname(dataset),
"aclImdb", "test"))

return train_df, test_df


train, test = download_and_load_datasets()

数据预处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
train_texts = train['sentence'].values
train_labels = train['polarity'].values
test_texts = test['sentence'].values
test_labels = test['polarity'].values

labels_index = {'pos':1, 'neg':0}

tokenizer = Tokenizer(num_words=max_num_words)
tokenizer.fit_on_texts(train_texts)
train_sequences = tokenizer.texts_to_sequences(train_texts)#将文本转为词索引
test_sequences = tokenizer.texts_to_sequences(test_texts)
word_index = tokenizer.word_index
print("Found %s unique tokens."%len(word_index))

#将文本转为等长的向量
train_valid_data = pad_sequences(train_sequences, maxlen=max_seq_len)
test_data = pad_sequences(test_sequences, maxlen=max_seq_len)
train_valid_labels = to_categorical(np.asarray(train_labels))
test_labels = to_categorical(np.asarray(test_labels))

#划分数据集
indices = np.arange(train_valid_data.shape[0])
np.random.shuffle(indices)

train_valid_data = train_valid_data[indices]
train_valid_labels = train_valid_labels[indices]
num_valid_samples = int(valid_split * train_valid_data.shape[0])
x_train = train_valid_data[:-num_valid_samples]
y_train = train_valid_labels[:-num_valid_samples]
x_val = train_valid_data[-num_valid_samples:]
y_val = train_valid_labels[-num_valid_samples:]

print("划分数据集为训练测试集完毕!")
============================================================================
Found 88582 unique tokens.
划分数据集为训练测试集完毕!

模型训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
batch_size = 64
max_features = vocab_size + 1

print("定义训练LSTM模型: ")
lstm = Sequential() #shap只能用Sequential搭建deepnet模型
lstm.add(Embedding(max_num_words, 128))
lstm.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2))
lstm.add(Dense(2, activation="sigmoid"))
lstm.compile(loss="binary_crossentropy",
optimizer='Adam',
metrics=['accuracy'])
print("LSTM开始训练!")
lstm.fit(x_train, y_train,
batch_size=32, epochs=2,
validation_data=(x_val, y_val))

shap 解释lstm

1
2
3
4
5
6
7
from keras.datasets import imdb
import shap

shap.initjs()
explainer = shap.DeepExplainer(lstm, x_train[:20])
#解释每个预测值要2*背景数据,下面解释10个
shap_values = explainer.shap_values(x_val[:5])

获取验证集上前10个词对应的索引矩阵。

1
2
3
4
5
6
import numpy as np
words = imdb.get_word_index()
num2word = {}
for w in words.keys():
num2word[words[w]] = w
x_val_words = np.stack([np.array(list(map(lambda x: num2word.get(x, "NONE"), x_val[i]))) for i in range(10)])
1
2
3
4
5
6
shap.initjs()

shap.force_plot(explainer.expected_value[0], shap_values[0][0], x_val_words[0],
text_rotation=30,
matplotlib=True,
show=False)

image-20210518202004105

补充材料:如何解决机器学习树集成模型的解释性问题