SHAP summary_plot小提琴图颜色修改实战:从源码解析到参数定制
第一次用SHAP的summary_plot生成小提琴图时,我盯着那排单调的蓝色violin发愣——明明在matplotlib和seaborn里改颜色易如反掌,为什么这里的color参数毫无反应?这个看似简单的需求,最终让我花了三个晚上深挖SHAP源码。本文将分享从问题定位到两种解决方案的全过程,适合那些不满足于"能用就行"、想真正掌握工具定制的开发者。
1. 问题重现与初步排查
我们从一个典型的SHAP调用开始:
import shap explainer = shap.Explainer(model) shap_values = explainer(X) shap.summary_plot(shap_values, X, plot_type="violin")当尝试添加颜色参数时,以下常见方法全部失效:
# 以下尝试均无效 shap.summary_plot(shap_values, X, plot_type="violin", color="red") shap.summary_plot(shap_values, X, plot_type="violin", cmap="Reds")通过打印函数签名,发现summary_plot确实接受color参数:
import inspect print(inspect.signature(shap.summary_plot))输出显示函数定义包含color=None参数,但实际调用时却被忽略。这种表面合规但实际无效的情况,正是需要深入源码的信号。
2. 源码定位与逻辑分析
在Python环境中通过以下命令找到SHAP安装位置:
pip show shap | grep Location进入源码目录后,关键定位步骤:
- 在
plots/__init__.py中找到summary_plot函数 - 发现其通过
plot_type分支调用不同绘图逻辑 - 定位到
_violin_summary这个真正处理violin图的内部函数
核心发现点在于_violin_summary中的这段代码:
def _violin_summary(..., color=None, ...): # ... for i in range(len(feature_order)): sv = shap_values[:, feature_order[i]] v = ax.violinplot(..., positions=[i], showmeans=False, showextrema=False, widths=0.7 ) for b in v['bodies']: b.set_facecolor('#1E88E5') # 硬编码的蓝色 b.set_edgecolor('black') b.set_alpha(1)问题根源一目了然:虽然函数接收color参数,但在绘制violin时却使用了硬编码的十六进制色值#1E88E5,完全忽略了传入的color参数。
3. 解决方案一:直接修改源码
最快速的解决方案是直接修改源码中的颜色设置:
- 定位到
shap/plots/_violin.py文件 - 找到
_violin_summary函数中的颜色设置部分 - 修改为:
b.set_facecolor(color if color is not None else '#1E88E5')这种方法的优缺点对比:
| 优点 | 缺点 |
|---|---|
| 立即生效 | 需要每次安装/更新后重新修改 |
| 改动简单 | 不利于代码版本管理 |
| 适合快速验证 | 可能被后续更新覆盖 |
实际操作中,可以用patch工具临时修改:
import shap from functools import partial original_func = shap.plots._violin._violin_summary def patched_violin(..., color=None, ...): # 修改后的实现 pass shap.plots._violin._violin_summary = patched_violin4. 解决方案二:创建增强版函数
更可持续的方案是创建自定义函数,保留原函数的同时扩展功能:
def custom_summary_plot(shap_values, features=None, ..., violin_color=None, violin_edgecolor='black'): """ 增强版summary_plot,支持violin颜色定制 参数: violin_color: violin填充色,默认为原库的蓝色 violin_edgecolor: violin边缘色,默认为黑色 """ # 调用原始函数获取基础绘图 fig, ax = plt.subplots() shap.summary_plot(shap_values, features, plot_type="violin", show=False) # 获取当前axes并修改violin颜色 ax = plt.gca() for col in ax.collections: if isinstance(col, matplotlib.collections.PolyCollection): col.set_facecolor(violin_color or '#1E88E5') col.set_edgecolor(violin_edgecolor) return fig这个方案的优势在于:
- 不修改原始库文件
- 明确新增参数控制violin样式
- 保持与原函数相同的调用方式
- 可以进一步扩展其他定制选项
5. 颜色映射的高级应用
对于需要更复杂颜色映射的场景,我们可以扩展函数支持:
def advanced_summary_plot(..., color_map=None): """支持基于特征重要性的颜色渐变""" # 计算特征重要性 importance = np.abs(shap_values).mean(0) # 创建颜色映射 if color_map is None: color_map = plt.cm.Blues norm = plt.Normalize(importance.min(), importance.max()) colors = color_map(norm(importance)) # 绘制并设置颜色 shap.summary_plot(..., show=False) ax = plt.gca() for i, col in enumerate(ax.collections): if isinstance(col, matplotlib.collections.PolyCollection): col.set_facecolor(colors[i])使用示例:
advanced_summary_plot(shap_values, X, color_map=plt.cm.RdBu_r)这种实现可以产生从蓝到红的渐变效果,重要性高的特征显示为深红色,低的显示为浅蓝色。
6. 工程化封装与发布
为了让解决方案更易于团队使用,可以将其打包为独立模块:
- 创建
shap_extensions目录 - 添加
__init__.py和violin.py - 在setup.py中声明包依赖:
from setuptools import setup setup( name='shap_extensions', version='0.1', packages=['shap_extensions'], install_requires=['shap>=0.40', 'matplotlib>=3.0'] )关键设计考虑:
- 保持与原SHAP库的函数签名兼容
- 通过继承或组合扩展功能
- 提供清晰的文档字符串和类型提示
- 编写单元测试验证颜色修改效果
最终调用方式既简洁又明确:
from shap_extensions.violin import summary_plot summary_plot(shap_values, X, violin_color='#FF6D00', violin_edgecolor='#333333')