• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.2
/src/shapiq/plot/stacked_bar.py
1
"""This module contains functions to plot the n_sii stacked bar charts."""
2

3
from __future__ import annotations
1✔
4

5
import contextlib
1✔
6
from copy import deepcopy
1✔
7
from typing import TYPE_CHECKING, Any
1✔
8

9
import matplotlib.pyplot as plt
1✔
10
import numpy as np
1✔
11
from matplotlib.patches import Patch
1✔
12

13
from ._config import COLORS_K_SII
1✔
14

15
__all__ = ["stacked_bar_plot"]
1✔
16

17

18
if TYPE_CHECKING:
1✔
NEW
19
    from matplotlib.axes import Axes
×
NEW
20
    from matplotlib.figure import Figure
×
21

UNCOV
22
    from shapiq.interaction_values import InteractionValues
×
23

24

25
def stacked_bar_plot(
1✔
26
    interaction_values: InteractionValues,
27
    *,
28
    feature_names: list[Any] | None = None,
29
    max_order: int | None = None,
30
    title: str | None = None,
31
    xlabel: str | None = None,
32
    ylabel: str | None = None,
33
    show: bool = False,
34
) -> tuple[Figure, Axes] | None:
35
    """The stacked bar plot interaction scores.
36

37
    This stacked bar plot can be used to visualize the amount of interaction between the features
38
    for a given instance. The interaction values are plotted as stacked bars with positive and
39
    negative parts stacked on top of each other. The colors represent the order of the
40
    interaction values. For a detailed explanation of this plot, we refer to Bordt and von Luxburg
41
    (2023)[1]_.
42

43
    An example of the plot is shown below.
44

45
    .. image:: /_static/stacked_bar_exampl.png
46
        :width: 400
47
        :align: center
48

49
    Args:
50
        interaction_values(InteractionValues): n-SII values as InteractionValues object
51
        feature_names: The feature names used for plotting. If no feature names are provided, the
52
            feature indices are used instead. Defaults to ``None``.
53
        max_order (int): The order of the n-SII values.
54
        title (str): The title of the plot.
55
        xlabel (str): The label of the x-axis.
56
        ylabel (str): The label of the y-axis.
57
        show (bool): Whether to show the plot. Defaults to ``False``.
58

59
    Returns:
60
        tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: A tuple containing the figure and
61
            the axis of the plot.
62

63
    Note:
64
        To change the figure size, font size, etc., use the [matplotlib parameters](https://matplotlib.org/stable/users/explain/customizing.html).
65

66
    Example:
67
        >>> import numpy as np
68
        >>> from shapiq.plot import stacked_bar_plot
69
        >>> interaction_values = InteractionValues(
70
        ...    values=np.array([1, -1.5, 1.75, 0.25, -0.5, 0.75,0.2]),
71
        ...    index="SII",
72
        ...    min_order=1,
73
        ...    max_order=3,
74
        ...    n_players=3,
75
        ...    baseline_value=0
76
        ... )
77
        >>> feature_names = ["a", "b", "c"]
78
        >>> fig, axes = stacked_bar_plot(
79
        ...     interaction_values=interaction_values,
80
        ...     feature_names=feature_names,
81
        ... )
82
        >>> plt.show()
83

84
    References:
85
        .. [1] Bordt, M., and von Luxburg, U. (2023). From Shapley Values to Generalized Additive Models and back. Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, PMLR 206:709-745. url: https://proceedings.mlr.press/v206/bordt23a.html
86

87
    """
88
    # sanitize inputs
89
    if max_order is None:
1✔
90
        max_order = interaction_values.max_order
1✔
91

92
    fig, axis = plt.subplots()
1✔
93

94
    # transform data to make plotting easier
95
    values_pos = np.array(
1✔
96
        [
97
            interaction_values.get_n_order_values(order)
98
            .clip(min=0)
99
            .sum(axis=tuple(range(1, order)))
100
            for order in range(1, max_order + 1)
101
        ],
102
    )
103
    values_neg = np.array(
1✔
104
        [
105
            interaction_values.get_n_order_values(order)
106
            .clip(max=0)
107
            .sum(axis=tuple(range(1, order)))
108
            for order in range(1, max_order + 1)
109
        ],
110
    )
111
    # get the number of features and the feature names
112
    n_features = len(values_pos[0])
1✔
113
    if feature_names is None:
1✔
114
        feature_names = [str(i + 1) for i in range(n_features)]
1✔
115
    x = np.arange(n_features)
1✔
116

117
    # get helper variables for plotting the bars
118
    min_max_values = [0, 0]  # to set the y-axis limits after all bars are plotted
1✔
119
    reference_pos = np.zeros(n_features)  # to plot the bars on top of each other
1✔
120
    reference_neg = deepcopy(values_neg[0])  # to plot the bars below of each other
1✔
121

122
    # plot the bar segments
123
    for order in range(len(values_pos)):
1✔
124
        axis.bar(x, height=values_pos[order], bottom=reference_pos, color=COLORS_K_SII[order])
1✔
125
        axis.bar(x, height=abs(values_neg[order]), bottom=reference_neg, color=COLORS_K_SII[order])
1✔
126
        axis.axhline(y=0, color="black", linestyle="solid", linewidth=0.5)
1✔
127
        reference_pos += values_pos[order]
1✔
128
        with contextlib.suppress(IndexError):
1✔
129
            reference_neg += values_neg[order + 1]
1✔
130
        min_max_values[0] = min(min_max_values[0], *reference_neg)
1✔
131
        min_max_values[1] = max(min_max_values[1], *reference_pos)
1✔
132

133
    # add a legend to the plots
134
    legend_elements = [
1✔
135
        Patch(facecolor=COLORS_K_SII[order], edgecolor="black", label=f"Order {order + 1}")
136
        for order in range(max_order)
137
    ]
138
    axis.legend(handles=legend_elements, loc="upper center", ncol=min(max_order, 4))
1✔
139

140
    x_ticks_labels = list(feature_names)  # might be unnecessary
1✔
141
    axis.set_xticks(x)
1✔
142
    axis.set_xticklabels(x_ticks_labels, rotation=45, ha="right")
1✔
143

144
    axis.set_xlim(-0.5, n_features - 0.5)
1✔
145
    axis.set_ylim(
1✔
146
        min_max_values[0] - abs(min_max_values[1] - min_max_values[0]) * 0.02,
147
        min_max_values[1] + abs(min_max_values[1] - min_max_values[0]) * 0.3,
148
    )
149

150
    # set title and labels if not provided
151
    if title is not None:
1✔
152
        axis.set_title(title)
1✔
153

154
    axis.set_xlabel("features") if xlabel is None else axis.set_xlabel(xlabel)
1✔
155
    axis.set_ylabel("SI values") if ylabel is None else axis.set_ylabel(ylabel)
1✔
156

157
    plt.tight_layout()
1✔
158

159
    if not show:
1✔
160
        return fig, axis
1✔
161
    plt.show()
×
162
    return None
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc