将SHAP总结图保存为PDF/SVG。

16

我目前在处理一个分类问题,并想创建特征重要性的可视化。我使用Python的XGBoost软件包,它已经提供了特征重要性图。但是,我发现shap (https://github.com/slundberg/shap)是一个基于树分类器的Python库,可以创建非常好的特征重要性图形。所有的东西都运行良好,我也可以将创建的图形保存为PNG,但是,如果我尝试将其保存为PDF或SVG,则会发生异常。这是我的操作:

首先,我训练XGBoost模型,并得到一个名为bst的模型。

train = remove_labels_for_binary_df(dataset_fc_baseline_1[0].train)
test = remove_labels_for_binary_df(dataset_fc_baseline_1[0].test)
results, bst = xgboost_with_bst(*transform_feat_to_num(train, test))

然后我创建SHAP值,使用它们创建一个摘要图并保存可视化结果。如果我将图像另存为plt.savefig('shap.png'),一切都正常运作。

import shap
import matplotlib.pyplot as plt

shap.initjs()

explainer = shap.TreeExplainer(bst)
shap_values = explainer.shap_values(train)
fig = shap.summary_plot(shap_values, train, show=False)
plt.savefig('shap.png')

不过,我需要PDF或SVG图形而非PNG格式,因此尝试使用plt.savefig('shap.pdf')进行保存。通常情况下这样做是有效的,但对于shap图却会产生以下异常。

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-39-49d17973f438> in <module>()
  1 fig = shap.summary_plot(shap_values, train, show=False)
----> 2 plt.savefig('shap.pdf')

 C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\pyplot.py in 
savefig(*args, **kwargs)
708 def savefig(*args, **kwargs):
709     fig = gcf()
--> 710     res = fig.savefig(*args, **kwargs)
711     fig.canvas.draw_idle()   # need this if 'transparent=True' to reset 
colors
712     return res

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\figure.py in 
savefig(self, fname, **kwargs)
2033             self.set_frameon(frameon)
2034 
-> 2035         self.canvas.print_figure(fname, **kwargs)
2036 
2037         if frameon:

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\backend_bases.py in 
print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, 
**kwargs)
2261                 orientation=orientation,
2262                 bbox_inches_restore=_bbox_inches_restore,
-> 2263                 **kwargs)
2264         finally:
2265             if bbox_inches and restore_bbox:

C:\Users\Studio\Anaconda3\lib\site- 
packages\matplotlib\backends\backend_pdf.py in print_pdf(self, filename, 
**kwargs)
2584                 RendererPdf(file, image_dpi, height, width),
2585                 bbox_inches_restore=_bbox_inches_restore)
-> 2586             self.figure.draw(renderer)
2587             renderer.finalize()
2588             if not isinstance(filename, PdfPages):

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in 
draw_wrapper(artist, renderer, *args, **kwargs)
 53                 renderer.start_filter()
 54 
---> 55             return draw(artist, renderer, *args, **kwargs)
 56         finally:
 57             if artist.get_agg_filter() is not None:

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\figure.py in 
draw(self, renderer)
1473 
1474             mimage._draw_list_compositing_images(
-> 1475                 renderer, self, artists, self.suppressComposite)
1476 
1477             renderer.close_group('figure')

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\image.py in 
_draw_list_compositing_images(renderer, parent, artists, suppress_composite)
139     if not_composite or not has_images:
140         for a in artists:
--> 141             a.draw(renderer)
142     else:
143         # Composite any adjacent images together

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in 
draw_wrapper(artist, renderer, *args, **kwargs)
 53                 renderer.start_filter()
 54 
---> 55             return draw(artist, renderer, *args, **kwargs)
 56         finally:
 57             if artist.get_agg_filter() is not None:

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\axes\_base.py in 
draw(self, renderer, inframe)
2605             renderer.stop_rasterizing()
2606 
-> 2607         mimage._draw_list_compositing_images(renderer, self, 
 artists)
2608 
2609         renderer.close_group('axes')

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\image.py in 
_draw_list_compositing_images(renderer, parent, artists, suppress_composite)
139     if not_composite or not has_images:
140         for a in artists:
--> 141             a.draw(renderer)
142     else:
143         # Composite any adjacent images together

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in 
draw_wrapper(artist, renderer, *args, **kwargs)
 58                 renderer.stop_filter(artist.get_agg_filter())
 59             if artist.get_rasterized():
---> 60                 renderer.stop_rasterizing()
 61 
 62     draw_wrapper._supports_rasterization = True

C:\Users\Studio\Anaconda3\lib\site- 
packages\matplotlib\backends\backend_mixed.py in stop_rasterizing(self)
128 
129             height = self._height * self.dpi
--> 130             buffer, bounds = 
self._raster_renderer.tostring_rgba_minimized()
131             l, b, w, h = bounds
132             if w > 0 and h > 0:

C:\Users\Studio\Anaconda3\lib\site- 
packages\matplotlib\backends\backend_agg.py in tostring_rgba_minimized(self)
138                 [extents[0] + extents[2], self.height - extents[1]]]
139         region = self.copy_from_bbox(bbox)
--> 140         return np.array(region), extents
141 
142     def draw_path(self, gc, path, transform, rgbFace=None):

ValueError: negative dimensions are not allowed

你有没有什么办法来解决这个问题?


1
你找到解决方案了吗? - FooBar
很遗憾,目前还没有解决这个问题的方法。 - Roqua
7个回答

8
在保存图形时,需要添加 matplotlib=True,show=False 参数:
def heart_disease_risk_factors(model, patient):

    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(patient)
    shap.initjs()

    return shap.force_plot(explainer.expected_value[1],shap_values[1],\
        patient,matplotlib=True,show=False)


plt.clf()
data_for_prediction = X_test.iloc[2,:].astype(float)
heart_disease_risk_factors(model, data_for_prediction)
plt.savefig("gg.png",dpi=150, bbox_inches='tight')

@zmag已经对代码进行了适当的缩进编辑。 - Ramisha Rani K
如果在Jupyter笔记本中,禁用%matplotlib inline - 可能会强制显示图形,从而在保存之前清空内存。 - st0ne

7
默认情况下,summary_plot 会调用 plt.show() 来确保图表显示。但是,如果您将 show=False 传递给 summary_plot,则可以保存它。例如:
#shap summary plot plotting
import matplotlib.pyplot as pl
shap.summary_plot(shap_values, X_train,max_display=10,show=False)
pl.savefig("shap_summary.svg",dpi=700) #.png,.pdf will also support here
pyplot.show()

此建议也适用于 shap.waterfall_plot,用于解释任何特定样本的模型预测。 - Heelara

4
这是使用 rasterized=True 绘制时(如果有超过 500 个数据点,则 shap 会这样做),在 NumPy 和 matplotlib 之间引起的 问题, 并已在最新版本的 matplotlib 中解决。

2
最简单的方法是按如下保存:
 fig = shap.summary_plot(shap_values, X_test, plot_type="bar", feature_names=["a", "b"], show=False)
plt.savefig("trial.png")

注意:默认情况下,summary_plot 调用 plt.show() 来确保绘图显示。但是,如果你将 show=False 传递给 summary_plot,则不会显示绘图。 https://github.com/slundberg/shap/issues/153

2
我认为最简单的方法是:
shap.summary_plot(shap_values, X, show=False)
plt.savefig('mygraph.pdf', format='pdf', dpi=600, bbox_inches='tight')
plt.show()

您的答案可以通过添加额外的支持信息来改进。请[编辑]以添加更多细节,例如引用或文档,以便其他人可以确认您的答案是否正确。您可以在帮助中心中找到有关如何编写良好答案的更多信息。 - Ethan

-1
请尝试这个:
shap.plots.force(shape_values[0], show=False, matplotlib=True).savefig('shap.pdf')

-2

保存为 PDF:

plt.savefig("shap.pdf", format='pdf', dpi=1000, bbox_inches='tight')

保存为eps格式:

plt.savefig("shap.eps", format='eps', dpi=1000, bbox_inches='tight')

更多信息请参考:

matplotlib.pyplot.savefig matplotlib

请查看链接以了解更多信息,例如 bbox_inches='tight' 的含义。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接