• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.47
/src/shapiq/plot/beeswarm.py
1
"""Wrapper for the beeswarm plot from the ``shap`` package.
2

3
Note:
4
    Code and implementation was taken and adapted from the [SHAP package](https://github.com/shap/shap)
5
    which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).
6

7
"""
8

9
from __future__ import annotations
1✔
10

11
from typing import TYPE_CHECKING, Literal
1✔
12

13
import matplotlib.colors as mcolors
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
from matplotlib.patches import Rectangle
1✔
18

19
from shapiq.interaction_values import InteractionValues, aggregate_interaction_values
1✔
20

21
from .utils import abbreviate_feature_names
1✔
22

23
if TYPE_CHECKING:
1✔
NEW
24
    from collections.abc import Sequence
×
25

NEW
26
    from matplotlib.axes import Axes
×
NEW
27
    from matplotlib.figure import Figure
×
28

29

30
__all__ = ["beeswarm_plot"]
1✔
31

32

33
def _get_red_blue_cmap() -> mcolors.LinearSegmentedColormap:
1✔
34
    """Creates a red-blue colormap with a smooth transition from blue to red.
35

36
    Returns:
37
        A colormap object that transitions from blue to red.
38
    """
39
    gray_rgb = np.array([0.51615537, 0.51615111, 0.5161729])
1✔
40

41
    cdict: dict[Literal["red", "green", "blue", "alpha"], Sequence[tuple[float, float, float]]] = {
1✔
42
        "red": [
43
            (0.0, 0.0, 0.0),
44
            (0.494949494949495, 0.6035590338007161, 0.6035590338007161),
45
            (1.0, 1.0, 1.0),
46
        ],
47
        "green": [
48
            (0.0, 0.5433775692459107, 0.5433775692459107),
49
            (0.494949494949495, 0.14541587318267168, 0.14541587318267168),
50
            (1.0, 0.0, 0.0),
51
        ],
52
        "blue": [
53
            (0.0, 0.983379062301401, 0.983379062301401),
54
            (0.494949494949495, 0.6828490076357064, 0.6828490076357064),
55
            (1.0, 0.31796406298163893, 0.31796406298163893),
56
        ],
57
        "alpha": [(0, 1.0, 1.0), (0.494949494949495, 1.0, 1.0), (1.0, 1.0, 1.0)],
58
    }
59
    red_blue = mcolors.LinearSegmentedColormap("red_blue", cdict)
1✔
60
    red_blue.set_bad(gray_rgb.tolist(), 1.0)
1✔
61
    red_blue.set_over(gray_rgb.tolist(), 1.0)
1✔
62
    red_blue.set_under(gray_rgb.tolist(), 1.0)
1✔
63
    return red_blue
1✔
64

65

66
def _get_config(row_height: float) -> dict:
1✔
67
    """Returns the configuration for the beeswarm plot.
68

69
    Args:
70
        row_height: Height of each row in the plot.
71

72
    Returns:
73
        Configuration dictionary.
74
    """
75
    config_dict = {
1✔
76
        "dot_size": 10,
77
        "margin_y": 0.01,
78
        "color_nan": "#777777",
79
        "color_lines": "#cccccc",
80
        "color_rectangle": "#eeeeee",
81
        "alpha_rectangle": 0.5,
82
    }
83
    margin = max(-0.1875 * row_height + 0.3875, 0.15)
1✔
84
    margin_label = 0.5 - min(row_height / 3, 0.2)
1✔
85
    config_dict["margin_plot"] = margin
1✔
86
    config_dict["margin_label"] = margin_label
1✔
87
    config_dict["fontsize_ys"] = 10 if row_height <= 0.2 else 11
1✔
88
    return config_dict
1✔
89

90

91
def _beeswarm(interaction_values: np.ndarray, rng: np.random.Generator) -> np.ndarray:
1✔
92
    """Creates vertical offsets for a beeswarm plot.
93

94
    Args:
95
        interaction_values: Interaction values for a given feature.
96
        rng: Random number generator.
97

98
    Returns:
99
        Vertical offsets (ys) for each point.
100
    """
101
    num_interactions = len(interaction_values)
1✔
102
    nbins = 100
1✔
103
    quant = np.round(
1✔
104
        nbins
105
        * (interaction_values - np.min(interaction_values))
106
        / (np.max(interaction_values) - np.min(interaction_values) + 1e-9)
107
    )
108

109
    inds = np.argsort(quant + rng.uniform(-1e-6, 1e-6, num_interactions))
1✔
110

111
    layer = 0
1✔
112
    last_bin = -1
1✔
113
    ys = np.zeros(num_interactions)
1✔
114
    for ind in inds:
1✔
115
        if quant[ind] != last_bin:
1✔
116
            layer = 0
1✔
117
        ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
1✔
118
        layer += 1
1✔
119
        last_bin = quant[ind]
1✔
120
    return ys
1✔
121

122

123
def _calculate_range(num_sub_features: int, i: int, margin: float) -> tuple[float, float]:
1✔
124
    """Calculates the y-axis range for a given sub-feature index in a beeswarm plot.
125

126
    Args:
127
        num_sub_features: Total number of sub-features in the interaction.
128
        i: Index of the current sub-feature.
129
        margin: Margin to apply to the y-axis range.
130

131
    Returns:
132
        A tuple containing the minimum and maximum y-axis values for the sub-feature.
133
    """
134
    if num_sub_features > 1:
1✔
135
        if i == 0:
1✔
136
            y_min = margin / 2 - 0.5
1✔
137
            y_max = 0.5 - margin / 4
1✔
138
        elif i == num_sub_features - 1:
1✔
139
            y_min = margin / 4 - 0.5
1✔
140
            y_max = 0.5 - margin / 2
1✔
141
        else:
142
            y_min = margin / 4 - 0.5
1✔
143
            y_max = 0.5 - margin / 4
1✔
144
    else:
145
        y_min = margin / 2 - 0.5
1✔
146
        y_max = 0.5 - margin / 2
1✔
147
    return y_min, y_max
1✔
148

149

150
def beeswarm_plot(
1✔
151
    interaction_values_list: list[InteractionValues],
152
    data: pd.DataFrame | np.ndarray,
153
    *,
154
    max_display: int | None = 10,
155
    feature_names: list[str] | None = None,
156
    abbreviate: bool = True,
157
    alpha: float = 0.8,
158
    row_height: float = 0.4,
159
    ax: Axes | None = None,
160
    rng_seed: int | None = 42,
161
    show: bool = True,
162
) -> Axes | None:
163
    """Plots a beeswarm plot of SHAP-IQ interaction values. Based on the SHAP beeswarm plot[1]_.
164

165
    The beeswarm plot visualizes how the magnitude and direction of interaction effects are distributed across all samples in the data,
166
    revealing dependencies between the feature's value and the strength of the interaction.
167

168
    Args:
169
        interaction_values_list: A list containing InteractionValues objects.
170
        data: The input data used to compute the interaction values.
171
        max_display: Maximum number of interactions to display. Defaults to 10.
172
        feature_names: Names of the features. If not given, feature indices will be used. Defaults to ``None``.
173
        abbreviate: Whether to abbreviate feature names. Defaults to ``True``.
174
        alpha: The transparency level for the plotted points, ranging from 0 (transparent) to 1
175
            (opaque). Defaults to 0.8.
176
        row_height: The height in inches allocated for each row on the plot. Defaults to 0.4.
177
        ax: ``Matplotlib Axes`` object to plot on. If ``None``, a new figure and axes will be created.
178
        rng_seed: Random seed for reproducibility. Defaults to 42.
179
        show: Whether to show the plot. Defaults to ``True``. If ``False``, the function returns the axis of the plot.
180

181
    Returns:
182
        If ``show`` is ``False``, the function returns the axis of the plot. Otherwise, it returns
183
        ``None``.
184

185
    References:
186
        .. [1] SHAP is available at https://github.com/shap/shap
187
    """
188
    if not isinstance(interaction_values_list, list) or len(interaction_values_list) == 0:
1✔
189
        error_message = "shap_interaction_values must be a non-empty list."
1✔
190
        raise ValueError(error_message)
1✔
191
    if not isinstance(data, pd.DataFrame) and not isinstance(data, np.ndarray):
1✔
192
        error_message = f"data must be a pandas DataFrame or a numpy array. Got: {type(data)}."
1✔
193
        raise TypeError(error_message)
1✔
194
    if len(interaction_values_list) != len(data):
1✔
195
        error_message = "Length of shap_interaction_values must match number of rows in data."
1✔
196
        raise ValueError(error_message)
1✔
197
    if row_height <= 0:
1✔
198
        error_message = "row_height must be a positive value."
1✔
199
        raise ValueError(error_message)
1✔
200
    if alpha <= 0 or alpha > 1:
1✔
201
        error_message = "alpha must be between 0 and 1."
1✔
202
        raise ValueError(error_message)
1✔
203

204
    n_samples = len(data)
1✔
205
    n_players = interaction_values_list[0].n_players
1✔
206

207
    if feature_names is not None:
1✔
208
        if abbreviate:
1✔
209
            feature_names = abbreviate_feature_names(feature_names)
1✔
210
    else:
211
        feature_names = ["F" + str(i) for i in range(n_players)]
1✔
212

213
    if len(feature_names) != n_players:
1✔
214
        error_message = "Length of feature_names must match n_players."
1✔
215
        raise ValueError(error_message)
1✔
216

217
    feature_mapping = dict(enumerate(feature_names))
1✔
218

219
    list_of_abs_interaction_values = [abs(iv) for iv in interaction_values_list]
1✔
220
    global_values: InteractionValues = aggregate_interaction_values(
1✔
221
        list_of_abs_interaction_values, aggregation="mean"
222
    )  # to match the order in bar plots
223

224
    interaction_keys, all_global_interaction_vals = zip(
1✔
225
        *[(k, v) for k, v in global_values.interactions.items() if len(k) != 0], strict=False
226
    )
227

228
    # Sort interactions by aggregated importance
229
    feature_order = np.argsort(all_global_interaction_vals)[::-1]
1✔
230
    if max_display is None:
1✔
231
        max_display = len(feature_order)
1✔
232
    num_interactions_to_display = min(max_display, len(feature_order))
1✔
233
    feature_order = feature_order[:num_interactions_to_display]
1✔
234

235
    interactions_to_plot = [interaction_keys[i] for i in feature_order]
1✔
236

237
    x_numpy = data.to_numpy(dtype=float) if isinstance(data, pd.DataFrame) else data.astype(float)
1✔
238

239
    shap_values_dict = {}
1✔
240

241
    for interaction in interactions_to_plot:
1✔
242
        shap_values_dict[interaction] = np.array(
1✔
243
            [sv.dict_values[interaction] for sv in interaction_values_list]
244
        )
245

246
    total_sub_features = sum(len(inter) for inter in interactions_to_plot)
1✔
247
    if ax is None:
1✔
248
        fig_height = total_sub_features * row_height + 1.5
1✔
249
        fig_width = 8 + 0.3 * max(
1✔
250
            [
251
                np.max([len(feature_mapping[f]) for f in interaction])
252
                for interaction in interactions_to_plot
253
            ]
254
        )
255
        ax = plt.gca()
1✔
256
        fig = plt.gcf()
1✔
257
        fig.set_size_inches(fig_width, fig_height)
1✔
258
    else:
259
        fig: Figure = ax.get_figure()  # pyright: ignore[reportAssignmentType]. Axes will always be a figure as Subfigure would not provide get_size_inches()
1✔
260
        row_height = (fig.get_size_inches()[1] - 1.5) / total_sub_features
1✔
261
    config_dict = _get_config(row_height)
1✔
262

263
    cmap = _get_red_blue_cmap()
1✔
264

265
    y_level = 0  # start plotting from the bottom
1✔
266
    y_tick_labels_formatted = {"y": [], "text": []}
1✔
267
    h_lines = []  # horizontal lines between interaction groups
1✔
268
    rectangles = []
1✔
269

270
    margin_label = config_dict["margin_label"]
1✔
271
    # iterate through interactions in reverse order for plotting (bottom-up)
272
    for interaction_index, interaction in enumerate(reversed(interactions_to_plot)):
1✔
273
        num_sub_features = len(interaction)
1✔
274

275
        if interaction_index % 2 == 0:
1✔
276
            bottom_y = y_level - 0.5
1✔
277
            height = num_sub_features
1✔
278
            if bottom_y == -0.5:
1✔
279
                bottom_y -= config_dict["margin_y"]
1✔
280
                height += config_dict["margin_y"]
1✔
281
            rectangles.append((bottom_y, height))
1✔
282

283
        group_midpoint_y = y_level + (num_sub_features - 1) / 2.0
1✔
284
        num_labels = num_sub_features + max(num_sub_features - 1, 0)
1✔
285
        bottom_y = group_midpoint_y - margin_label * (num_labels - 1) / 2
1✔
286
        upper_y = group_midpoint_y + margin_label * (num_labels - 1) / 2
1✔
287
        positions = (
1✔
288
            np.linspace(bottom_y, upper_y, num_labels)
289
            if num_sub_features > 1
290
            else np.array([group_midpoint_y])
291
        )
292
        j = 0
1✔
293
        for i, label in enumerate(reversed(interaction)):
1✔
294
            lb = feature_mapping[label]
1✔
295
            current_group_midpoint = positions[i + j]
1✔
296

297
            y_tick_labels_formatted["y"].append(current_group_midpoint)
1✔
298
            y_tick_labels_formatted["text"].append(lb)
1✔
299

300
            if i < num_sub_features - 1:
1✔
301
                y_tick_labels_formatted["y"].append(positions[i + j + 1])
1✔
302
                y_tick_labels_formatted["text"].append("x")
1✔
303
                j += 1
1✔
304

305
        # add horizontal lines
306
        if 0 < interaction_index < len(interactions_to_plot) - 1:
1✔
307
            upper_point = group_midpoint_y - num_sub_features / 2.0
1✔
308
            lower_point = group_midpoint_y + num_sub_features / 2.0
1✔
309
            h_lines.append(upper_point)
1✔
310
            h_lines.append(lower_point)
1✔
311

312
        current_shap_values = shap_values_dict[interaction]
1✔
313

314
        # calculate beeswarm offsets
315
        ys_raw = _beeswarm(current_shap_values, rng=np.random.default_rng(rng_seed))
1✔
316
        for i, sub_feature_idx in enumerate(interaction):
1✔
317
            y_min, y_max = _calculate_range(num_sub_features, i, config_dict["margin_plot"])
1✔
318
            range_y = np.max(ys_raw) - np.min(ys_raw) if np.max(ys_raw) != np.min(ys_raw) else 1.0
1✔
319
            ys = y_min + (ys_raw - np.min(ys_raw)) * (y_max - y_min) / range_y
1✔
320
            feature_values = x_numpy[:, sub_feature_idx]
1✔
321

322
            # nan handling - plotting as gray
323
            nan_mask = np.isnan(feature_values)
1✔
324
            valid_mask = ~nan_mask
1✔
325

326
            valid_feature_values = feature_values[valid_mask]
1✔
327
            if len(valid_feature_values) > 0:
1✔
328
                vmin = np.min(valid_feature_values)
1✔
329
                vmax = np.max(valid_feature_values)
1✔
330
            else:
331
                vmin = 0
1✔
332
                vmax = 1
1✔
333
            if vmin == vmax:
1✔
334
                vmin -= 1e-9
1✔
335
                vmax += 1e-9
1✔
336

337
            ax.scatter(
1✔
338
                x=current_shap_values[nan_mask],
339
                y=y_level + ys[nan_mask],
340
                color=config_dict["color_nan"],
341
                s=config_dict["dot_size"],
342
                alpha=alpha * 0.5,
343
                linewidth=0,
344
                rasterized=n_samples > 500,
345
                zorder=2,
346
            )
347

348
            # valid points
349
            ax.scatter(
1✔
350
                x=current_shap_values[valid_mask],
351
                y=y_level + ys[valid_mask],
352
                c=feature_values[valid_mask],
353
                cmap=cmap,
354
                vmin=vmin,
355
                vmax=vmax,
356
                s=config_dict["dot_size"],
357
                alpha=alpha,
358
                linewidth=0,
359
                rasterized=n_samples > 500,
360
                zorder=2,
361
            )
362
            y_level += 1
1✔
363

364
    # add horizontal grid lines between interaction groups
365
    h_lines = list(set(h_lines))
1✔
366
    for y_line in h_lines:
1✔
367
        ax.axhline(
1✔
368
            y=y_line,
369
            color=config_dict["color_lines"],
370
            linestyle="--",
371
            linewidth=0.5,
372
            alpha=0.8,
373
            zorder=-1,
374
        )
375

376
    ax.xaxis.grid(
1✔
377
        visible=True,
378
        color=config_dict["color_lines"],
379
        linestyle="--",
380
        linewidth=0.5,
381
        alpha=0.8,
382
        zorder=-1,
383
    )
384

385
    ax.axvline(x=0, color="#999999", linestyle="-", linewidth=1, zorder=1)
1✔
386
    ax.set_axisbelow(True)
1✔
387

388
    ax.set_xlabel("SHAP-IQ Interaction Value (impact on model output)", fontsize=12)
1✔
389
    ax.set_ylabel("")
1✔
390

391
    ax.tick_params(axis="y", length=0)
1✔
392
    ax.tick_params(axis="x", labelsize=10)
1✔
393

394
    xlims = ax.get_xlim()
1✔
395
    for y_coords in rectangles:
1✔
396
        bottom_y, height = y_coords
1✔
397

398
        x_left, x_right = xlims[0], xlims[1]
1✔
399
        rect = Rectangle(
1✔
400
            (x_left, bottom_y),
401
            x_right - x_left,
402
            height,
403
            facecolor=config_dict["color_rectangle"],
404
            edgecolor=config_dict["color_rectangle"],
405
            alpha=config_dict["alpha_rectangle"],
406
            zorder=-3,
407
        )
408
        ax.add_patch(rect)
1✔
409

410
    ax.set_yticks(y_tick_labels_formatted["y"])
1✔
411
    ax.set_yticklabels(y_tick_labels_formatted["text"], fontsize=config_dict["fontsize_ys"])
1✔
412

413
    ax.set_ylim(-0.5 - config_dict["margin_y"], y_level - 0.5 + config_dict["margin_y"])
1✔
414

415
    ax.spines["top"].set_visible(False)
1✔
416
    ax.spines["right"].set_visible(False)
1✔
417
    ax.spines["left"].set_visible(False)
1✔
418

419
    m = plt.cm.ScalarMappable(cmap=cmap)
1✔
420
    m.set_array([0, 1])
1✔
421
    cb = fig.colorbar(m, ax=ax, ticks=[0, 1], aspect=80)
1✔
422
    cb.set_ticklabels(["Low", "High"])
1✔
423
    cb.set_label("Feature value", size=12, labelpad=0)
1✔
424
    cb.ax.tick_params(labelsize=11, length=0)
1✔
425
    cb.set_alpha(1)
1✔
426
    cb.outline.set_visible(False)  # pyright: ignore[reportCallIssue]. TODO(advueu963): This seems to work but statically not safe
1✔
427

428
    plt.tight_layout(rect=(0, 0, 0.95, 1))
1✔
429

430
    if not show:
1✔
431
        return ax
1✔
432
    plt.show()
×
433
    return None
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc