• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.87
/src/shapiq/plot/waterfall.py
1
"""Wrapper for the waterfall plot from the ``shap`` package.
2

3
Note:
4
    Code and implementation was taken and adapted from the [SHAP package](https://github.com/shap/shap)
5
    which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).
6

7
"""
8

9
from __future__ import annotations
1✔
10

11
from typing import TYPE_CHECKING
1✔
12

13
import matplotlib.pyplot as plt
1✔
14
import numpy as np
1✔
15
from matplotlib.transforms import ScaledTranslation
1✔
16

17
from ._config import BLUE, RED
1✔
18
from .utils import abbreviate_feature_names, format_labels, format_value
1✔
19

20
if TYPE_CHECKING:
1✔
NEW
21
    from matplotlib.axes import Axes
×
22

UNCOV
23
    from shapiq.interaction_values import InteractionValues
×
24

25
__all__ = ["waterfall_plot"]
1✔
26

27

28
def _draw_waterfall_plot(
1✔
29
    values: np.ndarray,
30
    base_values: float,
31
    feature_names: np.ndarray | list[str],
32
    *,
33
    max_display: int = 10,
34
    show: bool = False,
35
) -> Axes | None:
36
    """The waterfall plot from the SHAP package.
37

38
    Note:
39
        This function was taken and adapted from the [SHAP package](https://github.com/shap/shap/blob/master/shap/plots/_waterfall.py)
40
        which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).
41
        Do not use this function directly, use the ``waterfall_plot`` function instead.
42

43
    Args:
44
        values: The values to plot.
45
        base_values: The base value.
46
        feature_names: The names of the features.
47
        max_display: The maximum number of features to display.
48
        show: Whether to show the plot.
49

50
    Returns:
51
        The plot if ``show`` is ``False``.
52

53
    """
54
    # Turn off interactive plot
55
    if show is False:
1✔
56
        plt.ioff()
1✔
57

58
    # init variables we use for tracking the plot locations
59
    num_features = min(max_display, len(values))
1✔
60
    row_height = 0.5
1✔
61
    rng = range(num_features - 1, -1, -1)
1✔
62
    order = np.argsort(-np.abs(values))
1✔
63
    pos_lefts = []
1✔
64
    pos_inds = []
1✔
65
    pos_widths = []
1✔
66
    pos_low = []
1✔
67
    pos_high = []
1✔
68
    neg_lefts = []
1✔
69
    neg_inds = []
1✔
70
    neg_widths = []
1✔
71
    neg_low = []
1✔
72
    neg_high = []
1✔
73
    loc = base_values + values.sum()
1✔
74
    yticklabels = ["" for _ in range(num_features + 1)]
1✔
75

76
    # size the plot based on how many features we are plotting
77
    plt.gcf().set_size_inches(8, num_features * row_height + 3.5)
1✔
78

79
    # see how many individual (vs. grouped at the end) features we are plotting
80
    num_individual = num_features if num_features == len(values) else num_features - 1
1✔
81

82
    # compute the locations of the individual features and plot the dashed connecting lines
83
    for i in range(num_individual):
1✔
84
        sval = values[order[i]]
1✔
85
        loc -= sval
1✔
86
        if sval >= 0:
1✔
87
            pos_inds.append(rng[i])
1✔
88
            pos_widths.append(sval)
1✔
89
            pos_lefts.append(loc)
1✔
90
        else:
91
            neg_inds.append(rng[i])
1✔
92
            neg_widths.append(sval)
1✔
93
            neg_lefts.append(loc)
1✔
94
        if num_individual != num_features or i + 4 < num_individual:
1✔
95
            plt.plot(
1✔
96
                [loc, loc],
97
                [rng[i] - 1 - 0.4, rng[i] + 0.4],
98
                color="#bbbbbb",
99
                linestyle="--",
100
                linewidth=0.5,
101
                zorder=-1,
102
            )
103
        yticklabels[rng[i]] = str(feature_names[order[i]])
1✔
104

105
    # add a last grouped feature to represent the impact of all the features we didn't show
106
    if num_features < len(values):
1✔
107
        yticklabels[0] = f"{int(len(values) - num_features + 1)} other features"
1✔
108
        remaining_impact = base_values - loc
1✔
109
        if remaining_impact < 0:
1✔
110
            pos_inds.append(0)
1✔
111
            pos_widths.append(-remaining_impact)
1✔
112
            pos_lefts.append(loc + remaining_impact)
1✔
113
        else:
114
            neg_inds.append(0)
1✔
115
            neg_widths.append(-remaining_impact)
1✔
116
            neg_lefts.append(loc + remaining_impact)
1✔
117

118
    points = (
1✔
119
        pos_lefts
120
        + list(np.array(pos_lefts) + np.array(pos_widths))
121
        + neg_lefts
122
        + list(np.array(neg_lefts) + np.array(neg_widths))
123
    )
124
    dataw = np.max(points) - np.min(points)
1✔
125

126
    # draw invisible bars just for sizing the axes
127
    label_padding = np.array([0.1 * dataw if w < 1 else 0 for w in pos_widths])
1✔
128
    plt.barh(
1✔
129
        pos_inds,
130
        np.array(pos_widths) + label_padding + 0.02 * dataw,
131
        left=np.array(pos_lefts) - 0.01 * dataw,
132
        color=RED.hex,
133
        alpha=0,
134
    )
135
    label_padding = np.array([-0.1 * dataw if -w < 1 else 0 for w in neg_widths])
1✔
136
    plt.barh(
1✔
137
        neg_inds,
138
        np.array(neg_widths) + label_padding - 0.02 * dataw,
139
        left=np.array(neg_lefts) + 0.01 * dataw,
140
        color=BLUE.hex,
141
        alpha=0,
142
    )
143

144
    # define variable we need for plotting the arrows
145
    head_length = 0.08
1✔
146
    bar_width = 0.8
1✔
147
    xlen = plt.xlim()[1] - plt.xlim()[0]
1✔
148
    fig = plt.gcf()
1✔
149
    ax = plt.gca()
1✔
150
    bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
1✔
151
    width = bbox.width
1✔
152
    bbox_to_xscale = xlen / width
1✔
153
    hl_scaled = bbox_to_xscale * head_length
1✔
154
    dpi = fig.dpi
1✔
155
    renderer = fig.canvas.get_renderer()  # pyright: ignore[reportAttributeAccessIssue]
1✔
156

157
    # draw the positive arrows
158
    for i in range(len(pos_inds)):
1✔
159
        dist = pos_widths[i]
1✔
160
        arrow_obj = plt.arrow(
1✔
161
            pos_lefts[i],
162
            pos_inds[i],
163
            max(dist - hl_scaled, 0.000001),
164
            0,
165
            head_length=min(dist, hl_scaled),
166
            color=RED.hex,
167
            width=bar_width,
168
            head_width=bar_width,
169
        )
170

171
        if pos_low is not None and i < len(pos_low):
1✔
172
            plt.errorbar(
×
173
                pos_lefts[i] + pos_widths[i],
174
                pos_inds[i],
175
                xerr=np.array([[pos_widths[i] - pos_low[i]], [pos_high[i] - pos_widths[i]]]),
176
                ecolor=BLUE.hex,
177
            )
178

179
        txt_obj = plt.text(
1✔
180
            pos_lefts[i] + 0.5 * dist,
181
            pos_inds[i],
182
            format_value(pos_widths[i], "%+0.02f"),
183
            horizontalalignment="center",
184
            verticalalignment="center",
185
            color="white",
186
            fontsize=12,
187
        )
188
        text_bbox = txt_obj.get_window_extent(renderer=renderer)
1✔
189
        arrow_bbox = arrow_obj.get_window_extent(renderer=renderer)
1✔
190

191
        # if the text overflows the arrow then draw it after the arrow
192
        if text_bbox.width > arrow_bbox.width:
1✔
193
            txt_obj.remove()
1✔
194

195
            txt_obj = plt.text(
1✔
196
                pos_lefts[i] + (5 / 72) * bbox_to_xscale + dist,
197
                pos_inds[i],
198
                format_value(pos_widths[i], "%+0.02f"),
199
                horizontalalignment="left",
200
                verticalalignment="center",
201
                color=RED.hex,
202
                fontsize=12,
203
            )
204

205
    # draw the negative arrows
206
    for i in range(len(neg_inds)):
1✔
207
        dist = neg_widths[i]
1✔
208

209
        arrow_obj = plt.arrow(
1✔
210
            neg_lefts[i],
211
            neg_inds[i],
212
            -max(-dist - hl_scaled, 0.000001),
213
            0,
214
            head_length=min(-dist, hl_scaled),
215
            color=BLUE.hex,
216
            width=bar_width,
217
            head_width=bar_width,
218
        )
219

220
        if neg_low is not None and i < len(neg_low):
1✔
221
            plt.errorbar(
×
222
                neg_lefts[i] + neg_widths[i],
223
                neg_inds[i],
224
                xerr=np.array([[neg_widths[i] - neg_low[i]], [neg_high[i] - neg_widths[i]]]),
225
                ecolor=RED.hex,
226
            )
227

228
        txt_obj = plt.text(
1✔
229
            neg_lefts[i] + 0.5 * dist,
230
            neg_inds[i],
231
            format_value(neg_widths[i], "%+0.02f"),
232
            horizontalalignment="center",
233
            verticalalignment="center",
234
            color="white",
235
            fontsize=12,
236
        )
237
        text_bbox = txt_obj.get_window_extent(renderer=renderer)
1✔
238
        arrow_bbox = arrow_obj.get_window_extent(renderer=renderer)
1✔
239

240
        # if the text overflows the arrow then draw it after the arrow
241
        if text_bbox.width > arrow_bbox.width:
1✔
242
            txt_obj.remove()
×
243

244
            plt.text(
×
245
                neg_lefts[i] - (5 / 72) * bbox_to_xscale + dist,
246
                neg_inds[i],
247
                format_value(neg_widths[i], "%+0.02f"),
248
                horizontalalignment="right",
249
                verticalalignment="center",
250
                color=BLUE.hex,
251
                fontsize=12,
252
            )
253

254
    # draw the y-ticks twice, once in gray and then again with just the feature names in black
255
    # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks
256
    ytick_pos = list(range(num_features)) + list(np.arange(num_features) + 1e-8)
1✔
257
    plt.yticks(
1✔
258
        ytick_pos,
259
        yticklabels[:-1] + [label.split("=")[-1] for label in yticklabels[:-1]],
260
        fontsize=13,
261
    )
262

263
    # Check that the y-ticks are not drawn outside the plot
264
    max_label_width = (
1✔
265
        max([label.get_window_extent(renderer=renderer).width for label in ax.get_yticklabels()])
266
        / dpi
267
    )
268
    if max_label_width > 0.1 * fig.get_size_inches()[0]:
1✔
269
        required_width = max_label_width / 0.1
1✔
270
        fig_height = fig.get_size_inches()[1]
1✔
271
        fig.set_size_inches(required_width, fig_height, forward=True)
1✔
272

273
    # put horizontal lines for each feature row
274
    for i in range(num_features):
1✔
275
        plt.axhline(i, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
1✔
276

277
    # mark the prior expected value and the model prediction
278
    plt.axvline(
1✔
279
        base_values,
280
        0,
281
        1 / num_features,
282
        color="#bbbbbb",
283
        linestyle="--",
284
        linewidth=0.5,
285
        zorder=-1,
286
    )
287
    fx = base_values + values.sum()
1✔
288
    plt.axvline(fx, 0, 1, color="#bbbbbb", linestyle="--", linewidth=0.5, zorder=-1)
1✔
289

290
    # clean up the main axis
291
    plt.gca().xaxis.set_ticks_position("bottom")
1✔
292
    plt.gca().yaxis.set_ticks_position("none")
1✔
293
    plt.gca().spines["right"].set_visible(False)
1✔
294
    plt.gca().spines["top"].set_visible(False)
1✔
295
    plt.gca().spines["left"].set_visible(False)
1✔
296
    ax.tick_params(labelsize=13)
1✔
297

298
    # draw the E[f(X)] tick mark
299
    xmin, xmax = ax.get_xlim()
1✔
300
    ax2 = ax.twiny()
1✔
301
    ax2.set_xlim(xmin, xmax)
1✔
302
    ax2.set_xticks(
1✔
303
        [base_values, base_values + 1e-8],
304
    )  # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks
305
    ax2.set_xticklabels(
1✔
306
        ["\n$E[f(X)]$", "\n$ = " + format_value(base_values, "%0.03f") + "$"],
307
        fontsize=12,
308
        ha="left",
309
    )
310
    ax2.spines["right"].set_visible(False)
1✔
311
    ax2.spines["top"].set_visible(False)
1✔
312
    ax2.spines["left"].set_visible(False)
1✔
313

314
    # draw the f(x) tick mark
315
    ax3 = ax2.twiny()
1✔
316
    ax3.set_xlim(xmin, xmax)
1✔
317
    # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks
318
    ax3.set_xticks([base_values + values.sum(), base_values + values.sum() + 1e-8])
1✔
319
    ax3.set_xticklabels(
1✔
320
        ["$f(x)$", "$ = " + format_value(fx, "%0.03f") + "$"],
321
        fontsize=12,
322
        ha="left",
323
    )
324
    tick_labels = ax3.xaxis.get_majorticklabels()
1✔
325
    tick_labels[0].set_transform(
1✔
326
        tick_labels[0].get_transform() + ScaledTranslation(-10 / 72.0, 0, fig.dpi_scale_trans),
327
    )
328
    tick_labels[1].set_transform(
1✔
329
        tick_labels[1].get_transform() + ScaledTranslation(12 / 72.0, 0, fig.dpi_scale_trans),
330
    )
331
    tick_labels[1].set_color("#999999")
1✔
332
    ax3.spines["right"].set_visible(False)
1✔
333
    ax3.spines["top"].set_visible(False)
1✔
334
    ax3.spines["left"].set_visible(False)
1✔
335

336
    # adjust the position of the E[f(X)] = x.xx label
337
    tick_labels = ax2.xaxis.get_majorticklabels()
1✔
338
    tick_labels[0].set_transform(
1✔
339
        tick_labels[0].get_transform() + ScaledTranslation(-20 / 72.0, 0, fig.dpi_scale_trans),
340
    )
341
    tick_labels[1].set_transform(
1✔
342
        tick_labels[1].get_transform()
343
        + ScaledTranslation(22 / 72.0, -1 / 72.0, fig.dpi_scale_trans),
344
    )
345

346
    tick_labels[1].set_color("#999999")
1✔
347

348
    # color the y tick labels that have the feature values as gray
349
    # (these fall behind the black ones with just the feature name)
350
    tick_labels = ax.yaxis.get_majorticklabels()
1✔
351
    for i in range(num_features):
1✔
352
        tick_labels[i].set_color("#999999")
1✔
353

354
    if show:
1✔
355
        plt.show()
×
356
        return None
×
357
    return plt.gca()
1✔
358

359

360
def waterfall_plot(
1✔
361
    interaction_values: InteractionValues,
362
    *,
363
    feature_names: np.ndarray | list[str] | None = None,
364
    show: bool = False,
365
    max_display: int = 10,
366
    abbreviate: bool = True,
367
) -> Axes | None:
368
    """Draws a waterfall plot with the interaction values.
369

370
    The waterfall plot shows the individual contributions of the features to the interaction values.
371
    The plot is based on the waterfall plot from the SHAP[1]_ package.
372

373
    Args:
374
        interaction_values: The interaction values as an interaction object.
375
        feature_names: The names of the features. Defaults to ``None``.
376
        show: Whether to show the plot. Defaults to ``False``.
377
        max_display: The maximum number of interactions to display. Defaults to ``10``.
378
        abbreviate: Whether to abbreviate the feature names. Defaults to ``True``.
379

380
    Returns:
381
        The plot if ``show`` is ``False``.
382

383
    References:
384
        .. [1] SHAP is available at https://github.com/shap/shap
385

386
    """
387
    if feature_names is None:
1✔
388
        feature_mapping = {i: str(i) for i in range(interaction_values.n_players)}
1✔
389
    else:
390
        if abbreviate:
1✔
391
            feature_names = abbreviate_feature_names(feature_names)
1✔
392
        feature_mapping = {i: feature_names[i] for i in range(interaction_values.n_players)}
1✔
393

394
    # create the data for the waterfall plot in the correct format
395
    data = []
1✔
396
    for feature_tuple, value in interaction_values.dict_values.items():
1✔
397
        if len(feature_tuple) > 0:
1✔
398
            data.append((format_labels(feature_mapping, feature_tuple), str(value)))
1✔
399
    data = np.array(data, dtype=object)
1✔
400
    values = data[:, 1].astype(float)
1✔
401
    feature_names = data[:, 0]
1✔
402

403
    return _draw_waterfall_plot(
1✔
404
        values,
405
        float(interaction_values.baseline_value),
406
        feature_names,
407
        max_display=max_display,
408
        show=show,
409
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc