Shap LSTM(Keras,TensorFlow)ValueError:形状不匹配:无法广播对象到单一形状

3

我正在尝试从简单的LSTM模型中绘制摘要。 当我调用shap.summary_plot时,出现ValueError: shape mismatch: objects cannot be broadcast to a single shape错误。 这里是重现问题的Colab链接。

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, LSTM
import shap

# Create random training values.
#
# train_x is [
#   [
#        [0.3, 0.54 ... 0.8],
#        [0.4, 0.6 ... 0.55],
#        ...
#   ],
#   [
#        [0.3, 0.54 ... 0.8],
#        [0.4, 0.6 ... 0.55],
#        ...
#   ],
#   ...
# ]
#
# train_y is corresponding classification of train_x sequences, always 0 or 1
# [0, 1, 0, 1, 0, ... 0]

SAMPLES_CNT = 1000

train_x = np.random.rand(SAMPLES_CNT,5,4)
train_y = np.vectorize(lambda x: int(round(x)))(np.random.rand(SAMPLES_CNT))

val_x = np.random.rand(int(SAMPLES_CNT * 0.1),5,4)
val_y = np.vectorize(lambda x: int(round(x)))(np.random.rand(int(SAMPLES_CNT * 0.1)))

# Train model

model = Sequential()
model.add(LSTM(32,input_shape=train_x.shape[1:], return_sequences=False, stateful=False))
model.add(Dense(1, activation='sigmoid'))

model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001, decay=1e-6),
              loss='binary_crossentropy',metrics=['accuracy'])

fit = model.fit(train_x, train_y, batch_size=64, epochs=2, 
                validation_data=(val_x, val_y), shuffle=False)

explainer = shap.DeepExplainer(model, train_x[:10])
shap_vals = explainer.shap_values(val_x[:10])
shap.summary_plot(shap_vals, val_x[:10], plot_type="bar")

崩溃信息:


---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

<ipython-input-78-906a7898852e> in <module>
----> 1 shap.summary_plot(shap_vals, val_x[:10], feature_names=feature_names, plot_type="bar")
      2 

/usr/local/lib/python3.7/site-packages/shap/plots/summary.py in summary_plot(shap_values, features, feature_names, max_display, plot_type, color, axis_color, title, alpha, show, sort, color_bar, plot_size, layered_violin_max_num_bins, class_names, class_inds, color_bar_label, auto_size_plot)
    442             pl.barh(
    443                 y_pos, global_shap_values[feature_inds], 0.7, left=left_pos, align='center',
--> 444                 color=color(i), label=class_names[ind]
    445             )
    446             left_pos += global_shap_values[feature_inds]

/usr/local/lib/python3.7/site-packages/matplotlib/pyplot.py in barh(y, width, height, left, align, **kwargs)
   2421 def barh(y, width, height=0.8, left=None, *, align='center', **kwargs):
   2422     return gca().barh(
-> 2423         y, width, height=height, left=left, align=align, **kwargs)
   2424 
   2425 

/usr/local/lib/python3.7/site-packages/matplotlib/axes/_axes.py in barh(self, y, width, height, left, align, **kwargs)
   2544         kwargs.setdefault('orientation', 'horizontal')
   2545         patches = self.bar(x=left, height=height, width=width, bottom=y,
-> 2546                            align=align, **kwargs)
   2547         return patches
   2548 

/usr/local/lib/python3.7/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
   1563     def inner(ax, *args, data=None, **kwargs):
   1564         if data is None:
-> 1565             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1566 
   1567         bound = new_sig.bind(ax, *args, **kwargs)

/usr/local/lib/python3.7/site-packages/matplotlib/axes/_axes.py in bar(self, x, height, width, bottom, align, **kwargs)
   2339         x, height, width, y, linewidth = np.broadcast_arrays(
   2340             # Make args iterable too.
-> 2341             np.atleast_1d(x), height, width, y, linewidth)
   2342 
   2343         # Now that units have been converted, set the tick locations.

<__array_function__ internals> in broadcast_arrays(*args, **kwargs)

/usr/local/lib/python3.7/site-packages/numpy/lib/stride_tricks.py in broadcast_arrays(*args, **kwargs)
    262     args = [np.array(_m, copy=False, subok=subok) for _m in args]
    263 
--> 264     shape = _broadcast_shape(*args)
    265 
    266     if all(array.shape == shape for array in args):

/usr/local/lib/python3.7/site-packages/numpy/lib/stride_tricks.py in _broadcast_shape(*args)
    189     # use the old-iterator because np.nditer does not handle size 0 arrays
    190     # consistently
--> 191     b = np.broadcast(*args[:32])
    192     # unfortunately, it cannot handle 32 or more arguments directly
    193     for pos in range(32, len(args), 31):

ValueError: shape mismatch: objects cannot be broadcast to a single shape

我是在做什么错误吗?还是这是一个bug?我已经一整天都在苦思冥想了。提前感谢您的帮助。

1个回答

7
根据文档,`shap.summary_plot` 的前两个参数分别为:
- `shap_values`: numpy.array 对于单输出解释,这是一个 SHAP 值矩阵(# 样本数 x # 特征数)。 对于多输出解释,这是一个包含多个 SHAP 值矩阵的列表。
- `features`: numpy.array 或 pandas.DataFrame 或 list 特征值矩阵(# 样本数 x # 特征数)或特征名称列表的简称。
在您提供的代码中,`shap_vals[0]` 和 `val_x[:10]` 的形状都是 `(10, 5, 4)`。因此,您应该将前两个维度展平,或选择您感兴趣的时间点。例如:
shap.summary_plot(shap_vals[0][:, 0, :], val_x[:10][:, 0, :], feature_names=feature_names, plot_type="bar")

这将产生以下图表:enter image description here

非常感谢您的帮助!我真的很感激。我还在苦苦挣扎。当我尝试从最后一层为Dense(2, activation="softmax")sparse_categorical_crossentropy损失函数的模型中获取Shap值时。https://colab.research.google.com/drive/1pMh0n4WVVkNnM-uBr9D-u4AyqKu7hrud。我是否做错了什么明显的事情?(再次感谢!) - stringCode
2
我认为这是因为当解释器期望维度为1时,输出维度为2。您可以将最后一层更改为Dense(1, activation="sigmoid")并在编译模型时使用loss='binary_crossentropy' - rvinas

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接