• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.55
/src/shapiq/plot/si_graph.py
1
"""Module for plotting the explanation graph of interaction values."""
2

3
from __future__ import annotations
1✔
4

5
import math
1✔
6
from typing import TYPE_CHECKING
1✔
7
from warnings import warn
1✔
8

9
import matplotlib.patches as mpatches
1✔
10
import matplotlib.path as mpath
1✔
11
import networkx as nx
1✔
12
import numpy as np
1✔
13
from matplotlib import pyplot as plt
1✔
14

15
from ._config import get_color
1✔
16
from .utils import add_image_in_center
1✔
17

18
if TYPE_CHECKING:
1✔
NEW
19
    from matplotlib.axes import Axes
×
NEW
20
    from matplotlib.figure import Figure
×
NEW
21
    from matplotlib.legend import Legend
×
UNCOV
22
    from PIL.Image import Image
×
23

24
    from shapiq.interaction_values import InteractionValues
×
25

26
NORMAL_NODE_SIZE = 0.125  # 0.125
1✔
27
BASE_ALPHA_VALUE = 1.0  # the transparency level for the highest interaction
1✔
28
BASE_SIZE = 0.05  # the size of the highest interaction edge (with scale factor 1)
1✔
29
ADJUST_NODE_ALPHA = True
1✔
30
LABEL_OFFSET = 0.07
1✔
31

32
__all__ = ["get_legend", "si_graph_plot"]
1✔
33

34

35
def si_graph_plot(
1✔
36
    interaction_values: InteractionValues,
37
    *,
38
    show: bool = False,
39
    n_interactions: int | None = None,
40
    draw_threshold: float = 0.0,
41
    interaction_direction: str | None = None,
42
    min_max_order: tuple[int, int] = (1, -1),
43
    size_factor: float = 1.0,
44
    node_size_scaling: float = 1.0,
45
    min_max_interactions: tuple[float, float] | None = None,
46
    feature_names: list | dict | None = None,
47
    graph: list[tuple] | nx.Graph | None = None,
48
    plot_original_nodes: bool = False,
49
    plot_explanation: bool = True,
50
    pos: dict | None = None,
51
    circular_layout: bool = True,
52
    random_seed: int = 42,
53
    adjust_node_pos: bool = False,
54
    spring_k: float | None = None,
55
    compactness: float = 1e10,
56
    center_image: Image | np.ndarray | None = None,
57
    center_image_size: float = 0.4,
58
    feature_image_patches: dict[int, Image] | list[Image] | None = None,
59
    feature_image_patches_size: float = 0.2,
60
) -> tuple[Figure, Axes] | None:
61
    """Plots the interaction values as an explanation graph.
62

63
    An explanation graph is an undirected graph where the nodes represent players and the edges
64
    represent interactions between the players. The size of the nodes and edges represent the
65
    strength of the interaction values. The color of the edges represents the sign of the
66
    interaction values (red for positive and blue for negative). The SI-graph plot is presented in
67
    :footcite:t:`Muschalik.2024b`.
68

69
    Args:
70
        interaction_values: The interaction values to plot.
71

72
        show: Whether to show or return the plot. Defaults to ``True``.
73

74
        n_interactions: The number of interactions to plot. If ``None``, all interactions are
75
            plotted according to the draw_threshold.
76

77
        draw_threshold: The threshold to draw an edge (i.e. only draw explanations with an
78
            interaction value higher than this threshold).
79

80
        interaction_direction: The sign of the interaction values to plot. If ``None``, all
81
            interactions are plotted. Possible values are ``"positive"`` and
82
            ``"negative"``. Defaults to ``None``.
83

84
        min_max_order: Only show interactions of min <= size <= max. First order interactions are
85
            always shown. To use maximum order of interaction values, set max to -1. Defaults to
86
            ``(1, -1)``.
87

88
        size_factor: The factor to scale the explanations by (a higher value will make the
89
            interactions and main effects larger). Defaults to ``1.0``.
90

91
        node_size_scaling: The scaling factor for the node sizes. This can be used to make the nodes
92
            larger or smaller depending on how the graph looks. Defaults to ``1.0`` (no scaling).
93
            Values between ``0.0`` and ``1.0`` will make the nodes smaller, higher values will make
94
            the nodes larger.
95

96
        min_max_interactions: The minimum and maximum interaction values to use for scaling the
97
            interactions as a tuple ``(min, max)``. If ``None``, the minimum and maximum interaction
98
            values are used. Defaults to ``None``.
99

100
        feature_names: The feature names used for plotting. List/dict mapping index of the player as
101
            index/key to name. If no feature names are provided, the feature indices are used
102
            instead. Defaults to ``None``.
103

104
        graph: The underlying graph structure as a list of edge tuples or a networkx graph. If a
105
            networkx graph is provided, the nodes are used as the players and the edges are used as
106
            the connections between the players. Defaults to ``None``, which creates a graph with
107
            all nodes from the interaction values without any edges between them.
108

109
        plot_original_nodes: If set to ``True``, nodes are shown as white circles with the label
110
            inside, large first-order-effects appear as halos around the node. Set to ``False``,
111
            only the explanation nodes are shown, their labels next to them. Defaults to ``False``.
112

113
        plot_explanation: Whether to plot the explanation or only the original graph. Defaults to
114
            ``True``.
115

116
        pos: The positions of the nodes in the graph. If ``None``, the spring layout is used to
117
            position the nodes. Defaults to ``None``.
118

119
        circular_layout: plot the players in a circle according to their order.
120

121
        random_seed: The random seed to use for layout of the graph (if not circular).
122

123
        adjust_node_pos: Whether to adjust the node positions such that the nodes are at least
124
            ``NORMAL_NODE_SIZE`` apart. Defaults to ``False``.
125

126
        spring_k: The spring constant for the spring layout. If `None`, the spring constant is
127
            calculated based on the number of nodes in the graph. Defaults to ``None``.
128

129
        compactness: A scaling factor for the underlying spring layout. A higher compactness value
130
            will move the interactions closer to the graph nodes. If your graph looks weird, try
131
            adjusting this value, e.g. ``[0.1, 1.0, 10.0, 100.0, 1000.0]``. Defaults to ``1e10``.
132

133
        center_image: An optional image to be displayed in the center of the graph. If provided,
134
            the image displayed with size ``center_image_size``. If the number of features is
135
            a perfect square, we assume a vision transformer style grid was used and overlay the
136
            image with a grid of feature image patches. Defaults to ``None``.
137

138
        center_image_size: The size of the center image. Defaults to ``0.4``. Adjust this value
139
            to make the image larger or smaller in the center of the graph.
140

141
        feature_image_patches: A dictionary/list containing the image patches to be displayed
142
            instead of the feature labels in the network. The keys/indices of the list are the
143
            feature indices and the values are the feature images. If explicit feature names are
144
            provided, they are displayed on top of the image. Defaults to ``None``.
145

146
        feature_image_patches_size: The size of the feature image patches. Defaults to ``0.2``.
147

148
    Returns:
149
        The figure and axis of the plot if ``show`` is ``False``. Otherwise, ``None``.
150

151
    References:
152
        .. footbibliography::
153
    """
154
    if interaction_values is None:
1✔
UNCOV
155
        msg = "Interaction_values must be provided."
×
UNCOV
156
        raise ValueError(msg)
×
157

158
    normal_node_size = NORMAL_NODE_SIZE * node_size_scaling
1✔
159
    base_size = BASE_SIZE * node_size_scaling
1✔
160

161
    label_mapping = None
1✔
162
    if isinstance(feature_names, list):
1✔
163
        label_mapping = {i: feature_names[i] for i in range(len(feature_names))}
1✔
164
    else:
165
        label_mapping = feature_names
1✔
166

167
    player_ids = {
1✔
168
        interaction[0]
169
        for interaction in interaction_values.interaction_lookup
170
        if len(interaction) == 1
171
    }
172

173
    # fill the original graph with the edges and nodes
174
    if isinstance(graph, nx.Graph):
1✔
175
        original_graph = graph
1✔
176
        graph_nodes = list(original_graph.nodes)
1✔
177
        # check if graph has labels
178
        if "label" not in original_graph.nodes[graph_nodes[0]]:
1✔
179
            for node in graph_nodes:
1✔
180
                node_label = label_mapping.get(node, node) if label_mapping is not None else node
1✔
181
                original_graph.nodes[node]["label"] = node_label
1✔
182
    elif isinstance(graph, list):
1✔
183
        circular_layout = False
1✔
184
        original_graph, graph_nodes = nx.Graph(), []
1✔
185
        for edge in graph:
1✔
186
            original_graph.add_edge(*edge)
1✔
187
            nodel_labels = [edge[0], edge[1]]
1✔
188
            if label_mapping is not None:
1✔
189
                nodel_labels = [label_mapping.get(node, node) for node in nodel_labels]
1✔
190
            original_graph.add_node(edge[0], label=nodel_labels[0])
1✔
191
            original_graph.add_node(edge[1], label=nodel_labels[1])
1✔
192
            graph_nodes.extend([edge[0], edge[1]])
1✔
193
    else:  # graph is considered None
194
        original_graph = nx.Graph()
1✔
195
        graph_nodes = list(player_ids)
1✔
196
        for node in graph_nodes:
1✔
197
            node_label = label_mapping.get(node, node) if label_mapping is not None else node
1✔
198
            original_graph.add_node(node, label=node_label)
1✔
199

200
    for player_id in player_ids:
1✔
201
        if player_id not in original_graph.nodes:
1✔
202
            msg = (
×
203
                f"The given graph does not contain player {player_id}, which can lead to misattributions in the plot.\n"
204
                f"The given graph: {graph} and the players of the given interaction values: {player_ids}"
205
            )
206
            warn(msg, stacklevel=2)
×
207
            break
×
208

209
    if n_interactions is not None:
1✔
210
        # get the top n interactions
211
        interaction_values = interaction_values.get_top_k_interactions(
1✔
212
            n_interactions
213
        )  # TODO(advueu963): Was get_top_k(n_interactions) which should be wrong. # noqa: TD003
214

215
    min_order, max_order = min_max_order
1✔
216
    min_order = max(1, min_order)
1✔
217
    if max_order == -1:
1✔
218
        max_order = interaction_values.max_order
1✔
219

220
    # get the interactions to plot (sufficiently large, right order)
221
    interactions_to_plot = {}
1✔
222
    min_interaction, max_interaction = 1e10, 0.0
1✔
223
    for interaction, interaction_pos in interaction_values.interaction_lookup.items():
1✔
224
        if len(interaction) < min_order or len(interaction) > max_order:
1✔
225
            continue
1✔
226
        interaction_value = interaction_values.values[interaction_pos]
1✔
227
        min_interaction = min(abs(interaction_value), min_interaction)
1✔
228
        max_interaction = max(abs(interaction_value), max_interaction)
1✔
229
        if abs(interaction_value) > draw_threshold:
1✔
230
            if interaction_direction == "positive" and interaction_value < 0:
1✔
231
                continue
1✔
232
            if interaction_direction == "negative" and interaction_value > 0:
1✔
233
                continue
1✔
234
            interactions_to_plot[interaction] = interaction_value
1✔
235

236
    if min_max_interactions is not None:
1✔
237
        min_interaction, max_interaction = min_max_interactions
1✔
238

239
    # create explanation graph
240
    explanation_graph, explanation_nodes, explanation_edges = nx.Graph(), [], []
1✔
241
    for interaction, interaction_value in interactions_to_plot.items():
1✔
242
        interaction_size = len(interaction)
1✔
243
        interaction_strength = abs(interaction_value)
1✔
244

245
        attributes = {
1✔
246
            "color": get_color(interaction_value),
247
            "alpha": _normalize_value(interaction_value, max_interaction, BASE_ALPHA_VALUE),
248
            "interaction": interaction,
249
            "weight": interaction_strength * compactness,
250
            "size": _normalize_value(interaction_value, max_interaction, base_size * size_factor),
251
        }
252

253
        # add main effect explanations as nodes
254
        if interaction_size == 1:
1✔
255
            player = interaction[0]
1✔
256
            explanation_graph.add_node(player, **attributes)
1✔
257
            explanation_nodes.append(player)
1✔
258

259
        # add 2-way interaction explanations as edges
260
        if interaction_size >= 2:
1✔
261
            explanation_edges.append(interaction)
1✔
262
            player_last = interaction[-1]
1✔
263
            if interaction_size > 2:
1✔
264
                dummy_node = tuple(interaction)
1✔
265
                explanation_graph.add_node(dummy_node, **attributes)
1✔
266
                player_last = dummy_node
1✔
267
            # add the edges between the players
268
            for player in interaction[:-1]:
1✔
269
                explanation_graph.add_edge(player, player_last, **attributes)
1✔
270

271
    # position first the original graph structure
272
    if isinstance(graph, nx.Graph | list):
1✔
273
        circular_layout = False
1✔
274

275
    adjusted_pos: dict = {}
1✔
276
    if pos is None:
1✔
277
        # TODO(advueu963): pos is statically just a Mapping. Forcing it to be dict is way stronger but necessary as far I see it # noqa: TD003
278
        if circular_layout:
1✔
279
            adjusted_pos = nx.circular_layout(original_graph)  # pyright: ignore[reportAssignmentType]
1✔
280
        else:
281
            adjusted_pos = nx.spring_layout(original_graph, seed=random_seed, k=spring_k)  # pyright: ignore[reportAssignmentType]
1✔
282
            adjusted_pos = nx.kamada_kawai_layout(original_graph, scale=1, pos=pos)  # pyright: ignore[reportAssignmentType]
1✔
283
    else:
284
        # pos is given, but we need to scale the positions potentially
285
        min_pos = np.min(list(pos.values()), axis=0)
1✔
286
        max_pos = np.max(list(pos.values()), axis=0)
1✔
287
        adjusted_pos = {node: (pos[node] - min_pos) / (max_pos - min_pos) for node in pos}
1✔
288

289
    # adjust pos such that the nodes are at least NORMAL_NODE_SIZE apart
290
    if adjust_node_pos:
1✔
291
        adjusted_pos = _adjust_position(adjusted_pos, original_graph)
1✔
292

293
    # create the plot
294
    fig, ax = plt.subplots(figsize=(7, 7))
1✔
295
    if plot_explanation:
1✔
296
        # position now again the hyper-edges onto the normal nodes weight param is weight
297
        pos_explain = nx.spring_layout(
1✔
298
            explanation_graph,
299
            weight="weight",
300
            seed=random_seed,
301
            pos=adjusted_pos,
302
            fixed=graph_nodes,
303
        )
304
        adjusted_pos.update(pos_explain)
1✔
305
        _draw_fancy_hyper_edges(ax, adjusted_pos, explanation_graph, hyper_edges=explanation_edges)
1✔
306
        _draw_explanation_nodes(
1✔
307
            ax,
308
            adjusted_pos,
309
            explanation_graph,
310
            nodes=explanation_nodes,
311
            normal_node_size=normal_node_size,
312
        )
313

314
    # add the original graph structure on top
315
    if plot_original_nodes or not plot_explanation:
1✔
316
        _draw_graph_nodes(ax, adjusted_pos, original_graph, normal_node_size=normal_node_size)
1✔
317
    _draw_graph_edges(ax, adjusted_pos, original_graph, normal_node_size=normal_node_size)
1✔
318

319
    # add images
320
    if feature_image_patches is not None:
1✔
321
        _draw_feature_images(
1✔
322
            ax,
323
            adjusted_pos,
324
            original_graph,
325
            feature_image_patches,
326
            feature_image_patches_size,
327
        )
328

329
    if feature_image_patches is None or plot_original_nodes:
1✔
330
        _draw_graph_labels(
1✔
331
            ax,
332
            adjusted_pos,
333
            original_graph,
334
            normal_node_size=normal_node_size,
335
            plot_white_nodes=plot_original_nodes,
336
        )
337

338
    # add the center image
339
    if center_image is not None:
1✔
340
        n_features = interaction_values.n_players
×
341
        if feature_image_patches is not None:
×
342
            n_features = len(feature_image_patches)
×
343
        # if the number is not a square we should not draw a grid, otherwise we assume a grid
344
        if math.isqrt(n_features) ** 2 != n_features:
×
345
            n_features = None
×
346
        add_image_in_center(
×
347
            image=center_image,
348
            axis=ax,
349
            size=center_image_size,
350
            n_features=n_features,
351
        )
352

353
    # tidy up the plot
354
    ax.set_aspect("equal", adjustable="datalim")  # make y- and x-axis scales equal
1✔
355
    ax.axis("off")  # remove axis
1✔
356

357
    if not show:
1✔
358
        return fig, ax
1✔
359
    plt.show()
×
360
    return None
×
361

362

363
# TODO(advueu963): This function is not used at all. If not given an axis it will also crash. What is the meaning of this function # noqa: TD003
364
def get_legend(axis: Axes) -> tuple[Legend, Legend]:
1✔
365
    """Gets the legend for the SI graph plot.
366

367
    Returns a tuple of legends, a legend for first order (nodes) and one for higher order (edges)
368
    interactions. If an axis is provided, it adds the legend to the axis.
369

370
    Args:
371
        axis (plt.Axes): The axis to add the legend to.
372

373
    Returns:
374
        a tuple of two legend objects: the first is the legend for the first order interactions, the second for higher
375
        order interactions.
376
    """
377
    interaction_values = [1.0, 0.4, -0.4, -1]
1✔
378
    labels = ["high pos.", "low pos.", "low neg.", "high neg."]
1✔
379

380
    plot_circles = []
1✔
381
    plot_edges = []
1✔
382
    for value in interaction_values:
1✔
383
        color = get_color(value)
1✔
384
        node_size = abs(value) / 2 + 1 / 2
1✔
385
        edge_size = abs(value) / 2
1✔
386
        alpha = _normalize_value(value, 1, BASE_ALPHA_VALUE)
1✔
387
        circle = axis.plot(
1✔
388
            [], [], c=color, marker="o", markersize=node_size * 8, linestyle="None", alpha=alpha
389
        )
390
        plot_circles.append(circle[0])
1✔
391
        line = axis.plot([], [], c=color, linewidth=edge_size * 6, alpha=alpha)
1✔
392
        plot_edges.append(line[0])
1✔
393

394
    font_size = plt.rcParams["legend.fontsize"]
1✔
395

396
    legend1 = plt.gca().legend(
1✔
397
        plot_circles,
398
        labels,
399
        frameon=True,
400
        framealpha=0.5,
401
        facecolor="white",
402
        title=r"$\bf{First\ Order}$",
403
        fontsize=font_size,
404
        labelspacing=0.5,
405
        handletextpad=0.5,
406
        borderpad=0.5,
407
        handlelength=1.5,
408
        title_fontsize=font_size,
409
        loc="upper left",
410
    )
411

412
    legend2 = plt.legend(
1✔
413
        plot_edges,
414
        labels,
415
        frameon=True,
416
        framealpha=0.5,
417
        facecolor="white",
418
        title=r"$\bf{Higher\ Order}$",
419
        fontsize=font_size,
420
        labelspacing=0.5,
421
        handletextpad=0.5,
422
        borderpad=0.5,
423
        handlelength=1.5,
424
        title_fontsize=font_size,
425
        loc="upper right",
426
    )
427
    if axis:
1✔
428
        axis.add_artist(legend1)
1✔
429
        axis.add_artist(legend2)
1✔
430
    return legend1, legend2
1✔
431

432

433
def _normalize_value(
1✔
434
    value: float | np.ndarray,
435
    max_value: float,
436
    base_value: float,
437
) -> float | np.ndarray:
438
    """Scale a value between 0 and 1 based on the maximum value and a base value.
439

440
    Args:
441
        value: The value to normalize/scale.
442
        max_value: The maximum value to normalize/scale the value by.
443
        base_value: The base value to scale the value by. For example, the alpha value for the
444
            highest interaction (as defined in ``BASE_ALPHA_VALUE``) or the size of the highest
445
            interaction edge (as defined in ``BASE_SIZE``).
446

447
    Returns:
448
        The normalized/scaled value.
449

450
    """
451
    ratio = abs(value) / abs(max_value)  # ratio is always positive in [0, 1]
1✔
452
    return ratio * base_value
1✔
453

454

455
def _draw_fancy_hyper_edges(
1✔
456
    axis: Axes,
457
    pos: dict,
458
    graph: nx.Graph,
459
    hyper_edges: list[tuple],
460
) -> None:
461
    """Draws a collection of hyper-edges as a fancy hyper-edge on the graph.
462

463
    Note:
464
        This is also used to draw normal 2-way edges in a fancy way.
465

466
    Args:
467
        axis: The axis to draw the hyper-edges on.
468
        pos: The positions of the nodes.
469
        graph: The graph to draw the hyper-edges on.
470
        hyper_edges: The hyper-edges to draw.
471

472
    """
473
    for hyper_edge in hyper_edges:
1✔
474
        # store all paths for the hyper-edge to combine them later
475
        all_paths = []
1✔
476

477
        # make also normal (2-way) edges plottable -> one node becomes the "center" node
478
        is_hyper_edge = True
1✔
479
        if len(hyper_edge) == 2:
1✔
480
            u, v = hyper_edge
1✔
481
            center_pos = pos[v]
1✔
482
            node_size = graph[u][v]["size"]
1✔
483
            color = graph[u][v]["color"]
1✔
484
            alpha = graph[u][v]["alpha"]
1✔
485
            is_hyper_edge = False
1✔
486
        else:  # a hyper-edge encodes its information in an artificial "center" node
487
            center_pos = pos[hyper_edge]
1✔
488
            # TODO(advueu963): Technically there is not guarantee it is not sure that the hyper_edge must exist # noqa: TD003
489
            node_size = graph.nodes.get(hyper_edge)["size"]  # pyright: ignore[reportOptionalSubscript]
1✔
490
            color = graph.nodes.get(hyper_edge)["color"]  # pyright: ignore[reportOptionalSubscript]
1✔
491
            alpha = graph.nodes.get(hyper_edge)["alpha"]  # pyright: ignore[reportOptionalSubscript]
1✔
492

493
        alpha = min(1.0, max(0.0, alpha))
1✔
494

495
        # draw the connection point of the hyper-edge
496
        circle = mpath.Path.circle(center_pos, radius=node_size / 2)
1✔
497
        all_paths.append(circle)
1✔
498
        axis.scatter(center_pos[0], center_pos[1], s=0, c="none", lw=0)  # add empty point for limit
1✔
499

500
        # draw the fancy connections from the other nodes to the center node
501
        for player in hyper_edge:
1✔
502
            player_pos = pos[player]
1✔
503

504
            circle_p = mpath.Path.circle(player_pos, radius=node_size / 2)
1✔
505
            all_paths.append(circle_p)
1✔
506
            axis.scatter(player_pos[0], player_pos[1], s=0, c="none", lw=0)  # for axis limits
1✔
507

508
            # get the direction of the connection
509
            direction = (center_pos[0] - player_pos[0], center_pos[1] - player_pos[1])
1✔
510
            direction = np.array(direction) / np.linalg.norm(direction)
1✔
511

512
            # get 90 degree of the direction
513
            direction_90 = np.array([-direction[1], direction[0]])
1✔
514

515
            # get the distance between the player and the center node
516
            distance = np.linalg.norm(center_pos - player_pos)
1✔
517

518
            # get the position of the start and end of the connection
519
            start_pos = player_pos - direction_90 * (node_size / 2)
1✔
520
            middle_pos = player_pos + direction * distance / 2
1✔
521
            end_pos_one = center_pos - direction_90 * (node_size / 2)
1✔
522
            end_pos_two = center_pos + direction_90 * (node_size / 2)
1✔
523
            start_pos_two = player_pos + direction_90 * (node_size / 2)
1✔
524

525
            # create the connection
526
            connection = mpath.Path(
1✔
527
                [
528
                    start_pos,
529
                    middle_pos,
530
                    end_pos_one,
531
                    end_pos_two,
532
                    middle_pos,
533
                    start_pos_two,
534
                    start_pos,
535
                ],
536
                [
537
                    mpath.Path.MOVETO,
538
                    mpath.Path.CURVE3,
539
                    mpath.Path.CURVE3,
540
                    mpath.Path.LINETO,
541
                    mpath.Path.CURVE3,
542
                    mpath.Path.CURVE3,
543
                    mpath.Path.LINETO,
544
                ],
545
            )
546

547
            # add the connection to the list of all paths
548
            all_paths.append(connection)
1✔
549

550
            # break after the first hyper-edge if there are only two players
551
            if not is_hyper_edge:
1✔
552
                break
1✔
553

554
        # combine all paths into one patch
555
        combined_path = mpath.Path.make_compound_path(*all_paths)
1✔
556
        patch = mpatches.PathPatch(combined_path, facecolor=color, lw=0, alpha=alpha)
1✔
557

558
        axis.add_patch(patch)
1✔
559

560

561
def _draw_graph_nodes(
1✔
562
    ax: Axes,
563
    pos: dict,
564
    graph: nx.Graph,
565
    nodes: list | None = None,
566
    normal_node_size: float = NORMAL_NODE_SIZE,
567
) -> None:
568
    """Draws the nodes of the graph as circles with a fixed size.
569

570
    Args:
571
        ax: The axis to draw the nodes on.
572
        pos: The positions of the nodes.
573
        graph: The graph to draw the nodes on.
574
        nodes: The nodes to draw. If ``None``, all nodes are drawn. Defaults to ``None``.
575
        normal_node_size: The size of the nodes. Defaults to ``NORMAL_NODE_SIZE``.
576

577
    """
578
    for node in graph.nodes:
1✔
579
        if nodes is not None and node not in nodes:
1✔
580
            continue
×
581

582
        position = pos[node]
1✔
583
        circle = mpath.Path.circle(position, radius=normal_node_size / 2)
1✔
584
        patch = mpatches.PathPatch(circle, facecolor="white", lw=1, alpha=1, edgecolor="black")
1✔
585
        ax.add_patch(patch)
1✔
586

587
        # add empty scatter for the axis to adjust the limits later
588
        ax.scatter(position[0], position[1], s=0, c="none", lw=0)
1✔
589

590

591
def _draw_explanation_nodes(
1✔
592
    ax: Axes,
593
    pos: dict,
594
    graph: nx.Graph,
595
    nodes: list | None = None,
596
    normal_node_size: float = NORMAL_NODE_SIZE,
597
) -> None:
598
    """Adds the node level explanations to the graph as circles with varying sizes.
599

600
    Args:
601
        ax: The axis to draw the nodes on.
602
        pos: The positions of the nodes.
603
        graph: The graph to draw the nodes on.
604
        nodes: The nodes to draw. If ``None``, all nodes are drawn. Defaults to ``None``.
605
        normal_node_size: The size of the nodes. Defaults to ``NORMAL_NODE_SIZE``.
606

607
    """
608
    for node in graph.nodes:
1✔
609
        if isinstance(node, tuple):
1✔
610
            continue
1✔
611
        if nodes is not None and node not in nodes:
1✔
612
            continue
1✔
613
        position = pos[node]
1✔
614
        # TODO(advueu963): Statically it seems to be not clear that we get an object which is subscriptable # noqa: TD003
615
        color = graph.nodes.get(node)["color"]  # pyright: ignore[reportOptionalSubscript]
1✔
616
        explanation_size = graph.nodes.get(node)["size"]  # pyright: ignore[reportOptionalSubscript]
1✔
617
        alpha = 1.0
1✔
618
        if ADJUST_NODE_ALPHA:
1✔
619
            alpha = graph.nodes.get(node)["alpha"]  # pyright: ignore[reportOptionalSubscript]
1✔
620

621
        alpha = min(1.0, max(0.0, alpha))
1✔
622

623
        radius = normal_node_size / 2 + explanation_size / 2
1✔
624
        circle = mpath.Path.circle(position, radius=radius)
1✔
625
        patch = mpatches.PathPatch(circle, facecolor="white", lw=1, edgecolor="white", alpha=1.0)
1✔
626
        ax.add_patch(patch)
1✔
627
        patch = mpatches.PathPatch(circle, facecolor=color, lw=1, edgecolor="white", alpha=alpha)
1✔
628
        ax.add_patch(patch)
1✔
629

630
        ax.scatter(position[0], position[1], s=0, c="none", lw=0)  # add empty point for limits
1✔
631

632

633
def _draw_graph_edges(
1✔
634
    ax: Axes,
635
    pos: dict,
636
    graph: nx.Graph,
637
    edges: list[tuple] | None = None,
638
    normal_node_size: float = NORMAL_NODE_SIZE,
639
) -> None:
640
    """Draws black lines between the nodes.
641

642
    Args:
643
        ax: The axis to draw the edges on.
644
        pos: The positions of the nodes.
645
        graph: The graph to draw the edges on.
646
        edges: The edges to draw. If ``None`` (default), all edges are drawn.
647
        normal_node_size: The size of the nodes. Defaults to ``NORMAL_NODE_SIZE``.
648

649
    """
650
    for u, v in graph.edges:
1✔
651
        if edges is not None and (u, v) not in edges and (v, u) not in edges:
1✔
652
            continue
×
653

654
        u_pos = pos[u]
1✔
655
        v_pos = pos[v]
1✔
656

657
        direction = v_pos - u_pos
1✔
658
        direction = direction / np.linalg.norm(direction)
1✔
659

660
        start_point = u_pos + direction * normal_node_size / 2
1✔
661
        end_point = v_pos - direction * normal_node_size / 2
1✔
662

663
        connection = mpath.Path(
1✔
664
            [start_point, end_point],
665
            [mpath.Path.MOVETO, mpath.Path.LINETO],
666
        )
667

668
        patch = mpatches.PathPatch(connection, facecolor="none", lw=1, edgecolor="black")
1✔
669
        ax.add_patch(patch)
1✔
670

671

672
def _draw_graph_labels(
1✔
673
    ax: Axes,
674
    pos: dict,
675
    graph: nx.Graph,
676
    *,
677
    nodes: list | None = None,
678
    normal_node_size: float = 1.0,
679
    plot_white_nodes: bool = False,
680
) -> None:
681
    """Adds labels to the nodes of the graph.
682

683
    Args:
684
        ax: The axis to draw the labels on.
685

686
        pos: The positions of the nodes.
687

688
        graph: The graph to draw the labels on.
689

690
        nodes: The nodes to draw the labels on. If ``None`` (default), all nodes are drawn.
691

692
        normal_node_size: The size of the nodes. Defaults to ``1.0``.
693

694
        plot_white_nodes: If set to ``True``, the nodes are drawn as white circles with the label
695
            inside. If set to ``False``, the labels are drawn next to the nodes. Defaults to
696
            ``False``.
697

698
    """
699
    for node in graph.nodes:
1✔
700
        if nodes is not None and node not in nodes:
1✔
701
            continue
×
702
        label = graph.nodes.get(node)["label"]  # pyright: ignore[reportOptionalSubscript]
1✔
703
        position = pos[node]
1✔
704
        if plot_white_nodes:
1✔
705
            offset = (0, 0)
×
706
        else:
707
            # offset so the text is next to the node
708
            offset_norm = np.sqrt(position[0] ** 2 + position[1] ** 2)
1✔
709
            offset = (
1✔
710
                (LABEL_OFFSET + normal_node_size) * position[0] / offset_norm,
711
                (LABEL_OFFSET + normal_node_size) * position[1] / offset_norm,
712
            )
713
        ax.text(
1✔
714
            position[0] + offset[0],
715
            position[1] + offset[1],
716
            label,
717
            fontsize=plt.rcParams["font.size"] + 1,
718
            ha="center",
719
            va="center",
720
            color="black",
721
        )
722

723

724
def _draw_feature_images(
1✔
725
    ax: Axes,
726
    pos: dict,
727
    graph: nx.Graph,
728
    feature_image_patches: dict[int, Image] | list[Image],
729
    patch_size: float,
730
) -> None:
731
    """Draws the feature images.
732

733
    Args:
734
        ax: The axis to draw the edges on.
735
        pos: The positions of the nodes.
736
        graph: The graph to draw the edges on.
737
        feature_image_patches: a dict that stores the images for the players
738
        patch_size: The size of the feature images.
739
    """
740
    x_min, x_max = ax.get_xlim()
1✔
741
    img_scale = x_max - x_min
1✔
742
    extend = img_scale * patch_size / 2
1✔
743
    for node in graph.nodes:
1✔
744
        if node < len(feature_image_patches):
1✔
745
            image = feature_image_patches[node]
1✔
746
            x, y = pos[node]
1✔
747
            offset_norm = np.sqrt(x**2 + y**2)
1✔
748
            # 1.55 -> bit more than sqrt(2) to position the middle of the image
749
            offset = (
1✔
750
                1.55 * patch_size * x / offset_norm,
751
                1.55 * patch_size * y / offset_norm,
752
            )
753
            # x and y are the middle of the image
754
            x, y = x + offset[0], y + offset[1]
1✔
755
            ax.imshow(image, extent=(x - extend, x + extend, y - extend, y + extend))
1✔
756
    # set the plot to show the whole graph
757
    x_min -= img_scale * patch_size
1✔
758
    x_max += img_scale * patch_size
1✔
759
    ax.set_xlim(x_min, x_max)
1✔
760
    ax.set_ylim(x_min, x_max)
1✔
761

762

763
def _adjust_position(
1✔
764
    pos: dict,
765
    graph: nx.Graph,
766
    normal_node_size: float = NORMAL_NODE_SIZE,
767
) -> dict:
768
    """Moves the nodes in the graph further apart if they are too close together."""
769
    # get the minimum distance between two nodes
770
    min_distance = 1e10
1✔
771
    for u, v in graph.edges:
1✔
772
        distance = np.linalg.norm(pos[u] - pos[v]).item()
1✔
773
        min_distance = min(min_distance, distance)
1✔
774

775
    # adjust the positions if the nodes are too close together
776
    min_edge_distance = normal_node_size + normal_node_size / 2
1✔
777
    if min_distance < min_edge_distance:
1✔
778
        for node, position in pos.items():
×
779
            pos[node] = position * min_edge_distance / min_distance
×
780

781
    return pos
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc