• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18471670957

13 Oct 2025 04:05PM UTC coverage: 93.111% (-0.7%) from 93.845%
18471670957

Pull #430

github

web-flow
Merge b2a7359a1 into dede390c9
Pull Request #430: Enhance type safety and fix bugs across the codebase

305 of 361 new or added lines in 51 files covered. (84.49%)

12 existing lines in 9 files now uncovered.

4987 of 5356 relevant lines covered (93.11%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.2
/src/shapiq/plot/stacked_bar.py
1
"""This module contains functions to plot the n_sii stacked bar charts."""
2

3
from __future__ import annotations
1✔
4

5
import contextlib
1✔
6
from copy import deepcopy
1✔
7
from typing import TYPE_CHECKING, Any
1✔
8

9
import matplotlib.pyplot as plt
1✔
10
import numpy as np
1✔
11
from matplotlib.patches import Patch
1✔
12

13
from ._config import COLORS_K_SII
1✔
14

15
__all__ = ["stacked_bar_plot"]
1✔
16

17

18
if TYPE_CHECKING:
1✔
NEW
19
    from matplotlib.axes import Axes
×
NEW
20
    from matplotlib.figure import Figure
×
21

UNCOV
22
    from shapiq.interaction_values import InteractionValues
×
23

24

25
def stacked_bar_plot(
1✔
26
    interaction_values: InteractionValues,
27
    *,
28
    feature_names: list[Any] | None = None,
29
    max_order: int | None = None,
30
    title: str | None = None,
31
    xlabel: str | None = None,
32
    ylabel: str | None = None,
33
    show: bool = False,
34
) -> tuple[Figure, Axes] | None:
35
    """The stacked bar plot interaction scores.
36

37
    This stacked bar plot can be used to visualize the amount of interaction between the features
38
    for a given instance. The interaction values are plotted as stacked bars with positive and
39
    negative parts stacked on top of each other. The colors represent the order of the
40
    interaction values. For a detailed explanation of this plot, we refer to Bordt and von Luxburg
41
    (2023)[1]_.
42

43
    An example of the plot is shown below.
44

45
    .. image:: /_static/stacked_bar_exampl.png
46
        :width: 400
47
        :align: center
48

49
    Args:
50
        interaction_values(InteractionValues): n-SII values as InteractionValues object
51
        feature_names: The feature names used for plotting. If no feature names are provided, the
52
            feature indices are used instead. Defaults to ``None``.
53
        max_order (int): The order of the n-SII values.
54
        title (str): The title of the plot.
55
        xlabel (str): The label of the x-axis.
56
        ylabel (str): The label of the y-axis.
57
        show (bool): Whether to show the plot. Defaults to ``False``.
58

59
    Returns:
60
        tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: A tuple containing the figure and
61
            the axis of the plot.
62

63
    Note:
64
        To change the figure size, font size, etc., use the [matplotlib parameters](https://matplotlib.org/stable/users/explain/customizing.html).
65

66
    Example:
67
        >>> import numpy as np
68
        >>> from shapiq.plot import stacked_bar_plot
69
        >>> interaction_values = InteractionValues(
70
        ...    values=np.array([1, -1.5, 1.75, 0.25, -0.5, 0.75,0.2]),
71
        ...    index="SII",
72
        ...    min_order=1,
73
        ...    max_order=3,
74
        ...    n_players=3,
75
        ...    baseline_value=0
76
        ... )
77
        >>> feature_names = ["a", "b", "c"]
78
        >>> fig, axes = stacked_bar_plot(
79
        ...     interaction_values=interaction_values,
80
        ...     feature_names=feature_names,
81
        ... )
82
        >>> plt.show()
83

84
    References:
85
        .. [1] Bordt, M., and von Luxburg, U. (2023). From Shapley Values to Generalized Additive Models and back. Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, PMLR 206:709-745. url: https://proceedings.mlr.press/v206/bordt23a.html
86

87
    """
88
    # sanitize inputs
89
    if max_order is None:
1✔
90
        max_order = interaction_values.max_order
1✔
91

92
    fig, axis = plt.subplots()
1✔
93

94
    # transform data to make plotting easier
95
    values_pos = np.array(
1✔
96
        [
97
            interaction_values.get_n_order_values(order)
98
            .clip(min=0)
99
            .sum(axis=tuple(range(1, order)))
100
            for order in range(1, max_order + 1)
101
        ],
102
    )
103
    values_neg = np.array(
1✔
104
        [
105
            interaction_values.get_n_order_values(order)
106
            .clip(max=0)
107
            .sum(axis=tuple(range(1, order)))
108
            for order in range(1, max_order + 1)
109
        ],
110
    )
111
    # get the number of features and the feature names
112
    n_features = len(values_pos[0])
1✔
113
    if feature_names is None:
1✔
114
        feature_names = [str(i + 1) for i in range(n_features)]
1✔
115
    x = np.arange(n_features)
1✔
116

117
    # get helper variables for plotting the bars
118
    min_max_values = [0, 0]  # to set the y-axis limits after all bars are plotted
1✔
119
    reference_pos = np.zeros(n_features)  # to plot the bars on top of each other
1✔
120
    reference_neg = deepcopy(values_neg[0])  # to plot the bars below of each other
1✔
121

122
    # plot the bar segments
123
    for order in range(len(values_pos)):
1✔
124
        axis.bar(x, height=values_pos[order], bottom=reference_pos, color=COLORS_K_SII[order])
1✔
125
        axis.bar(x, height=abs(values_neg[order]), bottom=reference_neg, color=COLORS_K_SII[order])
1✔
126
        axis.axhline(y=0, color="black", linestyle="solid", linewidth=0.5)
1✔
127
        reference_pos += values_pos[order]
1✔
128
        with contextlib.suppress(IndexError):
1✔
129
            reference_neg += values_neg[order + 1]
1✔
130
        min_max_values[0] = min(min_max_values[0], *reference_neg)
1✔
131
        min_max_values[1] = max(min_max_values[1], *reference_pos)
1✔
132

133
    # add a legend to the plots
134
    legend_elements = [
1✔
135
        Patch(facecolor=COLORS_K_SII[order], edgecolor="black", label=f"Order {order + 1}")
136
        for order in range(max_order)
137
    ]
138
    axis.legend(handles=legend_elements, loc="upper center", ncol=min(max_order, 4))
1✔
139

140
    x_ticks_labels = list(feature_names)  # might be unnecessary
1✔
141
    axis.set_xticks(x)
1✔
142
    axis.set_xticklabels(x_ticks_labels, rotation=45, ha="right")
1✔
143

144
    axis.set_xlim(-0.5, n_features - 0.5)
1✔
145
    axis.set_ylim(
1✔
146
        min_max_values[0] - abs(min_max_values[1] - min_max_values[0]) * 0.02,
147
        min_max_values[1] + abs(min_max_values[1] - min_max_values[0]) * 0.3,
148
    )
149

150
    # set title and labels if not provided
151
    if title is not None:
1✔
152
        axis.set_title(title)
1✔
153

154
    axis.set_xlabel("features") if xlabel is None else axis.set_xlabel(xlabel)
1✔
155
    axis.set_ylabel("SI values") if ylabel is None else axis.set_ylabel(ylabel)
1✔
156

157
    plt.tight_layout()
1✔
158

159
    if not show:
1✔
160
        return fig, axis
1✔
161
    plt.show()
×
162
    return None
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc