• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18449618793

12 Oct 2025 09:16PM UTC coverage: 93.266% (-0.6%) from 93.845%
18449618793

Pull #430

github

web-flow
Merge 4a26a5ad3 into dede390c9
Pull Request #430: Enhance type safety and fix bugs across the codebase

278 of 326 new or added lines in 46 files covered. (85.28%)

12 existing lines in 9 files now uncovered.

4986 of 5346 relevant lines covered (93.27%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.22
/src/shapiq/interaction_values.py
1
"""InteractionValues data-class, which is used to store the interaction scores."""
2

3
from __future__ import annotations
1✔
4

5
import contextlib
1✔
6
import copy
1✔
7
import json
1✔
8
import pickle
1✔
9
from pathlib import Path
1✔
10
from typing import TYPE_CHECKING
1✔
11
from warnings import warn
1✔
12

13
import numpy as np
1✔
14

15
from .game_theory.indices import (
1✔
16
    ALL_AVAILABLE_INDICES,
17
    get_index_from_computation_index,
18
    is_empty_value_the_baseline,
19
    is_index_aggregated,
20
    is_index_valid,
21
)
22
from .utils.errors import raise_deprecation_warning
1✔
23
from .utils.saving import safe_str_to_tuple, safe_tuple_to_str
1✔
24
from .utils.sets import generate_interaction_lookup
1✔
25

26
if TYPE_CHECKING:
1✔
27
    from collections.abc import Sequence
×
28
    from typing import Any
×
29

30
    from matplotlib.axes import Axes
×
31
    from matplotlib.figure import Figure
×
32

33
    from shapiq.typing import JSONType
×
34

35

36
SAVE_JSON_DEPRECATION_MSG = (
1✔
37
    "Saving InteractionValues not as a JSON file is deprecated. "
38
    "The parameters `as_pickle` and `as_npz` will be removed in the future. "
39
)
40

41

42
class InteractionValues:
1✔
43
    """This class contains the interaction values as estimated by an approximator.
44

45
    Attributes:
46
        values: The interaction values of the model in vectorized form.
47
        index: The interaction index estimated. All available indices are defined in
48
            ``ALL_AVAILABLE_INDICES``.
49
        max_order: The order of the approximation.
50
        n_players: The number of players.
51
        min_order: The minimum order of the approximation. Defaults to ``0``.
52
        interaction_lookup: A dictionary that maps interactions to their index in the values
53
            vector. If ``interaction_lookup`` is not provided, it is computed from the ``n_players``,
54
            ``min_order``, and `max_order` parameters. Defaults to ``None``.
55
        estimated: Whether the interaction values are estimated or not. Defaults to ``True``.
56
        estimation_budget: The budget used for the estimation. Defaults to ``None``.
57
        baseline_value: The value of the baseline interaction also known as 'empty prediction' or
58
            ``'empty value'`` since it denotes the value of the empty coalition (empty set). If not
59
            provided it is searched for in the values vector (raising an Error if not found).
60
            Defaults to ``None``.
61

62
    Raises:
63
        UserWarning: If the index is not a valid index as defined in ``ALL_AVAILABLE_INDICES``.
64
        TypeError: If the baseline value is not a number.
65

66
    """
67

68
    interactions: dict[tuple[int, ...], float]
1✔
69
    """The interactions as a dictionary mapping interactions to their values."""
1✔
70

71
    def __init__(
1✔
72
        self,
73
        values: np.ndarray | dict[tuple[int, ...], float],
74
        *,
75
        index: str,
76
        max_order: int,
77
        n_players: int,
78
        min_order: int,
79
        interaction_lookup: dict[tuple[int, ...], int] | None = None,  # type: ignore[assignment]
80
        estimated: bool = True,
81
        estimation_budget: int | None = None,  # type: ignore[assignment]
82
        baseline_value: float | np.number = 0.0,
83
        target_index: str | None = None,
84
    ) -> None:
85
        """Initialize the InteractionValues object.
86

87
        Args:
88
            values: The interaction values as a numpy array or a dictionary mapping interactions to their
89
                values.
90

91
            index: The index of the interaction values. This should be one of the indices defined in
92
            ALL_AVAILABLE_INDICES. It is used to determine how the interaction values are interpreted.
93
            max_order: The maximum order of the interactions.
94
            n_players: The number of players in the game.
95
            min_order: The minimum order of the interactions. Defaults to 0.
96
            interaction_lookup: A dictionary mapping interactions to their index in the values vector.
97
            Defaults to None, which means it will be generated from the n_players, min_order, and max_order parameters.
98
            estimated: Whether the interaction values are estimated or not. Defaults to True.
99
            estimation_budget: The budget used for the estimation. Defaults to None.
100
            baseline_value: The baseline value of the interaction values, also known as the empty prediction or empty value.
101
            target_index: The index to which the InteractionValues should be finalized. Defaults to None, which means that
102
            target_index = index
103
        """
104
        if not isinstance(baseline_value, (int | float | np.number)):
1✔
105
            msg = f"Baseline value must be provided as a number. Got {type(baseline_value)}."
1✔
106
            raise TypeError(msg)
1✔
107
        self.baseline_value = baseline_value
1✔
108
        if not is_index_valid(index, raise_error=False):
1✔
109
            warn(
1✔
110
                f"Index `{index}` is not a valid interaction index. "
111
                f"Valid indices are: {', '.join(ALL_AVAILABLE_INDICES)}.",
112
                stacklevel=2,
113
            )
114
        index = get_index_from_computation_index(index, max_order)
1✔
115
        if target_index is None:
1✔
116
            target_index = index
1✔
117

118
        interactions = _validate_and_return_interactions(
1✔
119
            values=values,
120
            interaction_lookup=interaction_lookup,
121
            n_players=n_players,
122
            min_order=min_order,
123
            max_order=max_order,
124
            baseline_value=baseline_value,
125
        )
126

127
        interactions, index, min_order, baseline_value = _update_interactions_for_index(
1✔
128
            interactions=interactions,
129
            index=index,
130
            target_index=target_index,
131
            min_order=min_order,
132
            max_order=max_order,
133
            baseline_value=baseline_value,
134
        )
135

136
        self.interactions = interactions
1✔
137
        self.index = index
1✔
138
        self.max_order = max_order
1✔
139
        self.n_players = n_players
1✔
140
        self.min_order = min_order
1✔
141
        self.estimated = estimated
1✔
142
        self.estimation_budget = estimation_budget
1✔
143

144
    @property
1✔
145
    def dict_values(self) -> dict[tuple[int, ...], float]:
1✔
146
        """Getter for the dict directly mapping from all interactions to scores."""
147
        return self.interactions
1✔
148

149
    @property
1✔
150
    def values(self) -> np.ndarray:
1✔
151
        """Getter for the values of the InteractionValues object.
152

153
        Returns:
154
            The values of the InteractionValues object as a numpy array.
155

156
        """
157
        return np.array(list(self.interactions.values()))
1✔
158

159
    @property
1✔
160
    def interaction_lookup(self) -> dict[tuple[int, ...], int]:
1✔
161
        """Getter for the interaction lookup of the InteractionValues object.
162

163
        Returns:
164
            The interaction lookup of the InteractionValues object as a dictionary mapping interactions
165
            to their index in the values vector.
166

167
        """
168
        return {
1✔
169
            interaction: index for index, (interaction, _) in enumerate(self.interactions.items())
170
        }
171

172
    def to_json_file(
1✔
173
        self,
174
        path: Path,
175
        *,
176
        desc: str | None = None,
177
        created_from: object | None = None,
178
        **kwargs: JSONType,
179
    ) -> None:
180
        """Saves the InteractionValues object to a JSON file.
181

182
        Args:
183
            path: The path to the JSON file.
184
            desc: A description of the InteractionValues object. Defaults to ``None``.
185
            created_from: An object from which the InteractionValues object was created. Defaults to
186
                ``None``.
187
            **kwargs: Additional parameters to store in the metadata of the JSON file.
188
        """
189
        from shapiq.utils.saving import (
1✔
190
            interactions_to_dict,
191
            make_file_metadata,
192
            save_json,
193
        )
194

195
        file_metadata = make_file_metadata(
1✔
196
            object_to_store=self,
197
            data_type="interaction_values",
198
            desc=desc,
199
            created_from=created_from,
200
            parameters=kwargs,
201
        )
202
        json_data = {
1✔
203
            **file_metadata,
204
            "metadata": {
205
                "n_players": self.n_players,
206
                "index": self.index,
207
                "max_order": self.max_order,
208
                "min_order": self.min_order,
209
                "estimated": self.estimated,
210
                "estimation_budget": self.estimation_budget,
211
                "baseline_value": self.baseline_value,
212
            },
213
            "data": interactions_to_dict(interactions=self.dict_values),
214
        }
215
        save_json(json_data, path)
1✔
216

217
    @classmethod
1✔
218
    def from_json_file(cls, path: Path) -> InteractionValues:
1✔
219
        """Loads an InteractionValues object from a JSON file.
220

221
        Args:
222
            path: The path to the JSON file. Note that the path must end with `'.json'`.
223

224
        Returns:
225
            The InteractionValues object loaded from the JSON file.
226

227
        Raises:
228
            ValueError: If the path does not end with `'.json'`.
229
        """
230
        from shapiq.utils.saving import dict_to_lookup_and_values
1✔
231

232
        if not path.name.endswith(".json"):
1✔
233
            msg = f"Path {path} does not end with .json. Cannot load InteractionValues."
×
234
            raise ValueError(msg)
×
235

236
        with path.open("r", encoding="utf-8") as file:
1✔
237
            json_data = json.load(file)
1✔
238

239
        metadata = json_data["metadata"]
1✔
240
        interaction_dict = json_data["data"]
1✔
241
        interaction_lookup, values = dict_to_lookup_and_values(interaction_dict)
1✔
242

243
        return cls(
1✔
244
            values=values,
245
            index=metadata["index"],
246
            max_order=metadata["max_order"],
247
            n_players=metadata["n_players"],
248
            min_order=metadata["min_order"],
249
            interaction_lookup=interaction_lookup,
250
            estimated=metadata["estimated"],
251
            estimation_budget=metadata["estimation_budget"],
252
            baseline_value=metadata["baseline_value"],
253
        )
254

255
    def sparsify(self, threshold: float = 1e-3) -> None:
1✔
256
        """Manually sets values close to zero actually to zero (removing values).
257

258
        Args:
259
            threshold: The threshold value below which interactions are zeroed out. Defaults to
260
                1e-3.
261

262
        """
263
        # find interactions to remove in self.interactions
264
        sparse_interactions = copy.deepcopy(self.interactions)
1✔
265
        for interaction, value in self.interactions.items():
1✔
266
            if np.abs(value) < threshold:
1✔
267
                del sparse_interactions[interaction]
1✔
268
        self.interactions = sparse_interactions
1✔
269

270
    def get_top_k_interactions(self, k: int) -> InteractionValues:
1✔
271
        """Returns the top k interactions.
272

273
        Args:
274
            k: The number of top interactions to return.
275

276
        Returns:
277
            The top k interactions as an InteractionValues object.
278

279
        """
280
        top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
1✔
281
        new_values = np.zeros(k, dtype=float)
1✔
282
        new_interaction_lookup = {}
1✔
283
        for interaction_pos, interaction in enumerate(self.interaction_lookup):
1✔
284
            if interaction_pos in top_k_indices:
1✔
285
                new_position = len(new_interaction_lookup)
1✔
286
                new_values[new_position] = float(self[interaction_pos])
1✔
287
                new_interaction_lookup[interaction] = new_position
1✔
288
        return InteractionValues(
1✔
289
            values=new_values,
290
            index=self.index,
291
            max_order=self.max_order,
292
            n_players=self.n_players,
293
            min_order=self.min_order,
294
            interaction_lookup=new_interaction_lookup,
295
            estimated=self.estimated,
296
            estimation_budget=self.estimation_budget,
297
            baseline_value=self.baseline_value,
298
        )
299

300
    def get_top_k(
1✔
301
        self, k: int, *, as_interaction_values: bool = True
302
    ) -> InteractionValues | tuple[dict, list[tuple]]:
303
        """Returns the top k interactions.
304

305
        Args:
306
            k: The number of top interactions to return.
307
            as_interaction_values: Whether to return the top `k` interactions as an InteractionValues
308
                object. Defaults to ``False``.
309

310
        Returns:
311
            The top k interactions as a dictionary and a sorted list of tuples.
312

313
        Examples:
314
            >>> interaction_values = InteractionValues(
315
            ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
316
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
317
            ...     index="SII",
318
            ...     max_order=2,
319
            ...     n_players=3,
320
            ...     min_order=1,
321
            ...     baseline_value=0.0,
322
            ... )
323
            >>> top_k_interactions, sorted_top_k_interactions = interaction_values.get_top_k(2, False)
324
            >>> top_k_interactions
325
            {(0, 2): 0.5, (1, 0): 0.6}
326
            >>> sorted_top_k_interactions
327
            [((1, 0), 0.6), ((0, 2), 0.5)]
328

329
        """
330
        if as_interaction_values:
1✔
UNCOV
331
            return self.get_top_k_interactions(k)
×
332
        top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
1✔
333
        top_k_interactions = {}
1✔
334
        for interaction, index in self.interaction_lookup.items():
1✔
335
            if index in top_k_indices:
1✔
336
                top_k_interactions[interaction] = self.values[index]
1✔
337
        sorted_top_k_interactions = [
1✔
338
            (interaction, top_k_interactions[interaction])
339
            for interaction in sorted(
340
                top_k_interactions, key=lambda x: top_k_interactions[x], reverse=True
341
            )
342
        ]
343
        return top_k_interactions, sorted_top_k_interactions
1✔
344

345
    def __repr__(self) -> str:
1✔
346
        """Returns the representation of the InteractionValues object."""
347
        representation = "InteractionValues(\n"
1✔
348
        representation += (
1✔
349
            f"    index={self.index}, max_order={self.max_order}, min_order={self.min_order}"
350
            f", estimated={self.estimated}, estimation_budget={self.estimation_budget},\n"
351
            f"    n_players={self.n_players}, baseline_value={self.baseline_value}\n)"
352
        )
353
        return representation
1✔
354

355
    def __str__(self) -> str:
1✔
356
        """Returns the string representation of the InteractionValues object."""
357
        representation = self.__repr__()
1✔
358
        representation = representation[:-2]  # remove the last "\n)" and add values
1✔
359
        _, sorted_top_10_interactions = self.get_top_k(
1✔
360
            10, as_interaction_values=False
361
        )  # get top 10 interactions
362
        # add values to string representation
363
        representation += ",\n    Top 10 interactions:\n"
1✔
364
        for interaction, value in sorted_top_10_interactions:
1✔
365
            representation += f"        {interaction}: {value}\n"
1✔
366
        representation += ")"
1✔
367
        return representation
1✔
368

369
    def __len__(self) -> int:
1✔
370
        """Returns the length of the InteractionValues object."""
371
        return len(self.values)  # might better to return the theoretical no. of interactions
1✔
372

373
    def __iter__(self) -> np.nditer:
1✔
374
        """Returns an iterator over the values of the InteractionValues object."""
375
        return np.nditer(self.values)
1✔
376

377
    def __getitem__(self, item: int | tuple[int, ...]) -> float:
1✔
378
        """Returns the score for the given interaction.
379

380
        Args:
381
            item: The interaction as a tuple of integers for which to return the score. If ``item`` is
382
                an integer it serves as the index to the values vector.
383

384
        Returns:
385
            The interaction value. If the interaction is not present zero is returned.
386

387
        """
388
        if isinstance(item, int):
1✔
389
            return float(self.values[item])
1✔
390
        item = tuple(sorted(item))
1✔
391
        try:
1✔
392
            return float(self.interactions[item])
1✔
393
        except KeyError:
1✔
394
            return 0.0
1✔
395

396
    def __setitem__(self, item: int | tuple[int, ...], value: float) -> None:
1✔
397
        """Sets the score for the given interaction.
398

399
        Args:
400
            item: The interaction as a tuple of integers for which to set the score. If ``item`` is an
401
                integer it serves as the index to the values vector.
402
            value: The value to set for the interaction.
403

404
        Raises:
405
            KeyError: If the interaction is not found in the InteractionValues object.
406

407
        """
408
        try:
1✔
409
            if isinstance(item, int):
1✔
410
                # dict.items() preserves the order of insertion, so we can use it to set the value
411
                for i, (interaction, _) in enumerate(self.interactions.items()):
1✔
412
                    if i == item:
1✔
413
                        self.interactions[interaction] = value
1✔
414
                        break
1✔
415
            else:
416
                item = tuple(sorted(item))
1✔
417
                if self.interactions[item] is not None:
1✔
418
                    # if the interaction is already present, update its value. Otherwise KeyError is raised
419
                    self.interactions[item] = value
1✔
420
        except Exception as e:
1✔
421
            msg = f"Interaction {item} not found in the InteractionValues. Unable to set a value."
1✔
422
            raise KeyError(msg) from e
1✔
423

424
    def __eq__(self, other: object) -> bool:
1✔
425
        """Checks if two InteractionValues objects are equal.
426

427
        Args:
428
            other: The other InteractionValues object.
429

430
        Returns:
431
            True if the two objects are equal, False otherwise.
432

433
        """
434
        if not isinstance(other, InteractionValues):
1✔
435
            msg = "Cannot compare InteractionValues with other types."
1✔
436
            raise TypeError(msg)
1✔
437
        if (
1✔
438
            self.index != other.index
439
            or self.max_order != other.max_order
440
            or self.min_order != other.min_order
441
            or self.n_players != other.n_players
442
            or not np.allclose(self.baseline_value, other.baseline_value)
443
        ):
444
            return False
1✔
445
        if not np.allclose(self.values, other.values):
1✔
446
            return False
1✔
447
        return self.interaction_lookup == other.interaction_lookup
1✔
448

449
    def __ne__(self, other: object) -> bool:
1✔
450
        """Checks if two InteractionValues objects are not equal.
451

452
        Args:
453
            other: The other InteractionValues object.
454

455
        Returns:
456
            True if the two objects are not equal, False otherwise.
457

458
        """
459
        return not self.__eq__(other)
1✔
460

461
    def __hash__(self) -> int:
1✔
462
        """Returns the hash of the InteractionValues object."""
463
        return hash(
1✔
464
            (
465
                self.index,
466
                self.max_order,
467
                self.min_order,
468
                self.n_players,
469
                tuple(self.values.flatten()),
470
            ),
471
        )
472

473
    def __copy__(self) -> InteractionValues:
1✔
474
        """Returns a copy of the InteractionValues object."""
475
        return InteractionValues(
1✔
476
            values=copy.deepcopy(self.values),
477
            index=self.index,
478
            max_order=self.max_order,
479
            estimated=self.estimated,
480
            estimation_budget=self.estimation_budget,
481
            n_players=self.n_players,
482
            interaction_lookup=copy.deepcopy(self.interaction_lookup),
483
            min_order=self.min_order,
484
            baseline_value=self.baseline_value,
485
        )
486

487
    def __add__(self, other: InteractionValues | float) -> InteractionValues:
1✔
488
        """Adds two InteractionValues objects together or a scalar."""
489
        n_players, min_order, max_order = self.n_players, self.min_order, self.max_order
1✔
490
        if isinstance(other, InteractionValues):
1✔
491
            if self.index != other.index:  # different indices
1✔
492
                msg = (
1✔
493
                    f"The indices of the InteractionValues objects are different: "
494
                    f"{self.index} != {other.index}. Addition might not be meaningful."
495
                )
496
                warn(msg, stacklevel=2)
1✔
497
            if (
1✔
498
                self.interaction_lookup != other.interaction_lookup
499
                or self.n_players != other.n_players
500
                or self.min_order != other.min_order
501
                or self.max_order != other.max_order
502
            ):  # different interactions but addable
503
                added_interactions = self.interactions.copy()
1✔
504
                for interaction in other.interactions:
1✔
505
                    if interaction not in added_interactions:
1✔
506
                        added_interactions[interaction] = other.interactions[interaction]
1✔
507
                    else:
508
                        added_interactions[interaction] += other.interactions[interaction]
1✔
509
                interaction_lookup = {
1✔
510
                    interaction: i for i, interaction in enumerate(added_interactions)
511
                }
512
                # adjust n_players, min_order, and max_order
513
                n_players = max(self.n_players, other.n_players)
1✔
514
                min_order = min(self.min_order, other.min_order)
1✔
515
                max_order = max(self.max_order, other.max_order)
1✔
516
                baseline_value = self.baseline_value + other.baseline_value
1✔
517
            else:  # basic case with same interactions
518
                added_interactions = {
1✔
519
                    interaction: self.interactions[interaction] + other.interactions[interaction]
520
                    for interaction in self.interactions
521
                }
522
                interaction_lookup = self.interaction_lookup
1✔
523
                baseline_value = self.baseline_value + other.baseline_value
1✔
524
        elif isinstance(other, int | float):
1✔
525
            added_interactions = {
1✔
526
                interaction: self.interactions[interaction] + other
527
                for interaction in self.interactions
528
            }
529
            interaction_lookup = self.interaction_lookup.copy()
1✔
530
            baseline_value = self.baseline_value + other
1✔
531
        else:
532
            msg = f"Cannot add InteractionValues with object of type {type(other)}."
1✔
533
            raise TypeError(msg)
1✔
534

535
        return InteractionValues(
1✔
536
            values=added_interactions,
537
            index=self.index,
538
            max_order=max_order,
539
            n_players=n_players,
540
            min_order=min_order,
541
            interaction_lookup=interaction_lookup,
542
            estimated=self.estimated,
543
            estimation_budget=self.estimation_budget,
544
            baseline_value=baseline_value,
545
        )
546

547
    def __radd__(self, other: InteractionValues | float) -> InteractionValues:
1✔
548
        """Adds two InteractionValues objects together or a scalar."""
549
        return self.__add__(other)
1✔
550

551
    def __neg__(self) -> InteractionValues:
1✔
552
        """Negates the InteractionValues object."""
553
        return InteractionValues(
1✔
554
            values=-self.values,
555
            index=self.index,
556
            max_order=self.max_order,
557
            n_players=self.n_players,
558
            min_order=self.min_order,
559
            interaction_lookup=self.interaction_lookup,
560
            estimated=self.estimated,
561
            estimation_budget=self.estimation_budget,
562
            baseline_value=-self.baseline_value,
563
        )
564

565
    def __sub__(self, other: InteractionValues | float) -> InteractionValues:
1✔
566
        """Subtracts two InteractionValues objects or a scalar."""
567
        return self.__add__(-other)
1✔
568

569
    def __rsub__(self, other: InteractionValues | float) -> InteractionValues:
1✔
570
        """Subtracts two InteractionValues objects or a scalar."""
571
        return (-self).__add__(other)
1✔
572

573
    def __mul__(self, other: float) -> InteractionValues:
1✔
574
        """Multiplies an InteractionValues object by a scalar."""
575
        interactions = {
1✔
576
            interaction: value * other for interaction, value in self.interactions.items()
577
        }
578
        return InteractionValues(
1✔
579
            values=interactions,
580
            index=self.index,
581
            max_order=self.max_order,
582
            n_players=self.n_players,
583
            min_order=self.min_order,
584
            interaction_lookup=self.interaction_lookup,
585
            estimated=self.estimated,
586
            estimation_budget=self.estimation_budget,
587
            baseline_value=self.baseline_value * other,
588
        )
589

590
    def __rmul__(self, other: float) -> InteractionValues:
1✔
591
        """Multiplies an InteractionValues object by a scalar."""
592
        return self.__mul__(other)
1✔
593

594
    def __abs__(self) -> InteractionValues:
1✔
595
        """Returns the absolute values of the InteractionValues object."""
596
        interactions = {interaction: abs(value) for interaction, value in self.interactions.items()}
1✔
597
        return InteractionValues(
1✔
598
            values=interactions,
599
            index=self.index,
600
            max_order=self.max_order,
601
            n_players=self.n_players,
602
            min_order=self.min_order,
603
            interaction_lookup=self.interaction_lookup,
604
            estimated=self.estimated,
605
            estimation_budget=self.estimation_budget,
606
            baseline_value=self.baseline_value,
607
        )
608

609
    def get_n_order_values(self, order: int) -> np.ndarray:
1✔
610
        """Returns the interaction values of a specific order as a numpy array.
611

612
        Note:
613
            Depending on the order and number of players the resulting array might be sparse and
614
            very large.
615

616
        Args:
617
            order: The order of the interactions to return.
618

619
        Returns:
620
            The interaction values of the specified order as a numpy array of shape ``(n_players,)``
621
            for order ``1`` and ``(n_players, n_players)`` for order ``2``, etc.
622

623
        Raises:
624
            ValueError: If the order is less than ``1``.
625

626
        """
627
        from itertools import permutations
1✔
628

629
        if order < 1:
1✔
630
            msg = "Order must be greater or equal to 1."
1✔
631
            raise ValueError(msg)
1✔
632
        values_shape = tuple([self.n_players] * order)
1✔
633
        values = np.zeros(values_shape, dtype=float)
1✔
634
        for interaction in self.interaction_lookup:
1✔
635
            if len(interaction) != order:
1✔
636
                continue
1✔
637
            # get all orderings of the interaction (e.g. (0, 1) and (1, 0) for interaction (0, 1))
638
            for perm in permutations(interaction):
1✔
639
                values[perm] = self[interaction]
1✔
640

641
        return values
1✔
642

643
    def get_n_order(
1✔
644
        self,
645
        order: int | None = None,
646
        min_order: int | None = None,
647
        max_order: int | None = None,
648
    ) -> InteractionValues:
649
        """Select particular order of interactions.
650

651
        Creates a new InteractionValues object containing only the interactions within the
652
        specified order range.
653

654
        You can specify:
655
            - `order`: to select interactions of a single specific order (e.g., all pairwise
656
                interactions).
657
            - `min_order` and/or `max_order`: to select a range of interaction orders.
658
            - If `order` and `min_order`/`max_order` are both set, `min_order` and `max_order` will
659
                override the `order` value.
660

661
        Example:
662
            >>> interaction_values = InteractionValues(
663
            ...     values=np.array([1, 2, 3, 4, 5, 6, 7]),
664
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5, (0, 1, 2): 6},
665
            ...     index="SII",
666
            ...     max_order=3,
667
            ...     n_players=3,
668
            ...     min_order=1,
669
            ...     baseline_value=0.0,
670
            ... )
671
            >>> interaction_values.get_n_order(order=1).dict_values
672
            {(0,): 1.0, (1,): 2.0, (2,): 3.0}
673
            >>> interaction_values.get_n_order(min_order=1, max_order=2).dict_values
674
            {(0,): 1.0, (1,): 2.0, (2,): 3.0, (0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0}
675
            >>> interaction_values.get_n_order(min_order=2).dict_values
676
            {(0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0, (0, 1, 2): 7.0}
677

678
        Args:
679
            order: The order of the interactions to return. Defaults to ``None`` which requires
680
                ``min_order`` or ``max_order`` to be set.
681
            min_order: The minimum order of the interactions to return. Defaults to ``None`` which
682
                sets it to the order.
683
            max_order: The maximum order of the interactions to return. Defaults to ``None`` which
684
                sets it to the order.
685

686
        Returns:
687
            The interaction values of the specified order.
688

689
        Raises:
690
            ValueError: If all three parameters are set to ``None``.
691
        """
692
        if order is None and min_order is None and max_order is None:
1✔
693
            msg = "Either order, min_order or max_order must be set."
1✔
694
            raise ValueError(msg)
1✔
695

696
        if order is not None:
1✔
697
            max_order = order if max_order is None else max_order
1✔
698
            min_order = order if min_order is None else min_order
1✔
699
        else:  # order is None
700
            min_order = self.min_order if min_order is None else min_order
1✔
701
            max_order = self.max_order if max_order is None else max_order
1✔
702

703
        if min_order > max_order:
1✔
704
            msg = f"min_order ({min_order}) must be less than or equal to max_order ({max_order})."
1✔
705
            raise ValueError(msg)
1✔
706

707
        new_values = []
1✔
708
        new_interaction_lookup = {}
1✔
709
        for interaction in self.interaction_lookup:
1✔
710
            if len(interaction) < min_order or len(interaction) > max_order:
1✔
711
                continue
1✔
712
            interaction_idx = len(new_interaction_lookup)
1✔
713
            new_values.append(self[interaction])
1✔
714
            new_interaction_lookup[interaction] = interaction_idx
1✔
715

716
        return InteractionValues(
1✔
717
            values=np.array(new_values),
718
            index=self.index,
719
            max_order=max_order,
720
            n_players=self.n_players,
721
            min_order=min_order,
722
            interaction_lookup=new_interaction_lookup,
723
            estimated=self.estimated,
724
            estimation_budget=self.estimation_budget,
725
            baseline_value=self.baseline_value,
726
        )
727

728
    def get_subset(self, players: list[int]) -> InteractionValues:
1✔
729
        """Selects a subset of players from the InteractionValues object.
730

731
        Args:
732
            players: List of players to select from the InteractionValues object.
733

734
        Returns:
735
            InteractionValues: Filtered InteractionValues object containing only values related to
736
            selected players.
737

738
        Example:
739
            >>> interaction_values = InteractionValues(
740
            ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
741
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
742
            ...     index="SII",
743
            ...     max_order=2,
744
            ...     n_players=3,
745
            ...     min_order=1,
746
            ...     baseline_value=0.0,
747
            ... )
748
            >>> interaction_values.get_subset([0, 1]).dict_values
749
            {(0,): 0.1, (1,): 0.2, (0, 1): 0.3}
750
            >>> interaction_values.get_subset([0, 2]).dict_values
751
            {(0,): 0.1, (2,): 0.3, (0, 2): 0.4}
752
            >>> interaction_values.get_subset([1]).dict_values
753
            {(1,): 0.2}
754

755
        """
756
        keys = self.interaction_lookup.keys()
1✔
757
        idx, keys_in_subset = [], []
1✔
758
        for i, key in enumerate(keys):
1✔
759
            if all(p in players for p in key):
1✔
760
                idx.append(i)
1✔
761
                keys_in_subset.append(key)
1✔
762
        new_values = self.values[idx]
1✔
763
        new_interaction_lookup = {key: index for index, key in enumerate(keys_in_subset)}
1✔
764
        n_players = self.n_players - len(players)
1✔
765
        return InteractionValues(
1✔
766
            values=new_values,
767
            index=self.index,
768
            max_order=self.max_order,
769
            n_players=n_players,
770
            min_order=self.min_order,
771
            interaction_lookup=new_interaction_lookup,
772
            estimated=self.estimated,
773
            estimation_budget=self.estimation_budget,
774
            baseline_value=self.baseline_value,
775
        )
776

777
    def save(self, path: Path, *, as_pickle: bool = False, as_npz: bool = False) -> None:
1✔
778
        """Save the InteractionValues object to a file.
779

780
        By default, the InteractionValues object is saved as a JSON file.
781

782
        Args:
783
            path: The path to save the InteractionValues object to.
784
            as_pickle: Whether to save the InteractionValues object as a pickle file (``True``).
785
            as_npz: Whether to save the InteractionValues object as a ``npz`` file (``True``).
786

787
        Raises:
788
            DeprecationWarning: If `as_pickle` or `as_npz` is set to ``True``, a deprecation
789
                warning is raised
790
        """
791
        # check if the directory exists
792
        directory = Path(path).parent
1✔
793
        if not Path(directory).exists():
1✔
794
            with contextlib.suppress(FileNotFoundError):
×
795
                Path(directory).mkdir(parents=True, exist_ok=True)
×
796
        if as_pickle:
1✔
797
            raise_deprecation_warning(
1✔
798
                message=SAVE_JSON_DEPRECATION_MSG,
799
                deprecated_in="1.3.1",
800
                removed_in="1.4.0",
801
            )
802
            with Path(path).open("wb") as file:
1✔
803
                pickle.dump(self, file)
1✔
804
        elif as_npz:
1✔
805
            raise_deprecation_warning(
1✔
806
                message=SAVE_JSON_DEPRECATION_MSG,
807
                deprecated_in="1.3.1",
808
                removed_in="1.4.0",
809
            )
810
            # save object as npz file
811
            interaction_keys = np.array(
1✔
812
                list(map(safe_tuple_to_str, self.interaction_lookup.keys()))
813
            )
814
            interaction_indices = np.array(list(self.interaction_lookup.values()))
1✔
815
            estimation_budget = self.estimation_budget if self.estimation_budget is not None else -1
1✔
816

817
            np.savez(
1✔
818
                path,
819
                values=self.values,
820
                index=self.index,
821
                max_order=self.max_order,
822
                n_players=self.n_players,
823
                min_order=self.min_order,
824
                interaction_lookup_keys=interaction_keys,
825
                interaction_lookup_indices=interaction_indices,
826
                estimated=self.estimated,
827
                estimation_budget=estimation_budget,
828
                baseline_value=self.baseline_value,
829
            )
830
        else:
831
            self.to_json_file(path)
1✔
832

833
    @classmethod
1✔
834
    def load(cls, path: Path | str) -> InteractionValues:
1✔
835
        """Load an InteractionValues object from a file.
836

837
        Args:
838
            path: The path to load the InteractionValues object from.
839

840
        Returns:
841
            The loaded InteractionValues object.
842

843
        """
844
        path = Path(path)
1✔
845
        # check if path ends with .json
846
        if path.name.endswith(".json"):
1✔
847
            return cls.from_json_file(path)
1✔
848

849
        raise_deprecation_warning(
1✔
850
            SAVE_JSON_DEPRECATION_MSG, deprecated_in="1.3.1", removed_in="1.4.0"
851
        )
852

853
        # try loading as npz file
854
        if path.name.endswith(".npz"):
1✔
855
            data = np.load(path, allow_pickle=True)
1✔
856
            try:
1✔
857
                # try to load Pyright save format
858
                interaction_lookup = {
1✔
859
                    safe_str_to_tuple(key): int(value)
860
                    for key, value in zip(
861
                        data["interaction_lookup_keys"],
862
                        data["interaction_lookup_indices"],
863
                        strict=False,
864
                    )
865
                }
NEW
866
            except KeyError:
×
867
                # fallback to old format
NEW
868
                interaction_lookup = data["interaction_lookup"].item()
×
869
            estimation_budget = data["estimation_budget"].item()
1✔
870
            if estimation_budget == -1:
1✔
871
                estimation_budget = None
1✔
872
            return InteractionValues(
1✔
873
                values=data["values"],
874
                index=str(data["index"]),
875
                max_order=int(data["max_order"]),
876
                n_players=int(data["n_players"]),
877
                min_order=int(data["min_order"]),
878
                interaction_lookup=interaction_lookup,
879
                estimated=bool(data["estimated"]),
880
                estimation_budget=estimation_budget,
881
                baseline_value=float(data["baseline_value"]),
882
            )
883
        msg = f"Path {path} does not end with .json or .npz. Cannot load InteractionValues."
×
884
        raise ValueError(msg)
×
885

886
    @classmethod
1✔
887
    def from_dict(cls, data: dict[str, Any]) -> InteractionValues:
1✔
888
        """Create an InteractionValues object from a dictionary.
889

890
        Args:
891
            data: The dictionary containing the data to create the InteractionValues object from.
892

893
        Returns:
894
            The InteractionValues object created from the dictionary.
895

896
        """
897
        return cls(
1✔
898
            values=data["values"],
899
            index=data["index"],
900
            max_order=data["max_order"],
901
            n_players=data["n_players"],
902
            min_order=data["min_order"],
903
            interaction_lookup=data["interaction_lookup"],
904
            estimated=data["estimated"],
905
            estimation_budget=data["estimation_budget"],
906
            baseline_value=data["baseline_value"],
907
        )
908

909
    def to_dict(self) -> dict:
1✔
910
        """Convert the InteractionValues object to a dictionary.
911

912
        Returns:
913
            The InteractionValues object as a dictionary.
914

915
        """
916
        return {
1✔
917
            "values": self.interactions,
918
            "index": self.index,
919
            "max_order": self.max_order,
920
            "n_players": self.n_players,
921
            "min_order": self.min_order,
922
            "interaction_lookup": self.interaction_lookup,
923
            "estimated": self.estimated,
924
            "estimation_budget": self.estimation_budget,
925
            "baseline_value": self.baseline_value,
926
        }
927

928
    def aggregate(
1✔
929
        self,
930
        others: Sequence[InteractionValues],
931
        aggregation: str = "mean",
932
    ) -> InteractionValues:
933
        """Aggregates InteractionValues objects using a specific aggregation method.
934

935
        Args:
936
            others: A list of InteractionValues objects to aggregate.
937
            aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
938
                ``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
939

940
        Returns:
941
            The aggregated InteractionValues object.
942

943
        Note:
944
            For documentation on the aggregation methods, see the ``aggregate_interaction_values()``
945
            function.
946

947
        """
948
        return aggregate_interaction_values([self, *others], aggregation)
1✔
949

950
    def plot_network(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
951
        """Visualize InteractionValues on a graph.
952

953
        Note:
954
            For arguments, see :func:`shapiq.plot.network.network_plot` and
955
                :func:`shapiq.plot.si_graph.si_graph_plot`.
956

957
        Args:
958
            show: Whether to show the plot. Defaults to ``True``.
959

960
            **kwargs: Additional keyword arguments to pass to the plotting function.
961

962
        Returns:
963
            If show is ``False``, the function returns a tuple with the figure and the axis of the
964
                plot.
965
        """
966
        from shapiq.plot.network import network_plot
1✔
967

968
        if self.max_order > 1:
1✔
969
            return network_plot(
1✔
970
                interaction_values=self,
971
                show=show,
972
                **kwargs,
973
            )
974
        msg = (
1✔
975
            "InteractionValues contains only 1-order values,"
976
            "but requires also 2-order values for the network plot."
977
        )
978
        raise ValueError(msg)
1✔
979

980
    def plot_si_graph(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
981
        """Visualize InteractionValues as a SI graph.
982

983
        For arguments, see shapiq.plots.si_graph_plot().
984

985
        Returns:
986
            The SI graph as a tuple containing the figure and the axes.
987

988
        """
989
        from shapiq.plot.si_graph import si_graph_plot
1✔
990

991
        return si_graph_plot(self, show=show, **kwargs)
1✔
992

993
    def plot_stacked_bar(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
994
        """Visualize InteractionValues on a graph.
995

996
        For arguments, see shapiq.plots.stacked_bar_plot().
997

998
        Returns:
999
            The stacked bar plot as a tuple containing the figure and the axes.
1000

1001
        """
1002
        from shapiq import stacked_bar_plot
1✔
1003

1004
        return stacked_bar_plot(self, show=show, **kwargs)
1✔
1005

1006
    def plot_force(
1✔
1007
        self,
1008
        feature_names: np.ndarray | None = None,
1009
        *,
1010
        show: bool = True,
1011
        abbreviate: bool = True,
1012
        contribution_threshold: float = 0.05,
1013
    ) -> Figure | None:
1014
        """Visualize InteractionValues on a force plot.
1015

1016
        For arguments, see shapiq.plots.force_plot().
1017

1018
        Args:
1019
            feature_names: The feature names used for plotting. If no feature names are provided, the
1020
                feature indices are used instead. Defaults to ``None``.
1021
            show: Whether to show the plot. Defaults to ``False``.
1022
            abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
1023
            contribution_threshold: The threshold for contributions to be displayed in percent.
1024
                Defaults to ``0.05``.
1025

1026
        Returns:
1027
            The force plot as a matplotlib figure (if show is ``False``).
1028

1029
        """
1030
        from .plot import force_plot
1✔
1031

1032
        return force_plot(
1✔
1033
            self,
1034
            feature_names=feature_names,
1035
            show=show,
1036
            abbreviate=abbreviate,
1037
            contribution_threshold=contribution_threshold,
1038
        )
1039

1040
    def plot_waterfall(
1✔
1041
        self,
1042
        feature_names: np.ndarray | None = None,
1043
        *,
1044
        show: bool = True,
1045
        abbreviate: bool = True,
1046
        max_display: int = 10,
1047
    ) -> Axes | None:
1048
        """Draws interaction values on a waterfall plot.
1049

1050
        Note:
1051
            Requires the ``shap`` Python package to be installed.
1052

1053
        Args:
1054
            feature_names: The feature names used for plotting. If no feature names are provided, the
1055
                feature indices are used instead. Defaults to ``None``.
1056
            show: Whether to show the plot. Defaults to ``False``.
1057
            abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
1058
            max_display: The maximum number of interactions to display. Defaults to ``10``.
1059
        """
1060
        from shapiq import waterfall_plot
1✔
1061

1062
        return waterfall_plot(
1✔
1063
            self,
1064
            feature_names=feature_names,
1065
            show=show,
1066
            max_display=max_display,
1067
            abbreviate=abbreviate,
1068
        )
1069

1070
    def plot_sentence(
1✔
1071
        self,
1072
        words: list[str],
1073
        *,
1074
        show: bool = True,
1075
        **kwargs: Any,
1076
    ) -> tuple[Figure, Axes] | None:
1077
        """Plots the first order effects (attributions) of a sentence or paragraph.
1078

1079
        For arguments, see shapiq.plots.sentence_plot().
1080

1081
        Returns:
1082
            If ``show`` is ``True``, the function returns ``None``. Otherwise, it returns a tuple
1083
            with the figure and the axis of the plot.
1084

1085
        """
1086
        from shapiq.plot.sentence import sentence_plot
1✔
1087

1088
        return sentence_plot(self, words, show=show, **kwargs)
1✔
1089

1090
    def plot_upset(self, *, show: bool = True, **kwargs: Any) -> Figure | None:
1✔
1091
        """Plots the upset plot.
1092

1093
        For arguments, see shapiq.plot.upset_plot().
1094

1095
        Returns:
1096
            The upset plot as a matplotlib figure (if show is ``False``).
1097

1098
        """
1099
        from shapiq.plot.upset import upset_plot
1✔
1100

1101
        return upset_plot(self, show=show, **kwargs)
1✔
1102

1103

1104
def aggregate_interaction_values(
1✔
1105
    interaction_values: Sequence[InteractionValues],
1106
    aggregation: str = "mean",
1107
) -> InteractionValues:
1108
    """Aggregates InteractionValues objects using a specific aggregation method.
1109

1110
    Args:
1111
        interaction_values: A list of InteractionValues objects to aggregate.
1112
        aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
1113
            ``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
1114

1115
    Returns:
1116
        The aggregated InteractionValues object.
1117

1118
    Example:
1119
        >>> iv1 = InteractionValues(
1120
        ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
1121
        ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
1122
        ...     index="SII",
1123
        ...     max_order=2,
1124
        ...     n_players=3,
1125
        ...     min_order=1,
1126
        ...     baseline_value=0.0,
1127
        ... )
1128
        >>> iv2 = InteractionValues(
1129
        ...     values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]),  # this iv is missing the (1, 2) value
1130
        ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4},  # no (1, 2)
1131
        ...     index="SII",
1132
        ...     max_order=2,
1133
        ...     n_players=3,
1134
        ...     min_order=1,
1135
        ...     baseline_value=1.0,
1136
        ... )
1137
        >>> aggregate_interaction_values([iv1, iv2], "mean")
1138
        InteractionValues(
1139
            index=SII, max_order=2, min_order=1, estimated=True, estimation_budget=None,
1140
            n_players=3, baseline_value=0.5,
1141
            Top 10 interactions:
1142
                (1, 2): 0.60
1143
                (0, 2): 0.35
1144
                (0, 1): 0.25
1145
                (0,): 0.15
1146
                (1,): 0.25
1147
                (2,): 0.35
1148
        )
1149

1150
    Note:
1151
        The index of the aggregated InteractionValues object is set to the index of the first
1152
        InteractionValues object in the list.
1153

1154
    Raises:
1155
        ValueError: If the aggregation method is not supported.
1156

1157
    """
1158

1159
    def _aggregate(vals: list[float], method: str) -> float:
1✔
1160
        """Does the actual aggregation of the values."""
1161
        if method == "mean":
1✔
1162
            return float(np.mean(vals))
1✔
1163
        if method == "median":
1✔
1164
            return float(np.median(vals))
1✔
1165
        if method == "sum":
1✔
1166
            return np.sum(vals)
1✔
1167
        if method == "max":
1✔
1168
            return np.max(vals)
1✔
1169
        if method == "min":
1✔
1170
            return np.min(vals)
1✔
1171
        msg = f"Aggregation method {method} is not supported."
1✔
1172
        raise ValueError(msg)
1✔
1173

1174
    # get all keys from all InteractionValues objects
1175
    all_keys = set()
1✔
1176
    for iv in interaction_values:
1✔
1177
        all_keys.update(iv.interaction_lookup.keys())
1✔
1178
    all_keys = sorted(all_keys)
1✔
1179

1180
    # aggregate the values
1181
    new_values = np.zeros(len(all_keys), dtype=float)
1✔
1182
    new_lookup = {}
1✔
1183
    for i, key in enumerate(all_keys):
1✔
1184
        new_lookup[key] = i
1✔
1185
        values = [iv[key] for iv in interaction_values]
1✔
1186
        new_values[i] = _aggregate(values, aggregation)
1✔
1187

1188
    max_order = max([iv.max_order for iv in interaction_values])
1✔
1189
    min_order = min([iv.min_order for iv in interaction_values])
1✔
1190
    n_players = max([iv.n_players for iv in interaction_values])
1✔
1191
    baseline_value = _aggregate(
1✔
1192
        [float(iv.baseline_value) for iv in interaction_values], aggregation
1193
    )
1194
    estimation_budget = interaction_values[0].estimation_budget
1✔
1195

1196
    return InteractionValues(
1✔
1197
        values=new_values,
1198
        index=interaction_values[0].index,
1199
        max_order=max_order,
1200
        n_players=n_players,
1201
        min_order=min_order,
1202
        interaction_lookup=new_lookup,
1203
        estimated=True,
1204
        estimation_budget=estimation_budget,
1205
        baseline_value=baseline_value,
1206
    )
1207

1208

1209
def _validate_and_return_interactions(
1✔
1210
    values: np.ndarray | dict[tuple[int, ...], float],
1211
    interaction_lookup: dict[tuple[int, ...], int] | None,
1212
    n_players: int,
1213
    min_order: int,
1214
    max_order: int,
1215
    baseline_value: float | np.number,
1216
) -> dict[tuple[int, ...], float]:
1217
    """Check the interactions for validity and consistency.
1218

1219
    Args:
1220
        values (np.ndarray | dict[tuple[int, ...], float]): The interaction values.
1221
        interaction_lookup (dict[tuple[int, ...], int]): A mapping from interactions to their indices.
1222
        n_players (int): The number of players.
1223
        min_order (int): The minimum order of interactions.
1224
        max_order (int): The maximum order of interactions.
1225
        baseline_value (float | np.number): The baseline value to use for empty interactions.
1226

1227
    Raises:
1228
        TypeError: If the values or interaction_lookup are not of the expected types.
1229
    """
1230
    interactions: dict[tuple[int, ...], float] = {}
1✔
1231
    if interaction_lookup is None:
1✔
1232
        interaction_lookup = generate_interaction_lookup(
1✔
1233
            players=n_players,
1234
            min_order=min_order,
1235
            max_order=max_order,
1236
        )
1237
    if interaction_lookup is not None and not isinstance(interaction_lookup, dict):
1✔
1238
        msg = f"Interaction lookup must be a dictionary. Got {type(interaction_lookup)}."
×
1239
        raise TypeError(msg)
×
1240

1241
    if isinstance(values, dict):
1✔
1242
        interactions = copy.deepcopy(values)
1✔
1243
    else:
1244
        interactions = {
1✔
1245
            interaction: values[index].item() for interaction, index in interaction_lookup.items()
1246
        }
1247

1248
    if min_order == 0 and () not in interactions:
1✔
1249
        interactions[()] = float(baseline_value)
1✔
1250
    return interactions
1✔
1251

1252

1253
def _update_interactions_for_index(
1✔
1254
    interactions: dict[tuple[int, ...], float],
1255
    index: str,
1256
    target_index: str,
1257
    max_order: int,
1258
    min_order: int,
1259
    baseline_value: float | np.number,
1260
) -> tuple[dict[tuple[int, ...], float], str, int, float]:
1261
    from .game_theory.aggregation import aggregate_base_attributions
1✔
1262

1263
    if is_index_aggregated(target_index) and target_index != index:
1✔
1264
        interactions, index, min_order = aggregate_base_attributions(
1✔
1265
            interactions=interactions,
1266
            index=index,
1267
            order=max_order,
1268
            min_order=min_order,
1269
            baseline_value=float(baseline_value),
1270
        )
1271
    if () in interactions:
1✔
1272
        empty_value = interactions[()]
1✔
1273
        if empty_value != baseline_value and index != "SII":
1✔
1274
            if is_empty_value_the_baseline(index):
1✔
1275
                # insert the empty value given in baseline into the values
1276
                interactions[()] = float(baseline_value)
1✔
1277
            else:  # manually set baseline to the empty value
1278
                baseline_value = interactions[()]
1✔
1279
    elif min_order == 0:
1✔
1280
        # TODO(mmshlk): this might not be what we really want to do always: what if empty and baseline are different?
1281
        # https://github.com/mmschlk/shapiq/issues/385
1282
        interactions[()] = float(baseline_value)
×
1283
    return interactions, index, min_order, float(baseline_value)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc