• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.02
/src/shapiq/plot/sentence.py
1
"""This module contains the sentence plot."""
2

3
from __future__ import annotations
1✔
4

5
import contextlib
1✔
6
from typing import TYPE_CHECKING
1✔
7

8
from matplotlib import pyplot as plt
1✔
9
from matplotlib.font_manager import FontProperties
1✔
10
from matplotlib.patches import FancyBboxPatch, PathPatch
1✔
11
from matplotlib.textpath import TextPath
1✔
12

13
from ._config import BLUE, RED
1✔
14

15
if TYPE_CHECKING:
1✔
16
    from collections.abc import Sequence
×
17

NEW
18
    from matplotlib.axes import Axes
×
NEW
19
    from matplotlib.figure import Figure
×
20

UNCOV
21
    from shapiq.interaction_values import InteractionValues
×
22

23

24
def _get_color_and_alpha(max_value: float, value: float) -> tuple[str, float]:
1✔
25
    """Gets the color and alpha value for an interaction value."""
26
    color = RED.hex if value >= 0 else BLUE.hex
1✔
27
    ratio = abs(value / max_value)
1✔
28
    ratio = min(ratio, 1.0)  # make ratio at most 1
1✔
29
    return color, ratio
1✔
30

31

32
def sentence_plot(
1✔
33
    interaction_values: InteractionValues,
34
    words: Sequence[str],
35
    *,
36
    connected_words: Sequence[tuple[str, str]] | None = None,
37
    chars_per_line: int = 35,
38
    font_family: str = "sans-serif",
39
    show: bool = False,
40
    max_score: float | None = None,
41
) -> tuple[Figure, Axes] | None:
42
    """Plots the first order effects (attributions) of a sentence or paragraph.
43

44
    An example of the plot is shown below.
45

46
    .. image:: /_static/sentence_plot_example.png
47
        :width: 400
48
        :align: center
49

50
    Args:
51
        interaction_values: The interaction values as an interaction object.
52
        words: The words of the sentence or a paragraph of text.
53
        connected_words: A list of tuples with connected words. Defaults to ``None``. If two 'words'
54
            are connected, the plot will not add a space between them (e.g., the parts "enjoy" and
55
            "able" would be connected to "enjoyable" with potentially different attributions for
56
            each part).
57
        chars_per_line: The maximum number of characters per line. Defaults to ``35`` after which
58
            the text will be wrapped to the next line. Connected words receive a '-' in front of
59
            them.
60
        font_family: The font family used for the plot. Defaults to ``sans-serif``. For a list of
61
            available font families, see the matplotlib documentation of
62
            ``matplotlib.font_manager.FontProperties``. Note the plot is optimized for sans-serif.
63
        max_score: The maximum score for the attributions to scale the colors and alpha values. This
64
            is useful if you want to compare the attributions of different sentences and both plots
65
            should have the same color scale. Defaults to ``None``.
66
        show: Whether to show the plot. Defaults to ``False``.
67

68
    Returns:
69
        If ``show`` is ``True``, the function returns ``None``. Otherwise, it returns a tuple with
70
        the figure and the axis of the plot.
71

72
    Example:
73
        >>> import numpy as np
74
        >>> from shapiq.plot import sentence_plot
75
        >>> iv = InteractionValues(
76
        ...    values=np.array([0.45, 0.01, 0.67, -0.2, -0.05, 0.7, 0.1, -0.04, 0.56, 0.7]),
77
        ...    index="SV",
78
        ...    n_players=10,
79
        ...    min_order=1,
80
        ...    max_order=1,
81
        ...    estimated=False,
82
        ...    baseline_value=0.0,
83
        ... )
84
        >>> words = ["I", "really", "enjoy", "working", "with", "Shapley", "values", "in", "Python", "!"]
85
        >>> connected_words = [("Shapley", "values")]
86
        >>> fig, ax = sentence_plot(iv, words, connected_words, show=False, chars_per_line=100)
87
        >>> plt.show()
88

89
    .. image:: /_static/sentence_plot_connected_example.png
90
        :width: 300
91
        :align: center
92

93
    """
94
    # set all the size parameters
95
    fontsize = 20
1✔
96
    word_spacing = 15
1✔
97
    line_spacing = 10
1✔
98
    height_padding = 5
1✔
99
    width_padding = 5
1✔
100

101
    # clean the input
102
    connected_words = [] if connected_words is None else connected_words
1✔
103
    words = [word.strip() for word in words]
1✔
104
    attributions = [interaction_values[(i,)] for i in range(len(words))]
1✔
105

106
    # get the maximum score
107
    if max_score is None:
1✔
108
        max_abs_attribution = max([abs(value) for value in attributions])
1✔
109
    else:
110
        max_abs_attribution = max_score
1✔
111

112
    # create plot
113
    fig, ax = plt.subplots()
1✔
114

115
    max_x_pos = 0
1✔
116
    x_pos, y_pos = word_spacing, 0
1✔
117
    lines, chars_in_line = 0, 0
1✔
118
    for i, (_word, attribution) in enumerate(zip(words, attributions, strict=False)):
1✔
119
        word = _word
1✔
120
        # check if the word is connected
121
        is_word_connected_first = False
1✔
122
        is_word_connected_second = (words[i - 1], word) in connected_words
1✔
123
        with contextlib.suppress(IndexError):
1✔
124
            is_word_connected_first = (word, words[i + 1]) in connected_words
1✔
125

126
        # check if the line is too long and needs to be wrapped
127
        chars_in_line += len(word)
1✔
128
        if chars_in_line > chars_per_line:
1✔
129
            lines += 1
1✔
130
            chars_in_line = 0
1✔
131
            x_pos = word_spacing
1✔
132
            y_pos -= fontsize + line_spacing
1✔
133
            if is_word_connected_second:
1✔
134
                word = "-" + word
1✔
135

136
        # adjust the x position for connected words
137
        if is_word_connected_second:
1✔
138
            x_pos += 2
1✔
139

140
        # set the position of the word in the plot
141
        position = (x_pos, y_pos)
1✔
142

143
        # get the color and alpha value
144
        color, alpha = _get_color_and_alpha(max_abs_attribution, attribution)
1✔
145

146
        # get the text
147
        text_color = "black" if alpha < 2 / 3 else "white"
1✔
148
        fp = FontProperties(family=font_family, style="normal", size=fontsize, weight="normal")
1✔
149
        text_path = TextPath(position, word, prop=fp)
1✔
150
        text_path = PathPatch(text_path, facecolor=text_color, edgecolor="none")
1✔
151
        width_of_text = text_path.get_window_extent().width
1✔
152

153
        # get dimensions for the explanation patch
154
        height_patch = fontsize + height_padding
1✔
155
        width_patch = width_of_text + 1
1✔
156
        y_pos_patch = y_pos - height_padding
1✔
157
        x_pos_patch = x_pos + 1
1✔
158
        if is_word_connected_first:
1✔
159
            x_pos_patch -= width_padding / 2
1✔
160
            width_patch += width_padding / 2
1✔
161
        elif is_word_connected_second:
1✔
162
            width_patch += width_padding / 2
1✔
163
        else:
164
            x_pos_patch -= width_padding / 2
1✔
165
            width_patch += width_padding
1✔
166

167
        # create the explanation patch
168
        patch = FancyBboxPatch(
1✔
169
            xy=(x_pos_patch, y_pos_patch),
170
            width=width_patch,
171
            height=height_patch,
172
            color=color,
173
            alpha=alpha,
174
            zorder=-1,
175
            boxstyle="Round, pad=0, rounding_size=3",
176
        )
177

178
        # draw elements for the word
179
        ax.add_patch(patch)
1✔
180
        ax.add_artist(text_path)
1✔
181

182
        # update the x position
183
        x_pos += width_of_text + word_spacing
1✔
184
        max_x_pos = max(max_x_pos, x_pos)
1✔
185
        if is_word_connected_first:
1✔
186
            x_pos -= word_spacing
1✔
187

188
    # fix up the dimensions of the plot
189
    ax.set_xlim(0, max_x_pos)
1✔
190
    ax.set_ylim(y_pos - fontsize / 2, fontsize + fontsize / 2)
1✔
191
    width = max_x_pos
1✔
192
    height = fontsize + fontsize / 2 + abs(y_pos - fontsize / 2)
1✔
193
    fig.set_size_inches(width / 100, height / 100)
1✔
194

195
    # clean up the plot
196
    ax.axis("off")
1✔
197
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
1✔
198

199
    # draw the plot
200
    if not show:
1✔
201
        return fig, ax
1✔
202
    plt.show()
×
203
    return None
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc