• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.95
/src/shapiq/plot/force.py
1
"""Wrapper for the force plot from the ``shap`` package.
2

3
Note:
4
    Code and implementation was taken and adapted from the [SHAP package](https://github.com/shap/shap)
5
    which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).
6

7
"""
8

9
from __future__ import annotations
1✔
10

11
from typing import TYPE_CHECKING
1✔
12

13
import matplotlib.pyplot as plt
1✔
14
import numpy as np
1✔
15
from matplotlib import lines
1✔
16
from matplotlib.colors import LinearSegmentedColormap
1✔
17
from matplotlib.font_manager import FontProperties
1✔
18
from matplotlib.patches import PathPatch, Polygon
1✔
19
from matplotlib.path import Path
1✔
20

21
from .utils import abbreviate_feature_names, format_labels
1✔
22

23
if TYPE_CHECKING:
1✔
NEW
24
    from matplotlib.axes import Axes
×
NEW
25
    from matplotlib.figure import Figure
×
26

UNCOV
27
    from shapiq.interaction_values import InteractionValues
×
28

29

30
__all__ = ["force_plot"]
1✔
31

32

33
def _create_bars(
1✔
34
    out_value: float,
35
    features: np.ndarray,
36
    feature_type: str,
37
    width_separators: float,
38
    width_bar: float,
39
) -> tuple[list, list]:
40
    rectangle_list = []
1✔
41
    separator_list = []
1✔
42

43
    pre_val = out_value
1✔
44
    for index, feature_iteration in zip(range(len(features)), features, strict=False):
1✔
45
        if feature_type == "positive":
1✔
46
            left_bound = float(feature_iteration[0])
1✔
47
            right_bound = pre_val
1✔
48
            pre_val = left_bound
1✔
49

50
            separator_indent = np.abs(width_separators)
1✔
51
            separator_pos = left_bound
1✔
52
            colors = ["#FF0D57", "#FFC3D5"]
1✔
53
        else:
54
            left_bound = pre_val
1✔
55
            right_bound = float(feature_iteration[0])
1✔
56
            pre_val = right_bound
1✔
57

58
            separator_indent = -np.abs(width_separators)
1✔
59
            separator_pos = right_bound
1✔
60
            colors = ["#1E88E5", "#D1E6FA"]
1✔
61

62
        # Create rectangle
63
        if index == 0:
1✔
64
            if feature_type == "positive":
1✔
65
                points_rectangle = [
1✔
66
                    [left_bound, 0],
67
                    [right_bound, 0],
68
                    [right_bound, width_bar],
69
                    [left_bound, width_bar],
70
                    [left_bound + separator_indent, (width_bar / 2)],
71
                ]
72
            else:
73
                points_rectangle = [
1✔
74
                    [right_bound, 0],
75
                    [left_bound, 0],
76
                    [left_bound, width_bar],
77
                    [right_bound, width_bar],
78
                    [right_bound + separator_indent, (width_bar / 2)],
79
                ]
80

81
        else:
82
            points_rectangle = [
1✔
83
                [left_bound, 0],
84
                [right_bound, 0],
85
                [right_bound + separator_indent * 0.90, (width_bar / 2)],
86
                [right_bound, width_bar],
87
                [left_bound, width_bar],
88
                [left_bound + separator_indent * 0.90, (width_bar / 2)],
89
            ]
90

91
        line = Polygon(
1✔
92
            points_rectangle,
93
            closed=True,
94
            fill=True,
95
            facecolor=colors[0],
96
            linewidth=0,
97
        )
98
        rectangle_list += [line]
1✔
99

100
        # Create separator
101
        points_separator = [
1✔
102
            [separator_pos, 0],
103
            [separator_pos + separator_indent, (width_bar / 2)],
104
            [separator_pos, width_bar],
105
        ]
106

107
        line = Polygon(points_separator, closed=False, fill=None, edgecolor=colors[1], lw=3)
1✔
108
        separator_list += [line]
1✔
109

110
    return rectangle_list, separator_list
1✔
111

112

113
def _add_labels(
1✔
114
    fig: Figure,
115
    ax: Axes,
116
    out_value: float,
117
    features: np.ndarray,
118
    feature_type: str,
119
    offset_text: float,
120
    total_effect: float = 0,
121
    min_perc: float = 0.05,
122
    text_rotation: float = 0,
123
) -> tuple[Figure, Axes]:
124
    """Add labels to the plot.
125

126
    Args:
127
        fig: Figure of the plot
128
        ax: Axes of the plot
129
        out_value: output value
130
        features: The values and names of the features
131
        feature_type: Indicating whether positive or negative features
132
        offset_text: value to offset name of the features
133
        total_effect: Total value of all features. Used to filter out features that do not contribute at least min_perc to the total effect.
134
        Defaults to 0 indicating that all features are shown.
135
        min_perc: minimal percentage of the total effect that a feature must contribute to be shown. Defaults to 0.05.
136
        text_rotation: Degree the text should be rotated. Defaults to 0.
137
    """
138
    start_text = out_value
1✔
139
    pre_val = out_value
1✔
140

141
    # Define variables specific to positive and negative effect features
142
    if feature_type == "positive":
1✔
143
        colors = ["#FF0D57", "#FFC3D5"]
1✔
144
        alignment = "right"
1✔
145
        sign = 1
1✔
146
    else:
147
        colors = ["#1E88E5", "#D1E6FA"]
1✔
148
        alignment = "left"
1✔
149
        sign = -1
1✔
150

151
    # Draw initial line
152
    if feature_type == "positive":
1✔
153
        x, y = np.array([[pre_val, pre_val], [0, -0.18]])
1✔
154
        line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0])
1✔
155
        line.set_clip_on(False)
1✔
156
        ax.add_line(line)
1✔
157
        start_text = pre_val
1✔
158

159
    box_end = out_value
1✔
160
    val = out_value
1✔
161
    for feature in features:
1✔
162
        # Exclude all labels that do not contribute at least 10% to the total
163
        feature_contribution = np.abs(float(feature[0]) - pre_val) / np.abs(total_effect)
1✔
164
        if feature_contribution < min_perc:
1✔
165
            break
1✔
166

167
        # Compute value for current feature
168
        val = float(feature[0])
1✔
169

170
        # Draw labels.
171
        text = feature[1]
1✔
172

173
        va_alignment = "top" if text_rotation != 0 else "baseline"
1✔
174

175
        text_out_val = plt.text(
1✔
176
            start_text - sign * offset_text,
177
            -0.15,
178
            text,
179
            fontsize=12,
180
            color=colors[0],
181
            horizontalalignment=alignment,
182
            va=va_alignment,
183
            rotation=text_rotation,
184
        )
185
        text_out_val.set_bbox({"facecolor": "none", "edgecolor": "none"})
1✔
186

187
        # We need to draw the plot to be able to get the size of the
188
        # text box
189
        fig.canvas.draw()
1✔
190
        box_size = text_out_val.get_bbox_patch().get_extents().transformed(ax.transData.inverted())  # pyright: ignore[reportOptionalMemberAccess]
1✔
191
        if feature_type == "positive":
1✔
192
            box_end_ = box_size.get_points()[0][0]
1✔
193
        else:
194
            box_end_ = box_size.get_points()[1][0]
1✔
195

196
        # Create end line
197
        if (sign * box_end_) > (sign * val):
1✔
198
            x, y = np.array([[val, val], [0, -0.18]])
1✔
199
            line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0])
1✔
200
            line.set_clip_on(False)
1✔
201
            ax.add_line(line)
1✔
202
            start_text = val
1✔
203
            box_end = val
1✔
204

205
        else:
206
            box_end = box_end_ - sign * offset_text
1✔
207
            x, y = np.array([[val, box_end, box_end], [0, -0.08, -0.18]])
1✔
208
            line = lines.Line2D(x, y, lw=1.0, alpha=0.5, color=colors[0])
1✔
209
            line.set_clip_on(False)
1✔
210
            ax.add_line(line)
1✔
211
            start_text = box_end
1✔
212

213
        # Update previous value
214
        pre_val = float(feature[0])
1✔
215

216
    # Create line for labels
217
    extent_shading = (out_value, box_end, 0, -0.31)
1✔
218
    path = [
1✔
219
        [out_value, 0],
220
        [pre_val, 0],
221
        [box_end, -0.08],
222
        [box_end, -0.2],
223
        [out_value, -0.2],
224
        [out_value, 0],
225
    ]
226

227
    path = Path(path)
1✔
228
    patch = PathPatch(path, facecolor="none", edgecolor="none")
1✔
229
    ax.add_patch(patch)
1✔
230

231
    # Extend axis if needed
232
    lower_lim, upper_lim = ax.get_xlim()
1✔
233
    if box_end < lower_lim:
1✔
234
        ax.set_xlim(box_end, upper_lim)
×
235

236
    if box_end > upper_lim:
1✔
237
        ax.set_xlim(lower_lim, box_end)
×
238

239
    # Create shading
240
    if feature_type == "positive":
1✔
241
        colors = np.array([(255, 13, 87), (255, 255, 255)]) / 255.0
1✔
242
    else:
243
        colors = np.array([(30, 136, 229), (255, 255, 255)]) / 255.0
1✔
244
    cm = LinearSegmentedColormap.from_list("cm", colors)
1✔
245

246
    _, z2 = np.meshgrid(np.linspace(0, 10), np.linspace(-10, 10))
1✔
247
    im = plt.imshow(
1✔
248
        z2,
249
        interpolation="quadric",
250
        cmap=cm,
251
        vmax=0.01,
252
        alpha=0.3,
253
        origin="lower",
254
        extent=extent_shading,
255
        clip_path=patch,
256
        clip_on=True,
257
        aspect="auto",
258
    )
259
    im.set_clip_path(patch)
1✔
260

261
    return fig, ax
1✔
262

263

264
def _add_output_element(out_name: str, out_value: float, ax: Axes) -> None:
1✔
265
    """Add grew line indicating the output value to the plot.
266

267
    Args:
268
        out_name: Name of the output value
269
        out_value: Value of the output
270
        ax: Axis of the plot
271

272
    Returns: Nothing
273

274
    """
275
    # Add output value
276
    x, y = np.array([[out_value, out_value], [0, 0.24]])
1✔
277
    line = lines.Line2D(x, y, lw=2.0, color="#F2F2F2")
1✔
278
    line.set_clip_on(False)
1✔
279
    ax.add_line(line)
1✔
280

281
    font0 = FontProperties()
1✔
282
    font = font0.copy()
1✔
283
    font.set_weight("bold")
1✔
284
    text_out_val = plt.text(
1✔
285
        out_value,
286
        0.25,
287
        f"{out_value:.2f}",
288
        fontproperties=font,
289
        fontsize=14,
290
        horizontalalignment="center",
291
    )
292
    text_out_val.set_bbox({"facecolor": "white", "edgecolor": "white"})
1✔
293

294
    text_out_val = plt.text(
1✔
295
        out_value,
296
        0.33,
297
        out_name,
298
        fontsize=12,
299
        alpha=0.5,
300
        horizontalalignment="center",
301
    )
302
    text_out_val.set_bbox({"facecolor": "white", "edgecolor": "white"})
1✔
303

304

305
def _add_base_value(base_value: float, ax: Axes) -> None:
1✔
306
    """Add base value to the plot.
307

308
    Args:
309
        base_value: the base value of the game
310
        ax: Axes of the plot
311

312
    Returns: None
313

314
    """
315
    x, y = np.array([[base_value, base_value], [0.13, 0.25]])
1✔
316
    line = lines.Line2D(x, y, lw=2.0, color="#F2F2F2")
1✔
317
    line.set_clip_on(False)
1✔
318
    ax.add_line(line)
1✔
319

320
    text_out_val = ax.text(
1✔
321
        base_value,
322
        0.25,
323
        "base value",
324
        fontsize=12,
325
        alpha=1,
326
        horizontalalignment="center",
327
    )
328
    text_out_val.set_bbox({"facecolor": "white", "edgecolor": "white"})
1✔
329

330

331
def update_axis_limits(
1✔
332
    ax: Axes,
333
    total_pos: float,
334
    pos_features: np.ndarray,
335
    total_neg: float,
336
    neg_features: np.ndarray,
337
    base_value: float,
338
    out_value: float,
339
) -> None:
340
    """Adjust the axis limits of the plot according to values.
341

342
    Args:
343
        ax: Axes of the plot
344
        total_pos: value of the total positive features
345
        pos_features: values and names of the positive features
346
        total_neg: value of the total negative features
347
        neg_features: values and names of the negative features
348
        base_value: the base value of the game
349
        out_value: the output value
350

351
    Returns: None
352

353
    """
354
    ax.set_ylim(-0.5, 0.15)
1✔
355
    padding = np.max([np.abs(total_pos) * 0.2, np.abs(total_neg) * 0.2])
1✔
356

357
    if len(pos_features) > 0:
1✔
358
        min_x = min(np.min(pos_features[:, 0].astype(float)), base_value) - padding
1✔
359
    else:
360
        min_x = out_value - padding
×
361
    if len(neg_features) > 0:
1✔
362
        max_x = max(np.max(neg_features[:, 0].astype(float)), base_value) + padding
1✔
363
    else:
364
        max_x = out_value + padding
1✔
365
    ax.set_xlim(min_x, max_x)
1✔
366

367
    plt.tick_params(
1✔
368
        top=True,
369
        bottom=False,
370
        left=False,
371
        right=False,
372
        labelleft=False,
373
        labeltop=True,
374
        labelbottom=False,
375
    )
376
    plt.locator_params(axis="x", nbins=12)
1✔
377

378
    for key, spine in zip(plt.gca().spines.keys(), plt.gca().spines.values(), strict=False):
1✔
379
        if key != "top":
1✔
380
            spine.set_visible(False)
1✔
381

382

383
def _split_features(
1✔
384
    interaction_dictionary: dict[tuple[int, ...], float],
385
    feature_to_names: dict[int, str],
386
    out_value: float,
387
) -> tuple[np.ndarray, np.ndarray, float, float]:
388
    """Splits the features into positive and negative values.
389

390
    Args:
391
        interaction_dictionary: Dictionary containing the interaction values mapping from
392
            feature indices to their values.
393
        feature_to_names: Dictionary mapping feature indices to feature names.
394
        out_value: The output value.
395

396
    Returns:
397
        tuple: A tuple containing the positive features, negative features, total positive value,
398
            and total negative value.
399

400
    """
401
    # split features into positive and negative values
402
    pos_features, neg_features = [], []
1✔
403
    for coaltion, value in interaction_dictionary.items():
1✔
404
        if len(coaltion) == 0:
1✔
405
            continue
1✔
406
        label = format_labels(feature_to_names, coaltion)
1✔
407
        if value >= 0:
1✔
408
            pos_features.append([str(value), label])
1✔
409
        elif value < 0:
1✔
410
            neg_features.append([str(value), label])
1✔
411
    # sort feature values descending according to (absolute) features values
412
    pos_features = sorted(pos_features, key=lambda x: float(x[0]), reverse=True)
1✔
413
    neg_features = sorted(neg_features, key=lambda x: float(x[0]), reverse=False)
1✔
414
    pos_features = np.array(pos_features, dtype=object)
1✔
415
    neg_features = np.array(neg_features, dtype=object)
1✔
416

417
    # convert negative feature values to plot values
418
    neg_val = out_value
1✔
419
    for i in neg_features:
1✔
420
        val = float(i[0])
1✔
421
        neg_val = neg_val + np.abs(val)
1✔
422
        i[0] = neg_val
1✔
423
    if len(neg_features) > 0:
1✔
424
        total_neg = np.max(neg_features[:, 0].astype(float)) - np.min(
1✔
425
            neg_features[:, 0].astype(float),
426
        )
427
    else:
428
        total_neg = 0
1✔
429

430
    # convert positive feature values to plot values
431
    pos_val = out_value
1✔
432
    for i in pos_features:
1✔
433
        val = float(i[0])
1✔
434
        pos_val = pos_val - np.abs(val)
1✔
435
        i[0] = pos_val
1✔
436

437
    if len(pos_features) > 0:
1✔
438
        total_pos = np.max(pos_features[:, 0].astype(float)) - np.min(
1✔
439
            pos_features[:, 0].astype(float),
440
        )
441
    else:
442
        total_pos = 0
×
443

444
    return pos_features, neg_features, total_pos, total_neg
1✔
445

446

447
def _add_bars(
1✔
448
    ax: Axes,
449
    out_value: float,
450
    pos_features: np.ndarray,
451
    neg_features: np.ndarray,
452
) -> None:
453
    """Add bars to the plot.
454

455
    Args:
456
        ax: Axes of the plot
457
        out_value: grand total value
458
        pos_features: positive features
459
        neg_features: negative features
460
    """
461
    width_bar = 0.1
1✔
462
    width_separators = (ax.get_xlim()[1] - ax.get_xlim()[0]) / 200
1✔
463
    # Create bar for negative shap values
464
    rectangle_list, separator_list = _create_bars(
1✔
465
        out_value,
466
        neg_features,
467
        "negative",
468
        width_separators,
469
        width_bar,
470
    )
471
    for i in rectangle_list:
1✔
472
        ax.add_patch(i)
1✔
473

474
    for i in separator_list:
1✔
475
        ax.add_patch(i)
1✔
476

477
    # Create bar for positive shap values
478
    rectangle_list, separator_list = _create_bars(
1✔
479
        out_value,
480
        pos_features,
481
        "positive",
482
        width_separators,
483
        width_bar,
484
    )
485
    for i in rectangle_list:
1✔
486
        ax.add_patch(i)
1✔
487

488
    for i in separator_list:
1✔
489
        ax.add_patch(i)
1✔
490

491

492
def draw_higher_lower_element(
1✔
493
    out_value: float,
494
    offset_text: float,
495
) -> None:
496
    plt.text(
1✔
497
        out_value - offset_text,
498
        0.35,
499
        "higher",
500
        fontsize=13,
501
        color="#FF0D57",
502
        horizontalalignment="right",
503
    )
504
    plt.text(
1✔
505
        out_value + offset_text,
506
        0.35,
507
        "lower",
508
        fontsize=13,
509
        color="#1E88E5",
510
        horizontalalignment="left",
511
    )
512
    plt.text(
1✔
513
        out_value,
514
        0.34,
515
        r"$\leftarrow$",
516
        fontsize=13,
517
        color="#1E88E5",
518
        horizontalalignment="center",
519
    )
520
    plt.text(
1✔
521
        out_value,
522
        0.36,
523
        r"$\rightarrow$",
524
        fontsize=13,
525
        color="#FF0D57",
526
        horizontalalignment="center",
527
    )
528

529

530
def _draw_force_plot(
1✔
531
    interaction_value: InteractionValues,
532
    feature_names: np.ndarray,
533
    *,
534
    figsize: tuple[int, int],
535
    min_perc: float = 0.05,
536
    draw_higher_lower: bool = True,
537
) -> Figure:
538
    """Draw the force plot.
539

540
    Note:
541
        The functionality was taken and adapted from the [SHAP package](https://github.com/shap/shap/blob/master/shap/plots/_force.py)
542
        which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).
543
        Do not use this function directly, use the ``force_plot`` function instead.
544

545
    Args:
546
        interaction_value: The interaction values to be plotted.
547
        feature_names: The names of the features.
548
        figsize: The size of the figure.
549
        min_perc: minimal percentage of the total effect that a feature must contribute to be shown.
550
            Defaults to ``0.05``.
551
        draw_higher_lower: Whether to draw the higher and lower indicator. Defaults to ``True``.
552

553
    Returns:
554
        The figure of the plot.
555

556
    """
557
    # turn off interactive plot
558
    plt.ioff()
1✔
559

560
    # compute overall metrics
561
    base_value = interaction_value.baseline_value
1✔
562
    out_value = np.sum(interaction_value.values)  # Sum of all values with the baseline value
1✔
563

564
    # split features into positive and negative values
565
    features_to_names = {i: str(name) for i, name in enumerate(feature_names)}
1✔
566
    pos_features, neg_features, total_pos, total_neg = _split_features(
1✔
567
        interaction_value.dict_values,
568
        features_to_names,
569
        out_value,
570
    )
571

572
    # define plots
573
    offset_text = (np.abs(total_neg) + np.abs(total_pos)) * 0.04
1✔
574

575
    fig, ax = plt.subplots(figsize=figsize)
1✔
576

577
    # compute axis limit
578
    update_axis_limits(
1✔
579
        ax, total_pos, pos_features, total_neg, neg_features, float(base_value), out_value
580
    )
581

582
    # add the bars to the plot
583
    _add_bars(ax, out_value, pos_features, neg_features)
1✔
584

585
    # add labels
586
    total_effect = np.abs(total_neg) + total_pos
1✔
587
    fig, ax = _add_labels(
1✔
588
        fig,
589
        ax,
590
        out_value,
591
        neg_features,
592
        "negative",
593
        offset_text,
594
        total_effect,
595
        min_perc=min_perc,
596
        text_rotation=0,
597
    )
598

599
    fig, ax = _add_labels(
1✔
600
        fig,
601
        ax,
602
        out_value,
603
        pos_features,
604
        "positive",
605
        offset_text,
606
        total_effect,
607
        min_perc=min_perc,
608
        text_rotation=0,
609
    )
610

611
    # add higher and lower element
612
    if draw_higher_lower:
1✔
613
        draw_higher_lower_element(out_value, offset_text)
1✔
614

615
    # add label for base value
616
    _add_base_value(float(base_value), ax)
1✔
617

618
    # add output label
619
    out_names = ""
1✔
620
    _add_output_element(out_names, out_value, ax)
1✔
621

622
    # fix the whitespace around the plot
623
    plt.tight_layout()
1✔
624

625
    return plt.gcf()
1✔
626

627

628
def force_plot(
1✔
629
    interaction_values: InteractionValues,
630
    *,
631
    feature_names: np.ndarray | list[str] | None = None,
632
    abbreviate: bool = True,
633
    show: bool = False,
634
    figsize: tuple[int, int] = (15, 4),
635
    draw_higher_lower: bool = True,
636
    contribution_threshold: float = 0.05,
637
) -> Figure | None:
638
    """Draws a force plot for the given interaction values.
639

640
    Args:
641
        interaction_values: The ``InteractionValues`` to be plotted.
642
        feature_names: The names of the features. If ``None``, the features are named by their index.
643
        show: Whether to show or return the plot. Defaults to ``False`` and returns the plot.
644
        abbreviate: Whether to abbreviate the feature names. Defaults to ``True.``
645
        figsize: The size of the figure. Defaults to ``(15, 4)``.
646
        draw_higher_lower: Whether to draw the higher and lower indicator. Defaults to ``True``.
647
        contribution_threshold: Define the minimum percentage of the total effect that a feature
648
            must contribute to be shown in the plot. Defaults to 0.05.
649

650
    Returns:
651
        plt.Figure: The figure of the plot
652

653
    References:
654
        .. [1] SHAP is available at https://github.com/shap/shap
655

656
    """
657
    if feature_names is None:
1✔
658
        feature_names = [str(i) for i in range(interaction_values.n_players)]
1✔
659
    if abbreviate:
1✔
660
        feature_names = abbreviate_feature_names(feature_names)
1✔
661
    feature_names = np.array(feature_names)
1✔
662
    plot = _draw_force_plot(
1✔
663
        interaction_values,
664
        feature_names,
665
        figsize=figsize,
666
        draw_higher_lower=draw_higher_lower,
667
        min_perc=contribution_threshold,
668
    )
669
    if not show:
1✔
670
        return plot
1✔
671
    plt.show()
×
672
    return None
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc