• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.22
/src/shapiq/interaction_values.py
1
"""InteractionValues data-class, which is used to store the interaction scores."""
2

3
from __future__ import annotations
1✔
4

5
import contextlib
1✔
6
import copy
1✔
7
import json
1✔
8
import pickle
1✔
9
from pathlib import Path
1✔
10
from typing import TYPE_CHECKING
1✔
11
from warnings import warn
1✔
12

13
import numpy as np
1✔
14

15
from .game_theory.indices import (
1✔
16
    ALL_AVAILABLE_INDICES,
17
    get_index_from_computation_index,
18
    is_empty_value_the_baseline,
19
    is_index_aggregated,
20
    is_index_valid,
21
)
22
from .utils.errors import raise_deprecation_warning
1✔
23
from .utils.saving import safe_str_to_tuple, safe_tuple_to_str
1✔
24
from .utils.sets import generate_interaction_lookup
1✔
25

26
if TYPE_CHECKING:
1✔
27
    from collections.abc import Sequence
×
28
    from typing import Any
×
29

30
    from matplotlib.axes import Axes
×
31
    from matplotlib.figure import Figure
×
32

NEW
33
    from shapiq.typing import InteractionScores, JSONType
×
34

35
SAVE_JSON_DEPRECATION_MSG = (
1✔
36
    "Saving InteractionValues not as a JSON file is deprecated. "
37
    "The parameters `as_pickle` and `as_npz` will be removed in the future. "
38
)
39

40

41
class InteractionValues:
1✔
42
    """This class contains the interaction values as estimated by an approximator.
43

44
    Attributes:
45
        values: The interaction values of the model in vectorized form.
46
        index: The interaction index estimated. All available indices are defined in
47
            ``ALL_AVAILABLE_INDICES``.
48
        max_order: The order of the approximation.
49
        n_players: The number of players.
50
        min_order: The minimum order of the approximation. Defaults to ``0``.
51
        interaction_lookup: A dictionary that maps interactions to their index in the values
52
            vector. If ``interaction_lookup`` is not provided, it is computed from the ``n_players``,
53
            ``min_order``, and `max_order` parameters. Defaults to ``None``.
54
        estimated: Whether the interaction values are estimated or not. Defaults to ``True``.
55
        estimation_budget: The budget used for the estimation. Defaults to ``None``.
56
        baseline_value: The value of the baseline interaction also known as 'empty prediction' or
57
            ``'empty value'`` since it denotes the value of the empty coalition (empty set). If not
58
            provided it is searched for in the values vector (raising an Error if not found).
59
            Defaults to ``None``.
60

61
    Raises:
62
        UserWarning: If the index is not a valid index as defined in ``ALL_AVAILABLE_INDICES``.
63
        TypeError: If the baseline value is not a number.
64

65
    """
66

67
    interactions: InteractionScores
1✔
68
    """The interactions as a dictionary mapping interactions to their values."""
1✔
69

70
    def __init__(
1✔
71
        self,
72
        values: np.ndarray | InteractionScores,
73
        *,
74
        index: str,
75
        max_order: int,
76
        n_players: int,
77
        min_order: int,
78
        interaction_lookup: dict[tuple[int, ...], int] | None = None,  # type: ignore[assignment]
79
        estimated: bool = True,
80
        estimation_budget: int | None = None,  # type: ignore[assignment]
81
        baseline_value: float | np.number = 0.0,
82
        target_index: str | None = None,
83
    ) -> None:
84
        """Initialize the InteractionValues object.
85

86
        Args:
87
            values: The interaction values as a numpy array or a dictionary mapping interactions to their
88
                values.
89

90
            index: The index of the interaction values. This should be one of the indices defined in
91
            ALL_AVAILABLE_INDICES. It is used to determine how the interaction values are interpreted.
92
            max_order: The maximum order of the interactions.
93
            n_players: The number of players in the game.
94
            min_order: The minimum order of the interactions. Defaults to 0.
95
            interaction_lookup: A dictionary mapping interactions to their index in the values vector.
96
            Defaults to None, which means it will be generated from the n_players, min_order, and max_order parameters.
97
            estimated: Whether the interaction values are estimated or not. Defaults to True.
98
            estimation_budget: The budget used for the estimation. Defaults to None.
99
            baseline_value: The baseline value of the interaction values, also known as the empty prediction or empty value.
100
            target_index: The index to which the InteractionValues should be finalized. Defaults to None, which means that
101
            target_index = index
102
        """
103
        if not isinstance(baseline_value, (int | float | np.number)):
1✔
104
            msg = f"Baseline value must be provided as a number. Got {type(baseline_value)}."
1✔
105
            raise TypeError(msg)
1✔
106
        self.baseline_value = baseline_value
1✔
107
        if not is_index_valid(index, raise_error=False):
1✔
108
            warn(
1✔
109
                f"Index `{index}` is not a valid interaction index. "
110
                f"Valid indices are: {', '.join(ALL_AVAILABLE_INDICES)}.",
111
                stacklevel=2,
112
            )
113
        index = get_index_from_computation_index(index, max_order)
1✔
114
        if target_index is None:
1✔
115
            target_index = index
1✔
116

117
        interactions = _validate_and_return_interactions(
1✔
118
            values=values,
119
            interaction_lookup=interaction_lookup,
120
            n_players=n_players,
121
            min_order=min_order,
122
            max_order=max_order,
123
            baseline_value=baseline_value,
124
        )
125

126
        interactions, index, min_order, baseline_value = _update_interactions_for_index(
1✔
127
            interactions=interactions,
128
            index=index,
129
            target_index=target_index,
130
            min_order=min_order,
131
            max_order=max_order,
132
            baseline_value=baseline_value,
133
        )
134

135
        self.interactions = interactions
1✔
136
        self.index = index
1✔
137
        self.max_order = max_order
1✔
138
        self.n_players = n_players
1✔
139
        self.min_order = min_order
1✔
140
        self.estimated = estimated
1✔
141
        self.estimation_budget = estimation_budget
1✔
142

143
    @property
1✔
144
    def dict_values(self) -> dict[tuple[int, ...], float]:
1✔
145
        """Getter for the dict directly mapping from all interactions to scores."""
146
        return self.interactions
1✔
147

148
    @property
1✔
149
    def values(self) -> np.ndarray:
1✔
150
        """Getter for the values of the InteractionValues object.
151

152
        Returns:
153
            The values of the InteractionValues object as a numpy array.
154

155
        """
156
        return np.array(list(self.interactions.values()))
1✔
157

158
    @property
1✔
159
    def interaction_lookup(self) -> dict[tuple[int, ...], int]:
1✔
160
        """Getter for the interaction lookup of the InteractionValues object.
161

162
        Returns:
163
            The interaction lookup of the InteractionValues object as a dictionary mapping interactions
164
            to their index in the values vector.
165

166
        """
167
        return {
1✔
168
            interaction: index for index, (interaction, _) in enumerate(self.interactions.items())
169
        }
170

171
    def to_json_file(
1✔
172
        self,
173
        path: Path,
174
        *,
175
        desc: str | None = None,
176
        created_from: object | None = None,
177
        **kwargs: JSONType,
178
    ) -> None:
179
        """Saves the InteractionValues object to a JSON file.
180

181
        Args:
182
            path: The path to the JSON file.
183
            desc: A description of the InteractionValues object. Defaults to ``None``.
184
            created_from: An object from which the InteractionValues object was created. Defaults to
185
                ``None``.
186
            **kwargs: Additional parameters to store in the metadata of the JSON file.
187
        """
188
        from shapiq.utils.saving import (
1✔
189
            interactions_to_dict,
190
            make_file_metadata,
191
            save_json,
192
        )
193

194
        file_metadata = make_file_metadata(
1✔
195
            object_to_store=self,
196
            data_type="interaction_values",
197
            desc=desc,
198
            created_from=created_from,
199
            parameters=kwargs,
200
        )
201
        json_data = {
1✔
202
            **file_metadata,
203
            "metadata": {
204
                "n_players": self.n_players,
205
                "index": self.index,
206
                "max_order": self.max_order,
207
                "min_order": self.min_order,
208
                "estimated": self.estimated,
209
                "estimation_budget": self.estimation_budget,
210
                "baseline_value": self.baseline_value,
211
            },
212
            "data": interactions_to_dict(interactions=self.dict_values),
213
        }
214
        save_json(json_data, path)
1✔
215

216
    @classmethod
1✔
217
    def from_json_file(cls, path: Path) -> InteractionValues:
1✔
218
        """Loads an InteractionValues object from a JSON file.
219

220
        Args:
221
            path: The path to the JSON file. Note that the path must end with `'.json'`.
222

223
        Returns:
224
            The InteractionValues object loaded from the JSON file.
225

226
        Raises:
227
            ValueError: If the path does not end with `'.json'`.
228
        """
229
        from shapiq.utils.saving import dict_to_lookup_and_values
1✔
230

231
        if not path.name.endswith(".json"):
1✔
232
            msg = f"Path {path} does not end with .json. Cannot load InteractionValues."
×
233
            raise ValueError(msg)
×
234

235
        with path.open("r", encoding="utf-8") as file:
1✔
236
            json_data = json.load(file)
1✔
237

238
        metadata = json_data["metadata"]
1✔
239
        interaction_dict = json_data["data"]
1✔
240
        interaction_lookup, values = dict_to_lookup_and_values(interaction_dict)
1✔
241

242
        return cls(
1✔
243
            values=values,
244
            index=metadata["index"],
245
            max_order=metadata["max_order"],
246
            n_players=metadata["n_players"],
247
            min_order=metadata["min_order"],
248
            interaction_lookup=interaction_lookup,
249
            estimated=metadata["estimated"],
250
            estimation_budget=metadata["estimation_budget"],
251
            baseline_value=metadata["baseline_value"],
252
        )
253

254
    def sparsify(self, threshold: float = 1e-3) -> None:
1✔
255
        """Manually sets values close to zero actually to zero (removing values).
256

257
        Args:
258
            threshold: The threshold value below which interactions are zeroed out. Defaults to
259
                1e-3.
260

261
        """
262
        # find interactions to remove in self.interactions
263
        sparse_interactions = copy.deepcopy(self.interactions)
1✔
264
        for interaction, value in self.interactions.items():
1✔
265
            if np.abs(value) < threshold:
1✔
266
                del sparse_interactions[interaction]
1✔
267
        self.interactions = sparse_interactions
1✔
268

269
    def get_top_k_interactions(self, k: int) -> InteractionValues:
1✔
270
        """Returns the top k interactions.
271

272
        Args:
273
            k: The number of top interactions to return.
274

275
        Returns:
276
            The top k interactions as an InteractionValues object.
277

278
        """
279
        top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
1✔
280
        new_values = np.zeros(k, dtype=float)
1✔
281
        new_interaction_lookup = {}
1✔
282
        for interaction_pos, interaction in enumerate(self.interaction_lookup):
1✔
283
            if interaction_pos in top_k_indices:
1✔
284
                new_position = len(new_interaction_lookup)
1✔
285
                new_values[new_position] = float(self[interaction_pos])
1✔
286
                new_interaction_lookup[interaction] = new_position
1✔
287
        return InteractionValues(
1✔
288
            values=new_values,
289
            index=self.index,
290
            max_order=self.max_order,
291
            n_players=self.n_players,
292
            min_order=self.min_order,
293
            interaction_lookup=new_interaction_lookup,
294
            estimated=self.estimated,
295
            estimation_budget=self.estimation_budget,
296
            baseline_value=self.baseline_value,
297
        )
298

299
    def get_top_k(
1✔
300
        self, k: int, *, as_interaction_values: bool = True
301
    ) -> InteractionValues | tuple[dict, list[tuple]]:
302
        """Returns the top k interactions.
303

304
        Args:
305
            k: The number of top interactions to return.
306
            as_interaction_values: Whether to return the top `k` interactions as an InteractionValues
307
                object. Defaults to ``False``.
308

309
        Returns:
310
            The top k interactions as a dictionary and a sorted list of tuples.
311

312
        Examples:
313
            >>> interaction_values = InteractionValues(
314
            ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
315
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
316
            ...     index="SII",
317
            ...     max_order=2,
318
            ...     n_players=3,
319
            ...     min_order=1,
320
            ...     baseline_value=0.0,
321
            ... )
322
            >>> top_k_interactions, sorted_top_k_interactions = interaction_values.get_top_k(2, False)
323
            >>> top_k_interactions
324
            {(0, 2): 0.5, (1, 0): 0.6}
325
            >>> sorted_top_k_interactions
326
            [((1, 0), 0.6), ((0, 2), 0.5)]
327

328
        """
329
        if as_interaction_values:
1✔
UNCOV
330
            return self.get_top_k_interactions(k)
×
331
        top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
1✔
332
        top_k_interactions = {}
1✔
333
        for interaction, index in self.interaction_lookup.items():
1✔
334
            if index in top_k_indices:
1✔
335
                top_k_interactions[interaction] = self.values[index]
1✔
336
        sorted_top_k_interactions = [
1✔
337
            (interaction, top_k_interactions[interaction])
338
            for interaction in sorted(
339
                top_k_interactions, key=lambda x: top_k_interactions[x], reverse=True
340
            )
341
        ]
342
        return top_k_interactions, sorted_top_k_interactions
1✔
343

344
    def __repr__(self) -> str:
1✔
345
        """Returns the representation of the InteractionValues object."""
346
        representation = "InteractionValues(\n"
1✔
347
        representation += (
1✔
348
            f"    index={self.index}, max_order={self.max_order}, min_order={self.min_order}"
349
            f", estimated={self.estimated}, estimation_budget={self.estimation_budget},\n"
350
            f"    n_players={self.n_players}, baseline_value={self.baseline_value}\n)"
351
        )
352
        return representation
1✔
353

354
    def __str__(self) -> str:
1✔
355
        """Returns the string representation of the InteractionValues object."""
356
        representation = self.__repr__()
1✔
357
        representation = representation[:-2]  # remove the last "\n)" and add values
1✔
358
        _, sorted_top_10_interactions = self.get_top_k(
1✔
359
            10, as_interaction_values=False
360
        )  # get top 10 interactions
361
        # add values to string representation
362
        representation += ",\n    Top 10 interactions:\n"
1✔
363
        for interaction, value in sorted_top_10_interactions:
1✔
364
            representation += f"        {interaction}: {value}\n"
1✔
365
        representation += ")"
1✔
366
        return representation
1✔
367

368
    def __len__(self) -> int:
1✔
369
        """Returns the length of the InteractionValues object."""
370
        return len(self.values)  # might better to return the theoretical no. of interactions
1✔
371

372
    def __iter__(self) -> np.nditer:
1✔
373
        """Returns an iterator over the values of the InteractionValues object."""
374
        return np.nditer(self.values)
1✔
375

376
    def __getitem__(self, item: int | tuple[int, ...]) -> float:
1✔
377
        """Returns the score for the given interaction.
378

379
        Args:
380
            item: The interaction as a tuple of integers for which to return the score. If ``item`` is
381
                an integer it serves as the index to the values vector.
382

383
        Returns:
384
            The interaction value. If the interaction is not present zero is returned.
385

386
        """
387
        if isinstance(item, int):
1✔
388
            return float(self.values[item])
1✔
389
        item = tuple(sorted(item))
1✔
390
        try:
1✔
391
            return float(self.interactions[item])
1✔
392
        except KeyError:
1✔
393
            return 0.0
1✔
394

395
    def __setitem__(self, item: int | tuple[int, ...], value: float) -> None:
1✔
396
        """Sets the score for the given interaction.
397

398
        Args:
399
            item: The interaction as a tuple of integers for which to set the score. If ``item`` is an
400
                integer it serves as the index to the values vector.
401
            value: The value to set for the interaction.
402

403
        Raises:
404
            KeyError: If the interaction is not found in the InteractionValues object.
405

406
        """
407
        try:
1✔
408
            if isinstance(item, int):
1✔
409
                # dict.items() preserves the order of insertion, so we can use it to set the value
410
                for i, (interaction, _) in enumerate(self.interactions.items()):
1✔
411
                    if i == item:
1✔
412
                        self.interactions[interaction] = value
1✔
413
                        break
1✔
414
            else:
415
                item = tuple(sorted(item))
1✔
416
                if self.interactions[item] is not None:
1✔
417
                    # if the interaction is already present, update its value. Otherwise KeyError is raised
418
                    self.interactions[item] = value
1✔
419
        except Exception as e:
1✔
420
            msg = f"Interaction {item} not found in the InteractionValues. Unable to set a value."
1✔
421
            raise KeyError(msg) from e
1✔
422

423
    def __eq__(self, other: object) -> bool:
1✔
424
        """Checks if two InteractionValues objects are equal.
425

426
        Args:
427
            other: The other InteractionValues object.
428

429
        Returns:
430
            True if the two objects are equal, False otherwise.
431

432
        """
433
        if not isinstance(other, InteractionValues):
1✔
434
            msg = "Cannot compare InteractionValues with other types."
1✔
435
            raise TypeError(msg)
1✔
436
        if (
1✔
437
            self.index != other.index
438
            or self.max_order != other.max_order
439
            or self.min_order != other.min_order
440
            or self.n_players != other.n_players
441
            or not np.allclose(self.baseline_value, other.baseline_value)
442
        ):
443
            return False
1✔
444
        if not np.allclose(self.values, other.values):
1✔
445
            return False
1✔
446
        return self.interaction_lookup == other.interaction_lookup
1✔
447

448
    def __ne__(self, other: object) -> bool:
1✔
449
        """Checks if two InteractionValues objects are not equal.
450

451
        Args:
452
            other: The other InteractionValues object.
453

454
        Returns:
455
            True if the two objects are not equal, False otherwise.
456

457
        """
458
        return not self.__eq__(other)
1✔
459

460
    def __hash__(self) -> int:
1✔
461
        """Returns the hash of the InteractionValues object."""
462
        return hash(
1✔
463
            (
464
                self.index,
465
                self.max_order,
466
                self.min_order,
467
                self.n_players,
468
                tuple(self.values.flatten()),
469
            ),
470
        )
471

472
    def __copy__(self) -> InteractionValues:
1✔
473
        """Returns a copy of the InteractionValues object."""
474
        return InteractionValues(
1✔
475
            values=copy.deepcopy(self.values),
476
            index=self.index,
477
            max_order=self.max_order,
478
            estimated=self.estimated,
479
            estimation_budget=self.estimation_budget,
480
            n_players=self.n_players,
481
            interaction_lookup=copy.deepcopy(self.interaction_lookup),
482
            min_order=self.min_order,
483
            baseline_value=self.baseline_value,
484
        )
485

486
    def __add__(self, other: InteractionValues | float) -> InteractionValues:
1✔
487
        """Adds two InteractionValues objects together or a scalar."""
488
        n_players, min_order, max_order = self.n_players, self.min_order, self.max_order
1✔
489
        if isinstance(other, InteractionValues):
1✔
490
            if self.index != other.index:  # different indices
1✔
491
                msg = (
1✔
492
                    f"The indices of the InteractionValues objects are different: "
493
                    f"{self.index} != {other.index}. Addition might not be meaningful."
494
                )
495
                warn(msg, stacklevel=2)
1✔
496
            if (
1✔
497
                self.interaction_lookup != other.interaction_lookup
498
                or self.n_players != other.n_players
499
                or self.min_order != other.min_order
500
                or self.max_order != other.max_order
501
            ):  # different interactions but addable
502
                added_interactions = self.interactions.copy()
1✔
503
                for interaction in other.interactions:
1✔
504
                    if interaction not in added_interactions:
1✔
505
                        added_interactions[interaction] = other.interactions[interaction]
1✔
506
                    else:
507
                        added_interactions[interaction] += other.interactions[interaction]
1✔
508
                interaction_lookup = {
1✔
509
                    interaction: i for i, interaction in enumerate(added_interactions)
510
                }
511
                # adjust n_players, min_order, and max_order
512
                n_players = max(self.n_players, other.n_players)
1✔
513
                min_order = min(self.min_order, other.min_order)
1✔
514
                max_order = max(self.max_order, other.max_order)
1✔
515
                baseline_value = self.baseline_value + other.baseline_value
1✔
516
            else:  # basic case with same interactions
517
                added_interactions = {
1✔
518
                    interaction: self.interactions[interaction] + other.interactions[interaction]
519
                    for interaction in self.interactions
520
                }
521
                interaction_lookup = self.interaction_lookup
1✔
522
                baseline_value = self.baseline_value + other.baseline_value
1✔
523
        elif isinstance(other, int | float):
1✔
524
            added_interactions = {
1✔
525
                interaction: self.interactions[interaction] + other
526
                for interaction in self.interactions
527
            }
528
            interaction_lookup = self.interaction_lookup.copy()
1✔
529
            baseline_value = self.baseline_value + other
1✔
530
        else:
531
            msg = f"Cannot add InteractionValues with object of type {type(other)}."
1✔
532
            raise TypeError(msg)
1✔
533

534
        return InteractionValues(
1✔
535
            values=added_interactions,
536
            index=self.index,
537
            max_order=max_order,
538
            n_players=n_players,
539
            min_order=min_order,
540
            interaction_lookup=interaction_lookup,
541
            estimated=self.estimated,
542
            estimation_budget=self.estimation_budget,
543
            baseline_value=baseline_value,
544
        )
545

546
    def __radd__(self, other: InteractionValues | float) -> InteractionValues:
1✔
547
        """Adds two InteractionValues objects together or a scalar."""
548
        return self.__add__(other)
1✔
549

550
    def __neg__(self) -> InteractionValues:
1✔
551
        """Negates the InteractionValues object."""
552
        return InteractionValues(
1✔
553
            values=-self.values,
554
            index=self.index,
555
            max_order=self.max_order,
556
            n_players=self.n_players,
557
            min_order=self.min_order,
558
            interaction_lookup=self.interaction_lookup,
559
            estimated=self.estimated,
560
            estimation_budget=self.estimation_budget,
561
            baseline_value=-self.baseline_value,
562
        )
563

564
    def __sub__(self, other: InteractionValues | float) -> InteractionValues:
1✔
565
        """Subtracts two InteractionValues objects or a scalar."""
566
        return self.__add__(-other)
1✔
567

568
    def __rsub__(self, other: InteractionValues | float) -> InteractionValues:
1✔
569
        """Subtracts two InteractionValues objects or a scalar."""
570
        return (-self).__add__(other)
1✔
571

572
    def __mul__(self, other: float) -> InteractionValues:
1✔
573
        """Multiplies an InteractionValues object by a scalar."""
574
        interactions = {
1✔
575
            interaction: value * other for interaction, value in self.interactions.items()
576
        }
577
        return InteractionValues(
1✔
578
            values=interactions,
579
            index=self.index,
580
            max_order=self.max_order,
581
            n_players=self.n_players,
582
            min_order=self.min_order,
583
            interaction_lookup=self.interaction_lookup,
584
            estimated=self.estimated,
585
            estimation_budget=self.estimation_budget,
586
            baseline_value=self.baseline_value * other,
587
        )
588

589
    def __rmul__(self, other: float) -> InteractionValues:
1✔
590
        """Multiplies an InteractionValues object by a scalar."""
591
        return self.__mul__(other)
1✔
592

593
    def __abs__(self) -> InteractionValues:
1✔
594
        """Returns the absolute values of the InteractionValues object."""
595
        interactions = {interaction: abs(value) for interaction, value in self.interactions.items()}
1✔
596
        return InteractionValues(
1✔
597
            values=interactions,
598
            index=self.index,
599
            max_order=self.max_order,
600
            n_players=self.n_players,
601
            min_order=self.min_order,
602
            interaction_lookup=self.interaction_lookup,
603
            estimated=self.estimated,
604
            estimation_budget=self.estimation_budget,
605
            baseline_value=self.baseline_value,
606
        )
607

608
    def get_n_order_values(self, order: int) -> np.ndarray:
1✔
609
        """Returns the interaction values of a specific order as a numpy array.
610

611
        Note:
612
            Depending on the order and number of players the resulting array might be sparse and
613
            very large.
614

615
        Args:
616
            order: The order of the interactions to return.
617

618
        Returns:
619
            The interaction values of the specified order as a numpy array of shape ``(n_players,)``
620
            for order ``1`` and ``(n_players, n_players)`` for order ``2``, etc.
621

622
        Raises:
623
            ValueError: If the order is less than ``1``.
624

625
        """
626
        from itertools import permutations
1✔
627

628
        if order < 1:
1✔
629
            msg = "Order must be greater or equal to 1."
1✔
630
            raise ValueError(msg)
1✔
631
        values_shape = tuple([self.n_players] * order)
1✔
632
        values = np.zeros(values_shape, dtype=float)
1✔
633
        for interaction in self.interaction_lookup:
1✔
634
            if len(interaction) != order:
1✔
635
                continue
1✔
636
            # get all orderings of the interaction (e.g. (0, 1) and (1, 0) for interaction (0, 1))
637
            for perm in permutations(interaction):
1✔
638
                values[perm] = self[interaction]
1✔
639

640
        return values
1✔
641

642
    def get_n_order(
1✔
643
        self,
644
        order: int | None = None,
645
        min_order: int | None = None,
646
        max_order: int | None = None,
647
    ) -> InteractionValues:
648
        """Select particular order of interactions.
649

650
        Creates a new InteractionValues object containing only the interactions within the
651
        specified order range.
652

653
        You can specify:
654
            - `order`: to select interactions of a single specific order (e.g., all pairwise
655
                interactions).
656
            - `min_order` and/or `max_order`: to select a range of interaction orders.
657
            - If `order` and `min_order`/`max_order` are both set, `min_order` and `max_order` will
658
                override the `order` value.
659

660
        Example:
661
            >>> interaction_values = InteractionValues(
662
            ...     values=np.array([1, 2, 3, 4, 5, 6, 7]),
663
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5, (0, 1, 2): 6},
664
            ...     index="SII",
665
            ...     max_order=3,
666
            ...     n_players=3,
667
            ...     min_order=1,
668
            ...     baseline_value=0.0,
669
            ... )
670
            >>> interaction_values.get_n_order(order=1).dict_values
671
            {(0,): 1.0, (1,): 2.0, (2,): 3.0}
672
            >>> interaction_values.get_n_order(min_order=1, max_order=2).dict_values
673
            {(0,): 1.0, (1,): 2.0, (2,): 3.0, (0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0}
674
            >>> interaction_values.get_n_order(min_order=2).dict_values
675
            {(0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0, (0, 1, 2): 7.0}
676

677
        Args:
678
            order: The order of the interactions to return. Defaults to ``None`` which requires
679
                ``min_order`` or ``max_order`` to be set.
680
            min_order: The minimum order of the interactions to return. Defaults to ``None`` which
681
                sets it to the order.
682
            max_order: The maximum order of the interactions to return. Defaults to ``None`` which
683
                sets it to the order.
684

685
        Returns:
686
            The interaction values of the specified order.
687

688
        Raises:
689
            ValueError: If all three parameters are set to ``None``.
690
        """
691
        if order is None and min_order is None and max_order is None:
1✔
692
            msg = "Either order, min_order or max_order must be set."
1✔
693
            raise ValueError(msg)
1✔
694

695
        if order is not None:
1✔
696
            max_order = order if max_order is None else max_order
1✔
697
            min_order = order if min_order is None else min_order
1✔
698
        else:  # order is None
699
            min_order = self.min_order if min_order is None else min_order
1✔
700
            max_order = self.max_order if max_order is None else max_order
1✔
701

702
        if min_order > max_order:
1✔
703
            msg = f"min_order ({min_order}) must be less than or equal to max_order ({max_order})."
1✔
704
            raise ValueError(msg)
1✔
705

706
        new_values = []
1✔
707
        new_interaction_lookup = {}
1✔
708
        for interaction in self.interaction_lookup:
1✔
709
            if len(interaction) < min_order or len(interaction) > max_order:
1✔
710
                continue
1✔
711
            interaction_idx = len(new_interaction_lookup)
1✔
712
            new_values.append(self[interaction])
1✔
713
            new_interaction_lookup[interaction] = interaction_idx
1✔
714

715
        return InteractionValues(
1✔
716
            values=np.array(new_values),
717
            index=self.index,
718
            max_order=max_order,
719
            n_players=self.n_players,
720
            min_order=min_order,
721
            interaction_lookup=new_interaction_lookup,
722
            estimated=self.estimated,
723
            estimation_budget=self.estimation_budget,
724
            baseline_value=self.baseline_value,
725
        )
726

727
    def get_subset(self, players: list[int]) -> InteractionValues:
1✔
728
        """Selects a subset of players from the InteractionValues object.
729

730
        Args:
731
            players: List of players to select from the InteractionValues object.
732

733
        Returns:
734
            InteractionValues: Filtered InteractionValues object containing only values related to
735
            selected players.
736

737
        Example:
738
            >>> interaction_values = InteractionValues(
739
            ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
740
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
741
            ...     index="SII",
742
            ...     max_order=2,
743
            ...     n_players=3,
744
            ...     min_order=1,
745
            ...     baseline_value=0.0,
746
            ... )
747
            >>> interaction_values.get_subset([0, 1]).dict_values
748
            {(0,): 0.1, (1,): 0.2, (0, 1): 0.3}
749
            >>> interaction_values.get_subset([0, 2]).dict_values
750
            {(0,): 0.1, (2,): 0.3, (0, 2): 0.4}
751
            >>> interaction_values.get_subset([1]).dict_values
752
            {(1,): 0.2}
753

754
        """
755
        keys = self.interaction_lookup.keys()
1✔
756
        idx, keys_in_subset = [], []
1✔
757
        for i, key in enumerate(keys):
1✔
758
            if all(p in players for p in key):
1✔
759
                idx.append(i)
1✔
760
                keys_in_subset.append(key)
1✔
761
        new_values = self.values[idx]
1✔
762
        new_interaction_lookup = {key: index for index, key in enumerate(keys_in_subset)}
1✔
763
        n_players = self.n_players - len(players)
1✔
764
        return InteractionValues(
1✔
765
            values=new_values,
766
            index=self.index,
767
            max_order=self.max_order,
768
            n_players=n_players,
769
            min_order=self.min_order,
770
            interaction_lookup=new_interaction_lookup,
771
            estimated=self.estimated,
772
            estimation_budget=self.estimation_budget,
773
            baseline_value=self.baseline_value,
774
        )
775

776
    def save(self, path: Path, *, as_pickle: bool = False, as_npz: bool = False) -> None:
1✔
777
        """Save the InteractionValues object to a file.
778

779
        By default, the InteractionValues object is saved as a JSON file.
780

781
        Args:
782
            path: The path to save the InteractionValues object to.
783
            as_pickle: Whether to save the InteractionValues object as a pickle file (``True``).
784
            as_npz: Whether to save the InteractionValues object as a ``npz`` file (``True``).
785

786
        Raises:
787
            DeprecationWarning: If `as_pickle` or `as_npz` is set to ``True``, a deprecation
788
                warning is raised
789
        """
790
        # check if the directory exists
791
        directory = Path(path).parent
1✔
792
        if not Path(directory).exists():
1✔
793
            with contextlib.suppress(FileNotFoundError):
×
794
                Path(directory).mkdir(parents=True, exist_ok=True)
×
795
        if as_pickle:
1✔
796
            raise_deprecation_warning(
1✔
797
                message=SAVE_JSON_DEPRECATION_MSG,
798
                deprecated_in="1.3.1",
799
                removed_in="1.4.0",
800
            )
801
            with Path(path).open("wb") as file:
1✔
802
                pickle.dump(self, file)
1✔
803
        elif as_npz:
1✔
804
            raise_deprecation_warning(
1✔
805
                message=SAVE_JSON_DEPRECATION_MSG,
806
                deprecated_in="1.3.1",
807
                removed_in="1.4.0",
808
            )
809
            # save object as npz file
810
            interaction_keys = np.array(
1✔
811
                list(map(safe_tuple_to_str, self.interaction_lookup.keys()))
812
            )
813
            interaction_indices = np.array(list(self.interaction_lookup.values()))
1✔
814
            estimation_budget = self.estimation_budget if self.estimation_budget is not None else -1
1✔
815

816
            np.savez(
1✔
817
                path,
818
                values=self.values,
819
                index=self.index,
820
                max_order=self.max_order,
821
                n_players=self.n_players,
822
                min_order=self.min_order,
823
                interaction_lookup_keys=interaction_keys,
824
                interaction_lookup_indices=interaction_indices,
825
                estimated=self.estimated,
826
                estimation_budget=estimation_budget,
827
                baseline_value=self.baseline_value,
828
            )
829
        else:
830
            self.to_json_file(path)
1✔
831

832
    @classmethod
1✔
833
    def load(cls, path: Path | str) -> InteractionValues:
1✔
834
        """Load an InteractionValues object from a file.
835

836
        Args:
837
            path: The path to load the InteractionValues object from.
838

839
        Returns:
840
            The loaded InteractionValues object.
841

842
        """
843
        path = Path(path)
1✔
844
        # check if path ends with .json
845
        if path.name.endswith(".json"):
1✔
846
            return cls.from_json_file(path)
1✔
847

848
        raise_deprecation_warning(
1✔
849
            SAVE_JSON_DEPRECATION_MSG, deprecated_in="1.3.1", removed_in="1.4.0"
850
        )
851

852
        # try loading as npz file
853
        if path.name.endswith(".npz"):
1✔
854
            data = np.load(path, allow_pickle=True)
1✔
855
            try:
1✔
856
                # try to load Pyright save format
857
                interaction_lookup = {
1✔
858
                    safe_str_to_tuple(key): int(value)
859
                    for key, value in zip(
860
                        data["interaction_lookup_keys"],
861
                        data["interaction_lookup_indices"],
862
                        strict=False,
863
                    )
864
                }
NEW
865
            except KeyError:
×
866
                # fallback to old format
NEW
867
                interaction_lookup = data["interaction_lookup"].item()
×
868
            estimation_budget = data["estimation_budget"].item()
1✔
869
            if estimation_budget == -1:
1✔
870
                estimation_budget = None
1✔
871
            return InteractionValues(
1✔
872
                values=data["values"],
873
                index=str(data["index"]),
874
                max_order=int(data["max_order"]),
875
                n_players=int(data["n_players"]),
876
                min_order=int(data["min_order"]),
877
                interaction_lookup=interaction_lookup,
878
                estimated=bool(data["estimated"]),
879
                estimation_budget=estimation_budget,
880
                baseline_value=float(data["baseline_value"]),
881
            )
882
        msg = f"Path {path} does not end with .json or .npz. Cannot load InteractionValues."
×
883
        raise ValueError(msg)
×
884

885
    @classmethod
1✔
886
    def from_dict(cls, data: dict[str, Any]) -> InteractionValues:
1✔
887
        """Create an InteractionValues object from a dictionary.
888

889
        Args:
890
            data: The dictionary containing the data to create the InteractionValues object from.
891

892
        Returns:
893
            The InteractionValues object created from the dictionary.
894

895
        """
896
        return cls(
1✔
897
            values=data["values"],
898
            index=data["index"],
899
            max_order=data["max_order"],
900
            n_players=data["n_players"],
901
            min_order=data["min_order"],
902
            interaction_lookup=data["interaction_lookup"],
903
            estimated=data["estimated"],
904
            estimation_budget=data["estimation_budget"],
905
            baseline_value=data["baseline_value"],
906
        )
907

908
    def to_dict(self) -> dict:
1✔
909
        """Convert the InteractionValues object to a dictionary.
910

911
        Returns:
912
            The InteractionValues object as a dictionary.
913

914
        """
915
        return {
1✔
916
            "values": self.interactions,
917
            "index": self.index,
918
            "max_order": self.max_order,
919
            "n_players": self.n_players,
920
            "min_order": self.min_order,
921
            "interaction_lookup": self.interaction_lookup,
922
            "estimated": self.estimated,
923
            "estimation_budget": self.estimation_budget,
924
            "baseline_value": self.baseline_value,
925
        }
926

927
    def aggregate(
1✔
928
        self,
929
        others: Sequence[InteractionValues],
930
        aggregation: str = "mean",
931
    ) -> InteractionValues:
932
        """Aggregates InteractionValues objects using a specific aggregation method.
933

934
        Args:
935
            others: A list of InteractionValues objects to aggregate.
936
            aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
937
                ``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
938

939
        Returns:
940
            The aggregated InteractionValues object.
941

942
        Note:
943
            For documentation on the aggregation methods, see the ``aggregate_interaction_values()``
944
            function.
945

946
        """
947
        return aggregate_interaction_values([self, *others], aggregation)
1✔
948

949
    def plot_network(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
950
        """Visualize InteractionValues on a graph.
951

952
        Note:
953
            For arguments, see :func:`shapiq.plot.network.network_plot` and
954
                :func:`shapiq.plot.si_graph.si_graph_plot`.
955

956
        Args:
957
            show: Whether to show the plot. Defaults to ``True``.
958

959
            **kwargs: Additional keyword arguments to pass to the plotting function.
960

961
        Returns:
962
            If show is ``False``, the function returns a tuple with the figure and the axis of the
963
                plot.
964
        """
965
        from shapiq.plot.network import network_plot
1✔
966

967
        if self.max_order > 1:
1✔
968
            return network_plot(
1✔
969
                interaction_values=self,
970
                show=show,
971
                **kwargs,
972
            )
973
        msg = (
1✔
974
            "InteractionValues contains only 1-order values,"
975
            "but requires also 2-order values for the network plot."
976
        )
977
        raise ValueError(msg)
1✔
978

979
    def plot_si_graph(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
980
        """Visualize InteractionValues as a SI graph.
981

982
        For arguments, see shapiq.plots.si_graph_plot().
983

984
        Returns:
985
            The SI graph as a tuple containing the figure and the axes.
986

987
        """
988
        from shapiq.plot.si_graph import si_graph_plot
1✔
989

990
        return si_graph_plot(self, show=show, **kwargs)
1✔
991

992
    def plot_stacked_bar(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
993
        """Visualize InteractionValues on a graph.
994

995
        For arguments, see shapiq.plots.stacked_bar_plot().
996

997
        Returns:
998
            The stacked bar plot as a tuple containing the figure and the axes.
999

1000
        """
1001
        from shapiq import stacked_bar_plot
1✔
1002

1003
        return stacked_bar_plot(self, show=show, **kwargs)
1✔
1004

1005
    def plot_force(
1✔
1006
        self,
1007
        feature_names: np.ndarray | None = None,
1008
        *,
1009
        show: bool = True,
1010
        abbreviate: bool = True,
1011
        contribution_threshold: float = 0.05,
1012
    ) -> Figure | None:
1013
        """Visualize InteractionValues on a force plot.
1014

1015
        For arguments, see shapiq.plots.force_plot().
1016

1017
        Args:
1018
            feature_names: The feature names used for plotting. If no feature names are provided, the
1019
                feature indices are used instead. Defaults to ``None``.
1020
            show: Whether to show the plot. Defaults to ``False``.
1021
            abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
1022
            contribution_threshold: The threshold for contributions to be displayed in percent.
1023
                Defaults to ``0.05``.
1024

1025
        Returns:
1026
            The force plot as a matplotlib figure (if show is ``False``).
1027

1028
        """
1029
        from .plot import force_plot
1✔
1030

1031
        return force_plot(
1✔
1032
            self,
1033
            feature_names=feature_names,
1034
            show=show,
1035
            abbreviate=abbreviate,
1036
            contribution_threshold=contribution_threshold,
1037
        )
1038

1039
    def plot_waterfall(
1✔
1040
        self,
1041
        feature_names: np.ndarray | None = None,
1042
        *,
1043
        show: bool = True,
1044
        abbreviate: bool = True,
1045
        max_display: int = 10,
1046
    ) -> Axes | None:
1047
        """Draws interaction values on a waterfall plot.
1048

1049
        Note:
1050
            Requires the ``shap`` Python package to be installed.
1051

1052
        Args:
1053
            feature_names: The feature names used for plotting. If no feature names are provided, the
1054
                feature indices are used instead. Defaults to ``None``.
1055
            show: Whether to show the plot. Defaults to ``False``.
1056
            abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
1057
            max_display: The maximum number of interactions to display. Defaults to ``10``.
1058
        """
1059
        from shapiq import waterfall_plot
1✔
1060

1061
        return waterfall_plot(
1✔
1062
            self,
1063
            feature_names=feature_names,
1064
            show=show,
1065
            max_display=max_display,
1066
            abbreviate=abbreviate,
1067
        )
1068

1069
    def plot_sentence(
1✔
1070
        self,
1071
        words: list[str],
1072
        *,
1073
        show: bool = True,
1074
        **kwargs: Any,
1075
    ) -> tuple[Figure, Axes] | None:
1076
        """Plots the first order effects (attributions) of a sentence or paragraph.
1077

1078
        For arguments, see shapiq.plots.sentence_plot().
1079

1080
        Returns:
1081
            If ``show`` is ``True``, the function returns ``None``. Otherwise, it returns a tuple
1082
            with the figure and the axis of the plot.
1083

1084
        """
1085
        from shapiq.plot.sentence import sentence_plot
1✔
1086

1087
        return sentence_plot(self, words, show=show, **kwargs)
1✔
1088

1089
    def plot_upset(self, *, show: bool = True, **kwargs: Any) -> Figure | None:
1✔
1090
        """Plots the upset plot.
1091

1092
        For arguments, see shapiq.plot.upset_plot().
1093

1094
        Returns:
1095
            The upset plot as a matplotlib figure (if show is ``False``).
1096

1097
        """
1098
        from shapiq.plot.upset import upset_plot
1✔
1099

1100
        return upset_plot(self, show=show, **kwargs)
1✔
1101

1102

1103
def aggregate_interaction_values(
1✔
1104
    interaction_values: Sequence[InteractionValues],
1105
    aggregation: str = "mean",
1106
) -> InteractionValues:
1107
    """Aggregates InteractionValues objects using a specific aggregation method.
1108

1109
    Args:
1110
        interaction_values: A list of InteractionValues objects to aggregate.
1111
        aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
1112
            ``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
1113

1114
    Returns:
1115
        The aggregated InteractionValues object.
1116

1117
    Example:
1118
        >>> iv1 = InteractionValues(
1119
        ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
1120
        ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
1121
        ...     index="SII",
1122
        ...     max_order=2,
1123
        ...     n_players=3,
1124
        ...     min_order=1,
1125
        ...     baseline_value=0.0,
1126
        ... )
1127
        >>> iv2 = InteractionValues(
1128
        ...     values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]),  # this iv is missing the (1, 2) value
1129
        ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4},  # no (1, 2)
1130
        ...     index="SII",
1131
        ...     max_order=2,
1132
        ...     n_players=3,
1133
        ...     min_order=1,
1134
        ...     baseline_value=1.0,
1135
        ... )
1136
        >>> aggregate_interaction_values([iv1, iv2], "mean")
1137
        InteractionValues(
1138
            index=SII, max_order=2, min_order=1, estimated=True, estimation_budget=None,
1139
            n_players=3, baseline_value=0.5,
1140
            Top 10 interactions:
1141
                (1, 2): 0.60
1142
                (0, 2): 0.35
1143
                (0, 1): 0.25
1144
                (0,): 0.15
1145
                (1,): 0.25
1146
                (2,): 0.35
1147
        )
1148

1149
    Note:
1150
        The index of the aggregated InteractionValues object is set to the index of the first
1151
        InteractionValues object in the list.
1152

1153
    Raises:
1154
        ValueError: If the aggregation method is not supported.
1155

1156
    """
1157

1158
    def _aggregate(vals: list[float], method: str) -> float:
1✔
1159
        """Does the actual aggregation of the values."""
1160
        if method == "mean":
1✔
1161
            return float(np.mean(vals))
1✔
1162
        if method == "median":
1✔
1163
            return float(np.median(vals))
1✔
1164
        if method == "sum":
1✔
1165
            return np.sum(vals)
1✔
1166
        if method == "max":
1✔
1167
            return np.max(vals)
1✔
1168
        if method == "min":
1✔
1169
            return np.min(vals)
1✔
1170
        msg = f"Aggregation method {method} is not supported."
1✔
1171
        raise ValueError(msg)
1✔
1172

1173
    # get all keys from all InteractionValues objects
1174
    all_keys = set()
1✔
1175
    for iv in interaction_values:
1✔
1176
        all_keys.update(iv.interaction_lookup.keys())
1✔
1177
    all_keys = sorted(all_keys)
1✔
1178

1179
    # aggregate the values
1180
    new_values = np.zeros(len(all_keys), dtype=float)
1✔
1181
    new_lookup = {}
1✔
1182
    for i, key in enumerate(all_keys):
1✔
1183
        new_lookup[key] = i
1✔
1184
        values = [iv[key] for iv in interaction_values]
1✔
1185
        new_values[i] = _aggregate(values, aggregation)
1✔
1186

1187
    max_order = max([iv.max_order for iv in interaction_values])
1✔
1188
    min_order = min([iv.min_order for iv in interaction_values])
1✔
1189
    n_players = max([iv.n_players for iv in interaction_values])
1✔
1190
    baseline_value = _aggregate(
1✔
1191
        [float(iv.baseline_value) for iv in interaction_values], aggregation
1192
    )
1193
    estimation_budget = interaction_values[0].estimation_budget
1✔
1194

1195
    return InteractionValues(
1✔
1196
        values=new_values,
1197
        index=interaction_values[0].index,
1198
        max_order=max_order,
1199
        n_players=n_players,
1200
        min_order=min_order,
1201
        interaction_lookup=new_lookup,
1202
        estimated=True,
1203
        estimation_budget=estimation_budget,
1204
        baseline_value=baseline_value,
1205
    )
1206

1207

1208
def _validate_and_return_interactions(
1✔
1209
    values: np.ndarray | dict[tuple[int, ...], float],
1210
    interaction_lookup: dict[tuple[int, ...], int] | None,
1211
    n_players: int,
1212
    min_order: int,
1213
    max_order: int,
1214
    baseline_value: float | np.number,
1215
) -> dict[tuple[int, ...], float]:
1216
    """Check the interactions for validity and consistency.
1217

1218
    Args:
1219
        values (np.ndarray | dict[tuple[int, ...], float]): The interaction values.
1220
        interaction_lookup (dict[tuple[int, ...], int]): A mapping from interactions to their indices.
1221
        n_players (int): The number of players.
1222
        min_order (int): The minimum order of interactions.
1223
        max_order (int): The maximum order of interactions.
1224
        baseline_value (float | np.number): The baseline value to use for empty interactions.
1225

1226
    Raises:
1227
        TypeError: If the values or interaction_lookup are not of the expected types.
1228
    """
1229
    interactions: dict[tuple[int, ...], float] = {}
1✔
1230
    if interaction_lookup is None:
1✔
1231
        interaction_lookup = generate_interaction_lookup(
1✔
1232
            players=n_players,
1233
            min_order=min_order,
1234
            max_order=max_order,
1235
        )
1236
    if interaction_lookup is not None and not isinstance(interaction_lookup, dict):
1✔
1237
        msg = f"Interaction lookup must be a dictionary. Got {type(interaction_lookup)}."
×
1238
        raise TypeError(msg)
×
1239

1240
    if isinstance(values, dict):
1✔
1241
        interactions = copy.deepcopy(values)
1✔
1242
    else:
1243
        interactions = {
1✔
1244
            interaction: values[index].item() for interaction, index in interaction_lookup.items()
1245
        }
1246

1247
    if min_order == 0 and () not in interactions:
1✔
1248
        interactions[()] = float(baseline_value)
1✔
1249
    return interactions
1✔
1250

1251

1252
def _update_interactions_for_index(
1✔
1253
    interactions: InteractionScores,
1254
    index: str,
1255
    target_index: str,
1256
    max_order: int,
1257
    min_order: int,
1258
    baseline_value: float | np.number,
1259
) -> tuple[InteractionScores, str, int, float]:
1260
    from .game_theory.aggregation import aggregate_base_attributions
1✔
1261

1262
    if is_index_aggregated(target_index) and target_index != index:
1✔
1263
        interactions, index, min_order = aggregate_base_attributions(
1✔
1264
            interactions=interactions,
1265
            index=index,
1266
            order=max_order,
1267
            min_order=min_order,
1268
            baseline_value=float(baseline_value),
1269
        )
1270
    if () in interactions:
1✔
1271
        empty_value = interactions[()]
1✔
1272
        if empty_value != baseline_value and index != "SII":
1✔
1273
            if is_empty_value_the_baseline(index):
1✔
1274
                # insert the empty value given in baseline into the values
1275
                interactions[()] = float(baseline_value)
1✔
1276
            else:  # manually set baseline to the empty value
1277
                baseline_value = interactions[()]
1✔
1278
    elif min_order == 0:
1✔
1279
        # TODO(mmshlk): this might not be what we really want to do always: what if empty and baseline are different?
1280
        # https://github.com/mmschlk/shapiq/issues/385
1281
        interactions[()] = float(baseline_value)
×
1282
    return interactions, index, min_order, float(baseline_value)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc