• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.17
/src/shapiq/plot/bar.py
1
"""Wrapper for the bar plot from the ``shap`` package.
2

3
Note:
4
    Code and implementation was taken and adapted from the [SHAP package](https://github.com/shap/shap)
5
    which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).
6

7
"""
8

9
from __future__ import annotations
1✔
10

11
from typing import TYPE_CHECKING
1✔
12

13
import matplotlib.pyplot as plt
1✔
14
import numpy as np
1✔
15

16
from shapiq.interaction_values import InteractionValues, aggregate_interaction_values
1✔
17

18
from ._config import BLUE, RED
1✔
19
from .utils import abbreviate_feature_names, format_labels, format_value
1✔
20

21
if TYPE_CHECKING:
1✔
NEW
22
    from matplotlib.axes import Axes
×
23

24

25
__all__ = ["bar_plot"]
1✔
26

27

28
def _bar(
1✔
29
    values: np.ndarray,
30
    feature_names: np.ndarray | list[str],
31
    max_display: int | None = 10,
32
    ax: Axes | None = None,
33
) -> Axes:
34
    """Create a bar plot of a set of SHAP values.
35

36
    Note:
37
        This function was taken and adapted from the [SHAP package](https://github.com/shap/shap/blob/master/shap/plots/_bar.py)
38
        which is licensed under the [MIT license](https://github.com/shap/shap/blob/master/LICENSE).
39
        Do not use this function directly, use the ``bar_plot`` function instead.
40

41
    Args:
42
        values: The explanation values to plot as a 2D array. Each row should be a different group
43
            of values to plot. The columns are the feature values.
44
        feature_names: The names of the features to display.
45
        max_display: The maximum number of features to display. Defaults to ``10``.
46
        ax: The axis to plot on. If ``None``, a new figure and axis is created. Defaults to
47
            ``None``.
48

49
    Returns:
50
        The axis of the plot.
51

52
    """
53
    # determine how many top features we will plot
54
    num_features = len(values[0])
1✔
55
    if max_display is None:
1✔
56
        max_display = num_features
1✔
57
    max_display = min(max_display, num_features)
1✔
58
    num_cut = max(num_features - max_display, 0)  # number of features that are not displayed
1✔
59

60
    # get order of features in descending order
61
    feature_order = np.argsort(np.mean(values, axis=0))[::-1]
1✔
62

63
    # if there are more features than we are displaying then we aggregate the features not shown
64
    if num_cut > 0:
1✔
65
        cut_feature_values = values[:, feature_order[max_display:]]
1✔
66
        sum_of_remaining = np.sum(cut_feature_values, axis=None)
1✔
67
        index_of_last = feature_order[max_display]
1✔
68
        values[:, index_of_last] = sum_of_remaining
1✔
69
        max_display += 1  # include the sum of the remaining in the display
1✔
70

71
    # get the top features and their names
72
    feature_inds = feature_order[:max_display]
1✔
73
    y_pos = np.arange(len(feature_inds), 0, -1)
1✔
74
    yticklabels: list[str] = [str(feature_names[i]) for i in feature_inds]
1✔
75
    if num_cut > 0:
1✔
76
        yticklabels[-1] = f"Sum of {int(num_cut)} other features"
1✔
77

78
    # create a figure if one was not provided
79
    if ax is None:
1✔
80
        ax = plt.gca()
1✔
81
        # only modify the figure size if ax was not passed in
82
        # compute our figure size based on how many features we are showing
83
        fig = plt.gcf()
1✔
84
        row_height = 0.5
1✔
85
        fig.set_size_inches(
1✔
86
            8 + 0.3 * max([len(feature_name) for feature_name in feature_names]),
87
            max_display * row_height * np.sqrt(len(values)) + 1.5,
88
        )
89

90
    # if negative values are present, we draw a vertical line to mark 0
91
    negative_values_present = np.sum(values[:, feature_order[:max_display]] < 0) > 0
1✔
92
    if negative_values_present:
1✔
93
        ax.axvline(0, 0, 1, color="#000000", linestyle="-", linewidth=1, zorder=1)
1✔
94

95
    # draw the bars
96
    patterns = (None, "\\\\", "++", "xx", "////", "*", "o", "O", ".", "-")
1✔
97
    total_width = 0.7
1✔
98
    bar_width = total_width / len(values)
1✔
99
    for i in range(len(values)):
1✔
100
        ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2)
1✔
101
        ax.barh(
1✔
102
            y_pos + ypos_offset,
103
            values[i, feature_inds],
104
            bar_width,
105
            align="center",
106
            color=[
107
                BLUE.hex if values[i, feature_inds[j]] <= 0 else RED.hex for j in range(len(y_pos))
108
            ],
109
            hatch=patterns[i],
110
            edgecolor=(1, 1, 1, 0.8),
111
            label="Group " + str(i + 1),
112
        )
113

114
    # draw the yticks (the 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks)
115
    ax.set_yticks(
1✔
116
        list(y_pos) + list(y_pos + 1e-8),
117
        yticklabels + [t.split("=")[-1] for t in yticklabels],
118
        fontsize=13,
119
    )
120

121
    xlen = ax.get_xlim()[1] - ax.get_xlim()[0]
1✔
122
    bbox = ax.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted())
1✔
123
    width = bbox.width
1✔
124
    bbox_to_xscale = xlen / width
1✔
125

126
    # draw the bar labels as text next to the bars
127
    for i in range(len(values)):
1✔
128
        ypos_offset = -((i - len(values) / 2) * bar_width + bar_width / 2)
1✔
129
        for j in range(len(y_pos)):
1✔
130
            ind = feature_inds[j]
1✔
131
            if values[i, ind] < 0:
1✔
132
                ax.text(
1✔
133
                    values[i, ind] - (5 / 72) * bbox_to_xscale,
134
                    float(y_pos[j] + ypos_offset),
135
                    format_value(values[i, ind], "%+0.02f"),
136
                    horizontalalignment="right",
137
                    verticalalignment="center",
138
                    color=BLUE.hex,
139
                    fontsize=12,
140
                )
141
            else:
142
                ax.text(
1✔
143
                    values[i, ind] + (5 / 72) * bbox_to_xscale,
144
                    float(y_pos[j] + ypos_offset),
145
                    format_value(values[i, ind], "%+0.02f"),
146
                    horizontalalignment="left",
147
                    verticalalignment="center",
148
                    color=RED.hex,
149
                    fontsize=12,
150
                )
151

152
    # put horizontal lines for each feature row
153
    for i in range(max_display):
1✔
154
        ax.axhline(i + 1, color="#888888", lw=0.5, dashes=(1, 5), zorder=-1)
1✔
155

156
    # remove plot frame and y-axis ticks
157
    ax.xaxis.set_ticks_position("bottom")
1✔
158
    ax.yaxis.set_ticks_position("none")
1✔
159
    ax.spines["right"].set_visible(False)
1✔
160
    ax.spines["top"].set_visible(False)
1✔
161
    if negative_values_present:
1✔
162
        ax.spines["left"].set_visible(False)
1✔
163
    ax.tick_params("x", labelsize=11)
1✔
164

165
    # set the x-axis limits to cover the data
166
    xmin, xmax = ax.get_xlim()
1✔
167
    x_buffer = (xmax - xmin) * 0.05
1✔
168
    if negative_values_present:
1✔
169
        ax.set_xlim(xmin - x_buffer, xmax + x_buffer)
1✔
170
    else:
171
        ax.set_xlim(xmin, xmax + x_buffer)
1✔
172

173
    ax.set_xlabel("Attribution", fontsize=13)
1✔
174

175
    if len(values) > 1:
1✔
176
        ax.legend(fontsize=12, loc="lower right")
1✔
177

178
    # color the y tick labels that have the feature values as gray
179
    # (these fall behind the black ones with just the feature name)
180
    tick_labels = ax.yaxis.get_majorticklabels()
1✔
181
    for i in range(max_display):
1✔
182
        tick_labels[i].set_color("#999999")
1✔
183

184
    return ax
1✔
185

186

187
def bar_plot(
1✔
188
    list_of_interaction_values: list[InteractionValues],
189
    *,
190
    feature_names: np.ndarray | list[str] | None = None,
191
    show: bool = False,
192
    abbreviate: bool = True,
193
    max_display: int | None = 10,
194
    global_plot: bool = True,
195
    plot_base_value: bool = False,
196
) -> Axes | None:
197
    """Draws interaction values as a SHAP bar plot[1]_.
198

199
    The function draws the interaction values on a bar plot. The interaction values can be
200
    aggregated into a global explanation or plotted separately.
201

202
    Args:
203
        list_of_interaction_values: A list containing InteractionValues objects.
204
        feature_names: The feature names used for plotting. If no feature names are provided, the
205
            feature indices are used instead. Defaults to ``None``.
206
        show: Whether ``matplotlib.pyplot.show()`` is called before returning. Default is ``True``.
207
            Setting this to ``False`` allows the plot to be customized further after it has been
208
            created.
209
        abbreviate: Whether to abbreviate the feature names. Defaults to ``True``.
210
        max_display: The maximum number of features to display. Defaults to ``10``. If set to
211
            ``None``, all features are displayed.
212
        global_plot: Weather to aggregate the values of the different InteractionValues objects
213
            into a global explanation (``True``) or to plot them as separate bars (``False``).
214
            Defaults to ``True``. If only one InteractionValues object is provided, this parameter
215
            is ignored.
216
        plot_base_value: Whether to include the base value in the plot or not. Defaults to
217
            ``False``.
218

219
    Returns:
220
        If ``show`` is ``False``, the function returns the axis of the plot. Otherwise, it returns
221
        ``None``.
222

223
    References:
224
        .. [1] SHAP is available at https://github.com/shap/shap
225

226
    """
227
    n_players = list_of_interaction_values[0].n_players
1✔
228

229
    if feature_names is not None:
1✔
230
        if abbreviate:
1✔
231
            feature_names = abbreviate_feature_names(feature_names)
1✔
232
        feature_mapping = {i: feature_names[i] for i in range(n_players)}
1✔
233
    else:
234
        feature_mapping = {i: "F" + str(i) for i in range(n_players)}
1✔
235

236
    # aggregate the interaction values if global_plot is True
237
    if global_plot and len(list_of_interaction_values) > 1:
1✔
238
        # The aggregation of the global values will be done on the absolute values
239
        list_of_interaction_values = [abs(iv) for iv in list_of_interaction_values]
1✔
240
        global_values = aggregate_interaction_values(list_of_interaction_values, aggregation="mean")
1✔
241
        values = np.expand_dims(global_values.values, axis=0)
1✔
242
        interaction_list = global_values.interaction_lookup.keys()
1✔
243
    else:  # plot the interaction values separately  (also includes the case of a single object)
244
        all_interactions = set()
1✔
245
        for iv in list_of_interaction_values:
1✔
246
            all_interactions.update(iv.interaction_lookup.keys())
1✔
247
        all_interactions = sorted(all_interactions)
1✔
248
        interaction_list = []
1✔
249
        values = np.zeros((len(list_of_interaction_values), len(all_interactions)))
1✔
250
        for j, interaction in enumerate(all_interactions):
1✔
251
            interaction_list.append(interaction)
1✔
252
            for i, iv in enumerate(list_of_interaction_values):
1✔
253
                values[i, j] = iv[interaction]
1✔
254

255
    # Include the base value in the plot
256
    if not plot_base_value:
1✔
257
        values = values[:, 1:]
1✔
258
        interaction_list = list(interaction_list)[1:]
1✔
259

260
    # format the labels
261
    labels = [format_labels(feature_mapping, interaction) for interaction in interaction_list]
1✔
262

263
    ax = _bar(values=values, feature_names=labels, max_display=max_display)
1✔
264
    if not show:
1✔
265
        return ax
1✔
266
    plt.show()
×
267
    return None
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc