• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18449618793

12 Oct 2025 09:16PM UTC coverage: 93.266% (-0.6%) from 93.845%
18449618793

Pull #430

github

web-flow
Merge 4a26a5ad3 into dede390c9
Pull Request #430: Enhance type safety and fix bugs across the codebase

278 of 326 new or added lines in 46 files covered. (85.28%)

12 existing lines in 9 files now uncovered.

4986 of 5346 relevant lines covered (93.27%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.52
/src/shapiq/plot/beeswarm.py
1
"""Wrapper for the beeswarm plot from the ``shap`` package.
2

3
Note:
4
    Code and implementation was taken and adapted from the [SHAP package](https://github.com/shap/shap)
5
    which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).
6

7
"""
8

9
from __future__ import annotations
1✔
10

11
from typing import TYPE_CHECKING, Literal
1✔
12

13
import matplotlib.colors as mcolors
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
from matplotlib.patches import Rectangle
1✔
18

19
from shapiq.interaction_values import InteractionValues, aggregate_interaction_values
1✔
20

21
from .utils import abbreviate_feature_names
1✔
22

23
if TYPE_CHECKING:
1✔
NEW
24
    from collections.abc import Sequence
×
25

NEW
26
    from matplotlib.axes import Axes
×
NEW
27
    from matplotlib.figure import Figure
×
28

29

30
__all__ = ["beeswarm_plot"]
1✔
31

32

33
def _get_red_blue_cmap() -> mcolors.LinearSegmentedColormap:
1✔
34
    """Creates a red-blue colormap with a smooth transition from blue to red.
35

36
    Returns:
37
        A colormap object that transitions from blue to red.
38
    """
39
    gray_rgb = np.array([0.51615537, 0.51615111, 0.5161729])
1✔
40

41
    cdict: dict[Literal["red", "green", "blue", "alpha"], Sequence[tuple[float, float, float]]] = {
1✔
42
        "red": [
43
            (0.0, 0.0, 0.0),
44
            (0.494949494949495, 0.6035590338007161, 0.6035590338007161),
45
            (1.0, 1.0, 1.0),
46
        ],
47
        "green": [
48
            (0.0, 0.5433775692459107, 0.5433775692459107),
49
            (0.494949494949495, 0.14541587318267168, 0.14541587318267168),
50
            (1.0, 0.0, 0.0),
51
        ],
52
        "blue": [
53
            (0.0, 0.983379062301401, 0.983379062301401),
54
            (0.494949494949495, 0.6828490076357064, 0.6828490076357064),
55
            (1.0, 0.31796406298163893, 0.31796406298163893),
56
        ],
57
        "alpha": [(0, 1.0, 1.0), (0.494949494949495, 1.0, 1.0), (1.0, 1.0, 1.0)],
58
    }
59
    red_blue = mcolors.LinearSegmentedColormap("red_blue", cdict)
1✔
60
    red_blue.set_bad(gray_rgb.tolist(), 1.0)
1✔
61
    red_blue.set_over(gray_rgb.tolist(), 1.0)
1✔
62
    red_blue.set_under(gray_rgb.tolist(), 1.0)
1✔
63
    return red_blue
1✔
64

65

66
def _get_config(row_height: float) -> dict:
1✔
67
    """Returns the configuration for the beeswarm plot.
68

69
    Args:
70
        row_height: Height of each row in the plot.
71

72
    Returns:
73
        Configuration dictionary.
74
    """
75
    config_dict = {
1✔
76
        "dot_size": 10,
77
        "margin_y": 0.01,
78
        "color_nan": "#777777",
79
        "color_lines": "#cccccc",
80
        "color_rectangle": "#eeeeee",
81
        "alpha_rectangle": 0.5,
82
    }
83
    margin = max(-0.1875 * row_height + 0.3875, 0.15)
1✔
84
    margin_label = 0.5 - min(row_height / 3, 0.2)
1✔
85
    config_dict["margin_plot"] = margin
1✔
86
    config_dict["margin_label"] = margin_label
1✔
87
    config_dict["fontsize_ys"] = 10 if row_height <= 0.2 else 11
1✔
88
    return config_dict
1✔
89

90

91
def _beeswarm(interaction_values: np.ndarray, rng: np.random.Generator) -> np.ndarray:
1✔
92
    """Creates vertical offsets for a beeswarm plot.
93

94
    Args:
95
        interaction_values: Interaction values for a given feature.
96
        rng: Random number generator.
97

98
    Returns:
99
        Vertical offsets (ys) for each point.
100
    """
101
    num_interactions = len(interaction_values)
1✔
102
    nbins = 100
1✔
103
    quant = np.round(
1✔
104
        nbins
105
        * (interaction_values - np.min(interaction_values))
106
        / (np.max(interaction_values) - np.min(interaction_values) + 1e-9)
107
    )
108

109
    inds = np.argsort(quant + rng.uniform(-1e-6, 1e-6, num_interactions))
1✔
110

111
    layer = 0
1✔
112
    last_bin = -1
1✔
113
    ys = np.zeros(num_interactions)
1✔
114
    for ind in inds:
1✔
115
        if quant[ind] != last_bin:
1✔
116
            layer = 0
1✔
117
        ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
1✔
118
        layer += 1
1✔
119
        last_bin = quant[ind]
1✔
120
    return ys
1✔
121

122

123
def _calculate_range(num_sub_features: int, i: int, margin: float) -> tuple[float, float]:
1✔
124
    """Calculates the y-axis range for a given sub-feature index in a beeswarm plot.
125

126
    Args:
127
        num_sub_features: Total number of sub-features in the interaction.
128
        i: Index of the current sub-feature.
129
        margin: Margin to apply to the y-axis range.
130

131
    Returns:
132
        A tuple containing the minimum and maximum y-axis values for the sub-feature.
133
    """
134
    if num_sub_features > 1:
1✔
135
        if i == 0:
1✔
136
            y_min = margin / 2 - 0.5
1✔
137
            y_max = 0.5 - margin / 4
1✔
138
        elif i == num_sub_features - 1:
1✔
139
            y_min = margin / 4 - 0.5
1✔
140
            y_max = 0.5 - margin / 2
1✔
141
        else:
142
            y_min = margin / 4 - 0.5
1✔
143
            y_max = 0.5 - margin / 4
1✔
144
    else:
145
        y_min = margin / 2 - 0.5
1✔
146
        y_max = 0.5 - margin / 2
1✔
147
    return y_min, y_max
1✔
148

149

150
def beeswarm_plot(
1✔
151
    interaction_values_list: list[InteractionValues],
152
    data: pd.DataFrame | np.ndarray,
153
    *,
154
    max_display: int | None = 10,
155
    feature_names: list[str] | None = None,
156
    abbreviate: bool = True,
157
    alpha: float = 0.8,
158
    row_height: float = 0.4,
159
    ax: Axes | None = None,
160
    rng_seed: int | None = 42,
161
    show: bool = True,
162
) -> Axes | None:
163
    """Plots a beeswarm plot of SHAP-IQ interaction values. Based on the SHAP beeswarm plot[1]_.
164

165
    The beeswarm plot visualizes how the magnitude and direction of interaction effects are distributed across all samples in the data,
166
    revealing dependencies between the feature's value and the strength of the interaction.
167

168
    Args:
169
        interaction_values_list: A list containing InteractionValues objects.
170
        data: The input data used to compute the interaction values.
171
        max_display: Maximum number of interactions to display. Defaults to 10.
172
        feature_names: Names of the features. If not given, feature indices will be used. Defaults to ``None``.
173
        abbreviate: Whether to abbreviate feature names. Defaults to ``True``.
174
        alpha: The transparency level for the plotted points, ranging from 0 (transparent) to 1
175
            (opaque). Defaults to 0.8.
176
        row_height: The height in inches allocated for each row on the plot. Defaults to 0.4.
177
        ax: ``Matplotlib Axes`` object to plot on. If ``None``, a new figure and axes will be created.
178
        rng_seed: Random seed for reproducibility. Defaults to 42.
179
        show: Whether to show the plot. Defaults to ``True``. If ``False``, the function returns the axis of the plot.
180

181
    Returns:
182
        If ``show`` is ``False``, the function returns the axis of the plot. Otherwise, it returns
183
        ``None``.
184

185
    References:
186
        .. [1] SHAP is available at https://github.com/shap/shap
187
    """
188
    if not isinstance(interaction_values_list, list) or len(interaction_values_list) == 0:
1✔
189
        error_message = "shap_interaction_values must be a non-empty list."
1✔
190
        raise ValueError(error_message)
1✔
191
    if not isinstance(data, pd.DataFrame) and not isinstance(data, np.ndarray):
1✔
192
        error_message = f"data must be a pandas DataFrame or a numpy array. Got: {type(data)}."
1✔
193
        raise TypeError(error_message)
1✔
194
    if len(interaction_values_list) != len(data):
1✔
195
        error_message = "Length of shap_interaction_values must match number of rows in data."
1✔
196
        raise ValueError(error_message)
1✔
197
    if row_height <= 0:
1✔
198
        error_message = "row_height must be a positive value."
1✔
199
        raise ValueError(error_message)
1✔
200
    if alpha <= 0 or alpha > 1:
1✔
201
        error_message = "alpha must be between 0 and 1."
1✔
202
        raise ValueError(error_message)
1✔
203

204
    n_samples = len(data)
1✔
205
    n_players = interaction_values_list[0].n_players
1✔
206

207
    if feature_names is not None:
1✔
208
        if abbreviate:
1✔
209
            feature_names = abbreviate_feature_names(feature_names)
1✔
210
    else:
211
        feature_names = ["F" + str(i) for i in range(n_players)]
1✔
212

213
    if len(feature_names) != n_players:
1✔
214
        error_message = "Length of feature_names must match n_players."
1✔
215
        raise ValueError(error_message)
1✔
216

217
    feature_mapping = dict(enumerate(feature_names))
1✔
218

219
    list_of_abs_interaction_values = [abs(iv) for iv in interaction_values_list]
1✔
220
    global_values: InteractionValues = aggregate_interaction_values(
1✔
221
        list_of_abs_interaction_values, aggregation="mean"
222
    )  # to match the order in bar plots
223

224
    interaction_keys = list(global_values.interaction_lookup.keys())
1✔
225
    all_global_interaction_vals = global_values.values  # noqa: PD011  # since ruff thinks this is a dataframe
1✔
226
    if interaction_keys[0] == ():  # check for base value
1✔
227
        interaction_keys = interaction_keys[1:]
1✔
228
        all_global_interaction_vals = all_global_interaction_vals[1:]
1✔
229

230
    # Sort interactions by aggregated importance
231
    feature_order = np.argsort(all_global_interaction_vals)[::-1]
1✔
232
    if max_display is None:
1✔
233
        max_display = len(feature_order)
1✔
234
    num_interactions_to_display = min(max_display, len(feature_order))
1✔
235
    feature_order = feature_order[:num_interactions_to_display]
1✔
236

237
    interactions_to_plot = [interaction_keys[i] for i in feature_order]
1✔
238

239
    x_numpy = data.to_numpy(dtype=float) if isinstance(data, pd.DataFrame) else data.astype(float)
1✔
240

241
    shap_values_dict = {}
1✔
242

243
    for interaction in interactions_to_plot:
1✔
244
        shap_values_dict[interaction] = np.array(
1✔
245
            [sv.dict_values[interaction] for sv in interaction_values_list]
246
        )
247

248
    total_sub_features = sum(len(inter) for inter in interactions_to_plot)
1✔
249
    if ax is None:
1✔
250
        fig_height = total_sub_features * row_height + 1.5
1✔
251
        fig_width = 8 + 0.3 * max(
1✔
252
            [
253
                np.max([len(feature_mapping[f]) for f in interaction])
254
                for interaction in interactions_to_plot
255
            ]
256
        )
257
        ax = plt.gca()
1✔
258
        fig = plt.gcf()
1✔
259
        fig.set_size_inches(fig_width, fig_height)
1✔
260
    else:
261
        fig: Figure = ax.get_figure()  # pyright: ignore[reportAssignmentType]. Axes will always be a figure as Subfigure would not provide get_size_inches()
1✔
262
        row_height = (fig.get_size_inches()[1] - 1.5) / total_sub_features
1✔
263
    config_dict = _get_config(row_height)
1✔
264

265
    cmap = _get_red_blue_cmap()
1✔
266

267
    y_level = 0  # start plotting from the bottom
1✔
268
    y_tick_labels_formatted = {"y": [], "text": []}
1✔
269
    h_lines = []  # horizontal lines between interaction groups
1✔
270
    rectangles = []
1✔
271

272
    margin_label = config_dict["margin_label"]
1✔
273
    # iterate through interactions in reverse order for plotting (bottom-up)
274
    for interaction_index, interaction in enumerate(reversed(interactions_to_plot)):
1✔
275
        num_sub_features = len(interaction)
1✔
276

277
        if interaction_index % 2 == 0:
1✔
278
            bottom_y = y_level - 0.5
1✔
279
            height = num_sub_features
1✔
280
            if bottom_y == -0.5:
1✔
281
                bottom_y -= config_dict["margin_y"]
1✔
282
                height += config_dict["margin_y"]
1✔
283
            rectangles.append((bottom_y, height))
1✔
284

285
        group_midpoint_y = y_level + (num_sub_features - 1) / 2.0
1✔
286
        num_labels = num_sub_features + max(num_sub_features - 1, 0)
1✔
287
        bottom_y = group_midpoint_y - margin_label * (num_labels - 1) / 2
1✔
288
        upper_y = group_midpoint_y + margin_label * (num_labels - 1) / 2
1✔
289
        positions = (
1✔
290
            np.linspace(bottom_y, upper_y, num_labels)
291
            if num_sub_features > 1
292
            else np.array([group_midpoint_y])
293
        )
294
        j = 0
1✔
295
        for i, label in enumerate(reversed(interaction)):
1✔
296
            lb = feature_mapping[label]
1✔
297
            current_group_midpoint = positions[i + j]
1✔
298

299
            y_tick_labels_formatted["y"].append(current_group_midpoint)
1✔
300
            y_tick_labels_formatted["text"].append(lb)
1✔
301

302
            if i < num_sub_features - 1:
1✔
303
                y_tick_labels_formatted["y"].append(positions[i + j + 1])
1✔
304
                y_tick_labels_formatted["text"].append("x")
1✔
305
                j += 1
1✔
306

307
        # add horizontal lines
308
        if 0 < interaction_index < len(interactions_to_plot) - 1:
1✔
309
            upper_point = group_midpoint_y - num_sub_features / 2.0
1✔
310
            lower_point = group_midpoint_y + num_sub_features / 2.0
1✔
311
            h_lines.append(upper_point)
1✔
312
            h_lines.append(lower_point)
1✔
313

314
        current_shap_values = shap_values_dict[interaction]
1✔
315

316
        # calculate beeswarm offsets
317
        ys_raw = _beeswarm(current_shap_values, rng=np.random.default_rng(rng_seed))
1✔
318
        for i, sub_feature_idx in enumerate(interaction):
1✔
319
            y_min, y_max = _calculate_range(num_sub_features, i, config_dict["margin_plot"])
1✔
320
            range_y = np.max(ys_raw) - np.min(ys_raw) if np.max(ys_raw) != np.min(ys_raw) else 1.0
1✔
321
            ys = y_min + (ys_raw - np.min(ys_raw)) * (y_max - y_min) / range_y
1✔
322
            feature_values = x_numpy[:, sub_feature_idx]
1✔
323

324
            # nan handling - plotting as gray
325
            nan_mask = np.isnan(feature_values)
1✔
326
            valid_mask = ~nan_mask
1✔
327

328
            valid_feature_values = feature_values[valid_mask]
1✔
329
            if len(valid_feature_values) > 0:
1✔
330
                vmin = np.min(valid_feature_values)
1✔
331
                vmax = np.max(valid_feature_values)
1✔
332
            else:
333
                vmin = 0
1✔
334
                vmax = 1
1✔
335
            if vmin == vmax:
1✔
336
                vmin -= 1e-9
1✔
337
                vmax += 1e-9
1✔
338

339
            ax.scatter(
1✔
340
                x=current_shap_values[nan_mask],
341
                y=y_level + ys[nan_mask],
342
                color=config_dict["color_nan"],
343
                s=config_dict["dot_size"],
344
                alpha=alpha * 0.5,
345
                linewidth=0,
346
                rasterized=n_samples > 500,
347
                zorder=2,
348
            )
349

350
            # valid points
351
            ax.scatter(
1✔
352
                x=current_shap_values[valid_mask],
353
                y=y_level + ys[valid_mask],
354
                c=feature_values[valid_mask],
355
                cmap=cmap,
356
                vmin=vmin,
357
                vmax=vmax,
358
                s=config_dict["dot_size"],
359
                alpha=alpha,
360
                linewidth=0,
361
                rasterized=n_samples > 500,
362
                zorder=2,
363
            )
364
            y_level += 1
1✔
365

366
    # add horizontal grid lines between interaction groups
367
    h_lines = list(set(h_lines))
1✔
368
    for y_line in h_lines:
1✔
369
        ax.axhline(
1✔
370
            y=y_line,
371
            color=config_dict["color_lines"],
372
            linestyle="--",
373
            linewidth=0.5,
374
            alpha=0.8,
375
            zorder=-1,
376
        )
377

378
    ax.xaxis.grid(
1✔
379
        visible=True,
380
        color=config_dict["color_lines"],
381
        linestyle="--",
382
        linewidth=0.5,
383
        alpha=0.8,
384
        zorder=-1,
385
    )
386

387
    ax.axvline(x=0, color="#999999", linestyle="-", linewidth=1, zorder=1)
1✔
388
    ax.set_axisbelow(True)
1✔
389

390
    ax.set_xlabel("SHAP-IQ Interaction Value (impact on model output)", fontsize=12)
1✔
391
    ax.set_ylabel("")
1✔
392

393
    ax.tick_params(axis="y", length=0)
1✔
394
    ax.tick_params(axis="x", labelsize=10)
1✔
395

396
    xlims = ax.get_xlim()
1✔
397
    for y_coords in rectangles:
1✔
398
        bottom_y, height = y_coords
1✔
399

400
        x_left, x_right = xlims[0], xlims[1]
1✔
401
        rect = Rectangle(
1✔
402
            (x_left, bottom_y),
403
            x_right - x_left,
404
            height,
405
            facecolor=config_dict["color_rectangle"],
406
            edgecolor=config_dict["color_rectangle"],
407
            alpha=config_dict["alpha_rectangle"],
408
            zorder=-3,
409
        )
410
        ax.add_patch(rect)
1✔
411

412
    ax.set_yticks(y_tick_labels_formatted["y"])
1✔
413
    ax.set_yticklabels(y_tick_labels_formatted["text"], fontsize=config_dict["fontsize_ys"])
1✔
414

415
    ax.set_ylim(-0.5 - config_dict["margin_y"], y_level - 0.5 + config_dict["margin_y"])
1✔
416

417
    ax.spines["top"].set_visible(False)
1✔
418
    ax.spines["right"].set_visible(False)
1✔
419
    ax.spines["left"].set_visible(False)
1✔
420

421
    m = plt.cm.ScalarMappable(cmap=cmap)
1✔
422
    m.set_array([0, 1])
1✔
423
    cb = fig.colorbar(m, ax=ax, ticks=[0, 1], aspect=80)
1✔
424
    cb.set_ticklabels(["Low", "High"])
1✔
425
    cb.set_label("Feature value", size=12, labelpad=0)
1✔
426
    cb.ax.tick_params(labelsize=11, length=0)
1✔
427
    cb.set_alpha(1)
1✔
428
    cb.outline.set_visible(False)  # pyright: ignore[reportCallIssue]. TODO(advueu963): This seems to work but statically not safe
1✔
429

430
    plt.tight_layout(rect=(0, 0, 0.95, 1))
1✔
431

432
    if not show:
1✔
433
        return ax
1✔
434
    plt.show()
×
435
    return None
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc