• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 16572631955

28 Jul 2025 03:03PM UTC coverage: 93.697% (-0.2%) from 93.901%
16572631955

Pull #398

github

web-flow
Merge 09af9d874 into c22f28f96
Pull Request #398: 365 potentially remove interactionvaluesvalues

156 of 168 new or added lines in 15 files covered. (92.86%)

20 existing lines in 3 files now uncovered.

4935 of 5267 relevant lines covered (93.7%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.96
/src/shapiq/interaction_values.py
1
"""InteractionValues data-class, which is used to store the interaction scores."""
2

3
from __future__ import annotations
1✔
4

5
import contextlib
1✔
6
import copy
1✔
7
import json
1✔
8
import pickle
1✔
9
from pathlib import Path
1✔
10
from typing import TYPE_CHECKING
1✔
11
from warnings import warn
1✔
12

13
import numpy as np
1✔
14

15
from .game_theory.indices import (
1✔
16
    ALL_AVAILABLE_INDICES,
17
    index_generalizes_bv,
18
    index_generalizes_sv,
19
    is_empty_value_the_baseline,
20
    is_index_aggregated,
21
)
22
from .utils.errors import raise_deprecation_warning
1✔
23
from .utils.sets import generate_interaction_lookup
1✔
24

25
if TYPE_CHECKING:
1✔
26
    from collections.abc import Sequence
×
27
    from typing import Any
×
28

29
    from matplotlib.axes import Axes
×
30
    from matplotlib.figure import Figure
×
31

32
    from shapiq.typing import JSONType
×
33

34

35
SAVE_JSON_DEPRECATION_MSG = (
1✔
36
    "Saving InteractionValues not as a JSON file is deprecated. "
37
    "The parameters `as_pickle` and `as_npz` will be removed in the future. "
38
)
39

40

41
class InteractionValues:
1✔
42
    """This class contains the interaction values as estimated by an approximator.
43

44
    Attributes:
45
        values: The interaction values of the model in vectorized form.
46
        index: The interaction index estimated. All available indices are defined in
47
            ``ALL_AVAILABLE_INDICES``.
48
        max_order: The order of the approximation.
49
        n_players: The number of players.
50
        min_order: The minimum order of the approximation. Defaults to ``0``.
51
        interaction_lookup: A dictionary that maps interactions to their index in the values
52
            vector. If ``interaction_lookup`` is not provided, it is computed from the ``n_players``,
53
            ``min_order``, and `max_order` parameters. Defaults to ``None``.
54
        estimated: Whether the interaction values are estimated or not. Defaults to ``True``.
55
        estimation_budget: The budget used for the estimation. Defaults to ``None``.
56
        baseline_value: The value of the baseline interaction also known as 'empty prediction' or
57
            ``'empty value'`` since it denotes the value of the empty coalition (empty set). If not
58
            provided it is searched for in the values vector (raising an Error if not found).
59
            Defaults to ``None``.
60

61
    Raises:
62
        UserWarning: If the index is not a valid index as defined in ``ALL_AVAILABLE_INDICES``.
63
        TypeError: If the baseline value is not a number.
64

65
    """
66

67
    def __init__(
1✔
68
        self,
69
        values: np.ndarray | dict[tuple[int, ...], float],
70
        *,
71
        index: str,
72
        max_order: int,
73
        n_players: int,
74
        min_order: int,
75
        interaction_lookup: dict[tuple[int, ...], int] | None = None,  # type: ignore[assignment]
76
        estimated: bool = True,
77
        estimation_budget: int | None = None,  # type: ignore[assignment]
78
        baseline_value: float = 0.0,
79
        target_index: str | None = None,
80
    ) -> None:
81
        """Initialize the InteractionValues object.
82

83
        Args:
84
            values: The interaction values as a numpy array or a dictionary mapping interactions to their
85
                values.
86

87
            index: The index of the interaction values. This should be one of the indices defined in
88
            ALL_AVAILABLE_INDICES. It is used to determine how the interaction values are interpreted.
89
            max_order: The maximum order of the interactions.
90
            n_players: The number of players in the game.
91
            min_order: The minimum order of the interactions. Defaults to 0.
92
            interaction_lookup: A dictionary mapping interactions to their index in the values vector.
93
            Defaults to None, which means it will be generated from the n_players, min_order, and max_order parameters.
94
            estimated: Whether the interaction values are estimated or not. Defaults to True.
95
            estimation_budget: The budget used for the estimation. Defaults to None.
96
            baseline_value: The baseline value of the interaction values, also known as the empty prediction or empty value.
97
            target_index: The index to which the InteractionValues should be finalized. Defaults to None, which means that
98
            target_index = index
99
        """
100
        self.max_order = max_order
1✔
101
        self.n_players = n_players
1✔
102
        self.min_order = min_order
1✔
103
        self.estimated = estimated
1✔
104
        self.estimation_budget = estimation_budget
1✔
105
        self.baseline_value = self._validate_baseline_value(baseline_value)
1✔
106
        self.index, self.target_index = self._index_preprocessing(index, target_index)
1✔
107

108
        self.interactions, interaction_lookup = self._validate_attributions(
1✔
109
            values, interaction_lookup
110
        )
111
        self.interactions = self._populate_interactions(self.interactions, interaction_lookup)
1✔
112

113
        self.interactions = self._index_aggregation(self.interactions)
1✔
114

115
        self.interactions = self._index_baseline_adjustment(self.interactions)
1✔
116

117
    def _index_baseline_adjustment(
1✔
118
        self, interactions: dict[tuple[int, ...], float]
119
    ) -> dict[tuple[int, ...], float]:
120
        """Adjust the baseline of interactions based on the target index.
121

122
        Args:
123
            interactions: The interactions as a dictionary mapping interactions to their values.
124

125
        Returns:
126
            dict[tuple[int, ...], float]: The adjusted interactions.
127

128
        """
129
        if () in interactions:
1✔
130
            empty_value = interactions[()]
1✔
131
            if empty_value != self.baseline_value and self.index != "SII":
1✔
132
                if is_empty_value_the_baseline(self.index):
1✔
133
                    # insert the empty value given in baseline into the values
134
                    interactions[()] = self.baseline_value
1✔
135
                else:  # manually set baseline to the empty value
136
                    self.baseline_value = interactions[()]
1✔
137
        elif self.min_order == 0:
1✔
138
            # TODO(mmshlk): this might not be what we really want to do always: what if empty and baseline are different?
139
            # https://github.com/mmschlk/shapiq/issues/385
NEW
140
            interactions[()] = self.baseline_value
×
141
        return interactions
1✔
142

143
    def _index_aggregation(
1✔
144
        self, interactions: dict[tuple[int, ...], float]
145
    ) -> dict[tuple[int, ...], float]:
146
        """Adjust the interactions based on the target index if it is aggregated.
147

148
        Args:
149
            interactions: The interactions as a dictionary mapping interactions to their values.
150

151
        Returns:
152
            dict[tuple[int, ...], float]: The adjusted interactions.
153

154
        """
155
        from .game_theory.aggregation import aggregate_base_attributions
1✔
156

157
        if is_index_aggregated(self.target_index) and self.target_index != self.index:
1✔
158
            interactions, self.index, self.min_order = aggregate_base_attributions(
1✔
159
                interactions=interactions,
160
                index=self.index,
161
                order=self.max_order,
162
                min_order=self.min_order,
163
                baseline_value=self.baseline_value,
164
            )
165
        return interactions
1✔
166

167
    def _validate_attributions(
1✔
168
        self,
169
        values: np.ndarray | dict[tuple[int, ...], float],
170
        interaction_lookup: dict[tuple[int, ...], int],
171
    ) -> tuple[np.ndarray | dict[tuple[int, ...], float], dict[tuple[int, ...], int] | None]:
172
        """Validate the attributions provided to the InteractionValues object.
173

174
        Args:
175
            values: The interaction values as a numpy array.
176
            interaction_lookup: The interaction lookup as a dictionary mapping interactions to their
177
                index in the values vector.
178

179
        Raises:
180
            TypeError: If all three parameters are None or if both values and interaction_lookup are
181
                provided but not interactions.
182

183
        """
184
        if values is None:
1✔
NEW
UNCOV
185
            msg = "Values must be provided."
×
NEW
UNCOV
186
            raise TypeError(msg)
×
187
        if (
1✔
188
            values is not None
189
            and not isinstance(values, np.ndarray)
190
            and not isinstance(values, dict)
191
        ):
NEW
192
            msg = f"Values must be a numpy array or dictionary. Got {type(values)}."
×
NEW
UNCOV
193
            raise TypeError(msg)
×
194
        if interaction_lookup is not None and not isinstance(interaction_lookup, dict):
1✔
NEW
UNCOV
195
            msg = f"Interaction lookup must be a dictionary. Got {type(interaction_lookup)}."
×
NEW
UNCOV
196
            raise TypeError(msg)
×
197
        return values, interaction_lookup
1✔
198

199
    def _populate_interactions(
1✔
200
        self,
201
        values: np.ndarray | dict[tuple[int, ...], float],
202
        interaction_lookup: dict[tuple[int, ...], int] | None,
203
    ) -> dict[tuple[int, ...], float]:
204
        """Populate the attributions for the InteractionValues object.
205

206
        Args:
207
            values: The interaction values as a numpy array. If None, it will be generated from the
208
                interaction_lookup.
209
            interaction_lookup: The interaction lookup as a dictionary mapping interactions to their
210
                index in the values vector. If None, it will be generated from the n_players, min_order,
211
                and max_order parameters.
212

213
        Returns:
214
            dict[tuple[int, ...], float]: The populated interactions as a dictionary mapping interactions to their values.
215

216
        Note:
217
            If the interaction_lookup was provided by the user, the interactions dictionary will be populated accordingly.
218
            Therefore we only need to return the interactions dictionary, as it inherently "contains" the interaction_lookup.
219
        """
220
        if isinstance(values, dict):
1✔
221
            interactions = copy.deepcopy(values)
1✔
222
        elif isinstance(values, np.ndarray):
1✔
223
            interaction_lookup = self._populate_interaction_lookup(interaction_lookup)
1✔
224
            interactions = {
1✔
225
                interaction: values[index].item()
226
                for interaction, index in interaction_lookup.items()
227
            }
228
        else:
NEW
UNCOV
229
            msg = f"Values must be a numpy array or dictionary. Got {type(values)}."
×
NEW
UNCOV
230
            raise TypeError(msg)
×
231

232
        if self.min_order == 0 and () not in interactions:
1✔
233
            interactions[()] = self.baseline_value
1✔
234

235
        return interactions
1✔
236

237
    def _populate_interaction_lookup(
1✔
238
        self, interaction_lookup: dict[tuple[int, ...], int] | None
239
    ) -> dict[tuple[int, ...], int]:
240
        """Populate the interaction lookup if it is not already set.
241

242
        Args:
243
            interaction_lookup: The interaction lookup to populate. If it is None, it will be generated
244
                from the n_players, min_order, and max_order parameters.
245

246
        Returns:
247
            dict[tuple[int, ...], int]: The populated interaction lookup.
248

249
        Warnings:
250
            UserWarning: If the interaction_lookup is None, a warning is raised to inform the user that
251
            using a numpy array for values can be dangerous and that a dictionary should be used instead.
252
        """
253
        if interaction_lookup is None:
1✔
254
            warn(
1✔
255
                "The usage of a numpy array for values can be dangerous. To make sure that each value is associated to the correct coaltion, consider using a dictionary instead. Ofcourse setting the interaction_lookup would also be possible. ",
256
                UserWarning,
257
                stacklevel=2,
258
            )
259
            interaction_lookup = generate_interaction_lookup(
1✔
260
                self.n_players,
261
                self.min_order,
262
                self.max_order,
263
            )
264
        return interaction_lookup
1✔
265

266
    def _populate_values(self, interactions: dict[tuple[int, ...], float]) -> np.ndarray:
1✔
267
        """Populate the values from the interactions.
268

269
        Args:
270
            interactions: The interactions as a dictionary mapping interactions to their index in the
271
                values vector.
272

273
        Returns:
274
            np.ndarray: The populated values as a numpy array.
275

276
        """
NEW
UNCOV
277
        return np.array(list(interactions.values()))
×
278

279
    def _index_preprocessing(self, index: str, target_index: str) -> tuple[str, str]:
1✔
280
        """Preprocess the index to ensure it is valid and adjust it if necessary.
281

282
        Args:
283
            index: The index to preprocess.
284
            target_index: The index to which the InteractionValues should be finalized.
285

286

287
        Returns:
288
            str: The preprocessed index.
289

290
        """
291
        index = self._validate_index(index)
1✔
292
        index = self._adjust_index(index)
1✔
293
        if target_index is not None:
1✔
294
            return index, target_index
1✔
295
        return index, index
1✔
296

297
    def _validate_index(self, index: str) -> str:
1✔
298
        """Validate the index and check if it is a valid index as defined in ALL_AVAILABLE_INDICES.
299

300
        Args:
301
            index: The index to validate.
302

303
        Raises:
304
            UserWarning: If the index is not a valid index as defined in ALL_AVAILABLE_INDICES.
305

306
        TypeError: If the index is not a string.
307

308
        Returns:
309
            str: The validated index.
310

311
        """
312
        if not isinstance(index, str):
1✔
NEW
UNCOV
313
            msg = f"Index must be a string. Got {type(index)}."
×
NEW
UNCOV
314
            raise TypeError(msg)
×
315

316
        return self._check_index_valid(index)
1✔
317

318
    def _adjust_index(self, index: str) -> str:
1✔
319
        """Adjust the index to be either "BV" or "SV" if max_order is 1.
320

321
        Returns:
322
            None: The index is set as an attribute of the InteractionValues object.
323
        """
324
        # set BV or SV if max_order is 1
325
        if self.max_order == 1:
1✔
326
            if index_generalizes_bv(index):
1✔
327
                index = "BV"
1✔
328
            if index_generalizes_sv(index):
1✔
329
                index = "SV"
1✔
330
        return index
1✔
331

332
    def _validate_baseline_value(self, baseline_value: float) -> int | float:
1✔
333
        """Validate the baseline value.
334

335
        Raises:
336
            TypeError: If the baseline value is not a number (int or float).
337

338
        Returns:
339
            int | float: The validated baseline value.
340
        """
341
        if not isinstance(baseline_value, int | float | np.number):
1✔
342
            msg = f"Baseline value must be provided as a number. Got {type(baseline_value)}."
1✔
343
            raise TypeError(msg)
1✔
344
        return baseline_value
1✔
345

346
    def _check_index_valid(self, index: str) -> str:
1✔
347
        """Check if the index is valid.
348

349
        Args:
350
            index: The index to check.
351

352
        Raises:
353
            UserWarning: If the index is not a valid index as defined in ``ALL_AVAILABLE_INDICES``.
354

355
        """
356
        if index not in ALL_AVAILABLE_INDICES:
1✔
357
            warn(
1✔
358
                UserWarning(
359
                    f"Index {index} is not a valid index as defined in {ALL_AVAILABLE_INDICES}. "
360
                    "This might lead to unexpected behavior.",
361
                ),
362
                stacklevel=2,
363
            )
364
        return index
1✔
365

366
    @property
1✔
367
    def dict_values(self) -> dict[tuple[int, ...], float]:
1✔
368
        """Getter for the dict directly mapping from all interactions to scores."""
369
        return self.interactions
1✔
370

371
    @property
1✔
372
    def values(self) -> np.ndarray:
1✔
373
        """Getter for the values of the InteractionValues object.
374

375
        Returns:
376
            The values of the InteractionValues object as a numpy array.
377

378
        """
379
        return np.array(list(self.interactions.values()))
1✔
380

381
    @property
1✔
382
    def interaction_lookup(self) -> dict[tuple[int, ...], int]:
1✔
383
        """Getter for the interaction lookup of the InteractionValues object.
384

385
        Returns:
386
            The interaction lookup of the InteractionValues object as a dictionary mapping interactions
387
            to their index in the values vector.
388

389
        """
390
        return {
1✔
391
            interaction: index for index, (interaction, _) in enumerate(self.interactions.items())
392
        }
393

394
    def to_json_file(
1✔
395
        self,
396
        path: Path,
397
        *,
398
        desc: str | None = None,
399
        created_from: object | None = None,
400
        **kwargs: JSONType,
401
    ) -> None:
402
        """Saves the InteractionValues object to a JSON file.
403

404
        Args:
405
            path: The path to the JSON file.
406
            desc: A description of the InteractionValues object. Defaults to ``None``.
407
            created_from: An object from which the InteractionValues object was created. Defaults to
408
                ``None``.
409
            **kwargs: Additional parameters to store in the metadata of the JSON file.
410
        """
411
        from shapiq.utils.saving import interactions_to_dict, make_file_metadata, save_json
1✔
412

413
        file_metadata = make_file_metadata(
1✔
414
            object_to_store=self,
415
            data_type="interaction_values",
416
            desc=desc,
417
            created_from=created_from,
418
            parameters=kwargs,
419
        )
420
        json_data = {
1✔
421
            **file_metadata,
422
            "metadata": {
423
                "n_players": self.n_players,
424
                "index": self.index,
425
                "max_order": self.max_order,
426
                "min_order": self.min_order,
427
                "estimated": self.estimated,
428
                "estimation_budget": self.estimation_budget,
429
                "baseline_value": self.baseline_value,
430
            },
431
            "data": interactions_to_dict(interactions=self.dict_values),
432
        }
433
        save_json(json_data, path)
1✔
434

435
    @classmethod
1✔
436
    def from_json_file(cls, path: Path) -> InteractionValues:
1✔
437
        """Loads an InteractionValues object from a JSON file.
438

439
        Args:
440
            path: The path to the JSON file. Note that the path must end with `'.json'`.
441

442
        Returns:
443
            The InteractionValues object loaded from the JSON file.
444

445
        Raises:
446
            ValueError: If the path does not end with `'.json'`.
447
        """
448
        from shapiq.utils.saving import dict_to_lookup_and_values
1✔
449

450
        if not path.name.endswith(".json"):
1✔
UNCOV
451
            msg = f"Path {path} does not end with .json. Cannot load InteractionValues."
×
UNCOV
452
            raise ValueError(msg)
×
453

454
        with path.open("r", encoding="utf-8") as file:
1✔
455
            json_data = json.load(file)
1✔
456

457
        metadata = json_data["metadata"]
1✔
458
        interaction_dict = json_data["data"]
1✔
459
        interaction_lookup, values = dict_to_lookup_and_values(interaction_dict)
1✔
460

461
        return cls(
1✔
462
            values=values,
463
            index=metadata["index"],
464
            max_order=metadata["max_order"],
465
            n_players=metadata["n_players"],
466
            min_order=metadata["min_order"],
467
            interaction_lookup=interaction_lookup,
468
            estimated=metadata["estimated"],
469
            estimation_budget=metadata["estimation_budget"],
470
            baseline_value=metadata["baseline_value"],
471
        )
472

473
    def sparsify(self, threshold: float = 1e-3) -> None:
1✔
474
        """Manually sets values close to zero actually to zero (removing values).
475

476
        Args:
477
            threshold: The threshold value below which interactions are zeroed out. Defaults to
478
                1e-3.
479

480
        """
481
        # find interactions to remove in self.interactions
482
        sparse_interactions = copy.deepcopy(self.interactions)
1✔
483
        for interaction, value in self.interactions.items():
1✔
484
            if np.abs(value) < threshold:
1✔
485
                del sparse_interactions[interaction]
1✔
486
        self.interactions = sparse_interactions
1✔
487

488
    def get_top_k_interactions(self, k: int) -> InteractionValues:
1✔
489
        """Returns the top k interactions.
490

491
        Args:
492
            k: The number of top interactions to return.
493

494
        Returns:
495
            The top k interactions as an InteractionValues object.
496

497
        """
498
        top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
1✔
499
        new_values = np.zeros(k, dtype=float)
1✔
500
        new_interaction_lookup = {}
1✔
501
        for interaction_pos, interaction in enumerate(self.interaction_lookup):
1✔
502
            if interaction_pos in top_k_indices:
1✔
503
                new_position = len(new_interaction_lookup)
1✔
504
                new_values[new_position] = float(self[interaction_pos])
1✔
505
                new_interaction_lookup[interaction] = new_position
1✔
506
        return InteractionValues(
1✔
507
            values=new_values,
508
            index=self.index,
509
            max_order=self.max_order,
510
            n_players=self.n_players,
511
            min_order=self.min_order,
512
            interaction_lookup=new_interaction_lookup,
513
            estimated=self.estimated,
514
            estimation_budget=self.estimation_budget,
515
            baseline_value=self.baseline_value,
516
        )
517

518
    def get_top_k(
1✔
519
        self, k: int, *, as_interaction_values: bool = True
520
    ) -> InteractionValues | tuple[dict, list[tuple]]:
521
        """Returns the top k interactions.
522

523
        Args:
524
            k: The number of top interactions to return.
525
            as_interaction_values: Whether to return the top `k` interactions as an InteractionValues
526
                object. Defaults to ``False``.
527

528
        Returns:
529
            The top k interactions as a dictionary and a sorted list of tuples.
530

531
        Examples:
532
            >>> interaction_values = InteractionValues(
533
            ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
534
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
535
            ...     index="SII",
536
            ...     max_order=2,
537
            ...     n_players=3,
538
            ...     min_order=1,
539
            ...     baseline_value=0.0,
540
            ... )
541
            >>> top_k_interactions, sorted_top_k_interactions = interaction_values.get_top_k(2, False)
542
            >>> top_k_interactions
543
            {(0, 2): 0.5, (1, 0): 0.6}
544
            >>> sorted_top_k_interactions
545
            [((1, 0), 0.6), ((0, 2), 0.5)]
546

547
        """
548
        if as_interaction_values:
1✔
549
            return self.get_top_k_interactions(k)
1✔
550
        top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
1✔
551
        top_k_interactions = {}
1✔
552
        for interaction, index in self.interaction_lookup.items():
1✔
553
            if index in top_k_indices:
1✔
554
                top_k_interactions[interaction] = self.values[index]
1✔
555
        sorted_top_k_interactions = [
1✔
556
            (interaction, top_k_interactions[interaction])
557
            for interaction in sorted(top_k_interactions, key=top_k_interactions.get, reverse=True)
558
        ]
559
        return top_k_interactions, sorted_top_k_interactions
1✔
560

561
    def __repr__(self) -> str:
1✔
562
        """Returns the representation of the InteractionValues object."""
563
        representation = "InteractionValues(\n"
1✔
564
        representation += (
1✔
565
            f"    index={self.index}, max_order={self.max_order}, min_order={self.min_order}"
566
            f", estimated={self.estimated}, estimation_budget={self.estimation_budget},\n"
567
            f"    n_players={self.n_players}, baseline_value={self.baseline_value}\n)"
568
        )
569
        return representation
1✔
570

571
    def __str__(self) -> str:
1✔
572
        """Returns the string representation of the InteractionValues object."""
573
        representation = self.__repr__()
1✔
574
        representation = representation[:-2]  # remove the last "\n)" and add values
1✔
575
        _, sorted_top_10_interactions = self.get_top_k(
1✔
576
            10, as_interaction_values=False
577
        )  # get top 10 interactions
578
        # add values to string representation
579
        representation += ",\n    Top 10 interactions:\n"
1✔
580
        for interaction, value in sorted_top_10_interactions:
1✔
581
            representation += f"        {interaction}: {value}\n"
1✔
582
        representation += ")"
1✔
583
        return representation
1✔
584

585
    def __len__(self) -> int:
1✔
586
        """Returns the length of the InteractionValues object."""
587
        return len(self.values)  # might better to return the theoretical no. of interactions
1✔
588

589
    def __iter__(self) -> np.nditer:
1✔
590
        """Returns an iterator over the values of the InteractionValues object."""
591
        return np.nditer(self.values)
1✔
592

593
    def __getitem__(self, item: int | tuple[int, ...]) -> float:
1✔
594
        """Returns the score for the given interaction.
595

596
        Args:
597
            item: The interaction as a tuple of integers for which to return the score. If ``item`` is
598
                an integer it serves as the index to the values vector.
599

600
        Returns:
601
            The interaction value. If the interaction is not present zero is returned.
602

603
        """
604
        if isinstance(item, int):
1✔
605
            return float(self.values[item])
1✔
606
        item = tuple(sorted(item))
1✔
607
        try:
1✔
608
            return float(self.interactions[item])
1✔
609
        except KeyError:
1✔
610
            return 0.0
1✔
611

612
    def __setitem__(self, item: int | tuple[int, ...], value: float) -> None:
1✔
613
        """Sets the score for the given interaction.
614

615
        Args:
616
            item: The interaction as a tuple of integers for which to set the score. If ``item`` is an
617
                integer it serves as the index to the values vector.
618
            value: The value to set for the interaction.
619

620
        Raises:
621
            KeyError: If the interaction is not found in the InteractionValues object.
622

623
        """
624
        try:
1✔
625
            if isinstance(item, int):
1✔
626
                # dict.items() preserves the order of insertion, so we can use it to set the value
627
                for i, (interaction, _) in enumerate(self.interactions.items()):
1✔
628
                    if i == item:
1✔
629
                        self.interactions[interaction] = value
1✔
630
                        break
1✔
631
            else:
632
                item = tuple(sorted(item))
1✔
633
                if self.interactions[item]:
1✔
634
                    # if the interaction is already present, update its value. Otherwise KeyError is raised
635
                    self.interactions[item] = value
1✔
636
        except Exception as e:
1✔
637
            msg = f"Interaction {item} not found in the InteractionValues. Unable to set a value."
1✔
638
            raise KeyError(msg) from e
1✔
639

640
    def __eq__(self, other: object) -> bool:
1✔
641
        """Checks if two InteractionValues objects are equal.
642

643
        Args:
644
            other: The other InteractionValues object.
645

646
        Returns:
647
            True if the two objects are equal, False otherwise.
648

649
        """
650
        if not isinstance(other, InteractionValues):
1✔
651
            msg = "Cannot compare InteractionValues with other types."
1✔
652
            raise TypeError(msg)
1✔
653
        if (
1✔
654
            self.index != other.index
655
            or self.max_order != other.max_order
656
            or self.min_order != other.min_order
657
            or self.n_players != other.n_players
658
            or self.baseline_value != other.baseline_value
659
        ):
660
            return False
1✔
661
        if not np.allclose(self.values, other.values):
1✔
662
            return False
1✔
663
        return self.interaction_lookup == other.interaction_lookup
1✔
664

665
    def __ne__(self, other: object) -> bool:
1✔
666
        """Checks if two InteractionValues objects are not equal.
667

668
        Args:
669
            other: The other InteractionValues object.
670

671
        Returns:
672
            True if the two objects are not equal, False otherwise.
673

674
        """
675
        return not self.__eq__(other)
1✔
676

677
    def __hash__(self) -> int:
1✔
678
        """Returns the hash of the InteractionValues object."""
679
        return hash(
1✔
680
            (
681
                self.index,
682
                self.max_order,
683
                self.min_order,
684
                self.n_players,
685
                tuple(self.values.flatten()),
686
            ),
687
        )
688

689
    def __copy__(self) -> InteractionValues:
1✔
690
        """Returns a copy of the InteractionValues object."""
691
        return InteractionValues(
1✔
692
            values=copy.deepcopy(self.values),
693
            index=self.index,
694
            max_order=self.max_order,
695
            estimated=self.estimated,
696
            estimation_budget=self.estimation_budget,
697
            n_players=self.n_players,
698
            interaction_lookup=copy.deepcopy(self.interaction_lookup),
699
            min_order=self.min_order,
700
            baseline_value=self.baseline_value,
701
        )
702

703
    def __add__(self, other: InteractionValues | float) -> InteractionValues:
1✔
704
        """Adds two InteractionValues objects together or a scalar."""
705
        n_players, min_order, max_order = self.n_players, self.min_order, self.max_order
1✔
706
        if isinstance(other, InteractionValues):
1✔
707
            if self.index != other.index:  # different indices
1✔
708
                msg = (
1✔
709
                    f"Cannot add InteractionValues with different indices {self.index} and "
710
                    f"{other.index}."
711
                )
712
                raise ValueError(msg)
1✔
713
            if (
1✔
714
                self.interaction_lookup != other.interaction_lookup
715
                or self.n_players != other.n_players
716
                or self.min_order != other.min_order
717
                or self.max_order != other.max_order
718
            ):  # different interactions but addable
719
                added_interactions = self.interactions.copy()
1✔
720
                for interaction in other.interactions:
1✔
721
                    if interaction not in added_interactions:
1✔
722
                        added_interactions[interaction] = other.interactions[interaction]
1✔
723
                    else:
724
                        added_interactions[interaction] += other.interactions[interaction]
1✔
725
                interaction_lookup = {
1✔
726
                    interaction: i for i, interaction in enumerate(added_interactions)
727
                }
728
                # adjust n_players, min_order, and max_order
729
                n_players = max(self.n_players, other.n_players)
1✔
730
                min_order = min(self.min_order, other.min_order)
1✔
731
                max_order = max(self.max_order, other.max_order)
1✔
732
                baseline_value = self.baseline_value + other.baseline_value
1✔
733
            else:  # basic case with same interactions
734
                added_interactions = {
1✔
735
                    interaction: self.interactions[interaction] + other.interactions[interaction]
736
                    for interaction in self.interactions
737
                }
738
                interaction_lookup = self.interaction_lookup
1✔
739
                baseline_value = self.baseline_value + other.baseline_value
1✔
740
        elif isinstance(other, int | float):
1✔
741
            added_interactions = {
1✔
742
                interaction: self.interactions[interaction] + other
743
                for interaction in self.interactions
744
            }
745
            interaction_lookup = self.interaction_lookup.copy()
1✔
746
            baseline_value = self.baseline_value + other
1✔
747
        else:
748
            msg = f"Cannot add InteractionValues with object of type {type(other)}."
1✔
749
            raise TypeError(msg)
1✔
750

751
        return InteractionValues(
1✔
752
            values=added_interactions,
753
            index=self.index,
754
            max_order=max_order,
755
            n_players=n_players,
756
            min_order=min_order,
757
            interaction_lookup=interaction_lookup,
758
            estimated=self.estimated,
759
            estimation_budget=self.estimation_budget,
760
            baseline_value=baseline_value,
761
        )
762

763
    def __radd__(self, other: InteractionValues | float) -> InteractionValues:
1✔
764
        """Adds two InteractionValues objects together or a scalar."""
765
        return self.__add__(other)
1✔
766

767
    def __neg__(self) -> InteractionValues:
1✔
768
        """Negates the InteractionValues object."""
769
        return InteractionValues(
1✔
770
            values=-self.values,
771
            index=self.index,
772
            max_order=self.max_order,
773
            n_players=self.n_players,
774
            min_order=self.min_order,
775
            interaction_lookup=self.interaction_lookup,
776
            estimated=self.estimated,
777
            estimation_budget=self.estimation_budget,
778
            baseline_value=-self.baseline_value,
779
        )
780

781
    def __sub__(self, other: InteractionValues | float) -> InteractionValues:
1✔
782
        """Subtracts two InteractionValues objects or a scalar."""
783
        return self.__add__(-other)
1✔
784

785
    def __rsub__(self, other: InteractionValues | float) -> InteractionValues:
1✔
786
        """Subtracts two InteractionValues objects or a scalar."""
787
        return (-self).__add__(other)
1✔
788

789
    def __mul__(self, other: float) -> InteractionValues:
1✔
790
        """Multiplies an InteractionValues object by a scalar."""
791
        interactions = {
1✔
792
            interaction: value * other for interaction, value in self.interactions.items()
793
        }
794
        return InteractionValues(
1✔
795
            values=interactions,
796
            index=self.index,
797
            max_order=self.max_order,
798
            n_players=self.n_players,
799
            min_order=self.min_order,
800
            interaction_lookup=self.interaction_lookup,
801
            estimated=self.estimated,
802
            estimation_budget=self.estimation_budget,
803
            baseline_value=self.baseline_value * other,
804
        )
805

806
    def __rmul__(self, other: float) -> InteractionValues:
1✔
807
        """Multiplies an InteractionValues object by a scalar."""
808
        return self.__mul__(other)
1✔
809

810
    def __abs__(self) -> InteractionValues:
1✔
811
        """Returns the absolute values of the InteractionValues object."""
812
        interactions = {interaction: abs(value) for interaction, value in self.interactions.items()}
1✔
813
        return InteractionValues(
1✔
814
            values=interactions,
815
            index=self.index,
816
            max_order=self.max_order,
817
            n_players=self.n_players,
818
            min_order=self.min_order,
819
            interaction_lookup=self.interaction_lookup,
820
            estimated=self.estimated,
821
            estimation_budget=self.estimation_budget,
822
            baseline_value=self.baseline_value,
823
        )
824

825
    def get_n_order_values(self, order: int) -> np.ndarray:
1✔
826
        """Returns the interaction values of a specific order as a numpy array.
827

828
        Note:
829
            Depending on the order and number of players the resulting array might be sparse and
830
            very large.
831

832
        Args:
833
            order: The order of the interactions to return.
834

835
        Returns:
836
            The interaction values of the specified order as a numpy array of shape ``(n_players,)``
837
            for order ``1`` and ``(n_players, n_players)`` for order ``2``, etc.
838

839
        Raises:
840
            ValueError: If the order is less than ``1``.
841

842
        """
843
        from itertools import permutations
1✔
844

845
        if order < 1:
1✔
846
            msg = "Order must be greater or equal to 1."
1✔
847
            raise ValueError(msg)
1✔
848
        values_shape = tuple([self.n_players] * order)
1✔
849
        values = np.zeros(values_shape, dtype=float)
1✔
850
        for interaction in self.interaction_lookup:
1✔
851
            if len(interaction) != order:
1✔
852
                continue
1✔
853
            # get all orderings of the interaction (e.g. (0, 1) and (1, 0) for interaction (0, 1))
854
            for perm in permutations(interaction):
1✔
855
                values[perm] = self[interaction]
1✔
856

857
        return values
1✔
858

859
    def get_n_order(
1✔
860
        self,
861
        order: int | None = None,
862
        min_order: int | None = None,
863
        max_order: int | None = None,
864
    ) -> InteractionValues:
865
        """Select particular order of interactions.
866

867
        Creates a new InteractionValues object containing only the interactions within the
868
        specified order range.
869

870
        You can specify:
871
            - `order`: to select interactions of a single specific order (e.g., all pairwise
872
                interactions).
873
            - `min_order` and/or `max_order`: to select a range of interaction orders.
874
            - If `order` and `min_order`/`max_order` are both set, `min_order` and `max_order` will
875
                override the `order` value.
876

877
        Example:
878
            >>> interaction_values = InteractionValues(
879
            ...     values=np.array([1, 2, 3, 4, 5, 6, 7]),
880
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5, (0, 1, 2): 6},
881
            ...     index="SII",
882
            ...     max_order=3,
883
            ...     n_players=3,
884
            ...     min_order=1,
885
            ...     baseline_value=0.0,
886
            ... )
887
            >>> interaction_values.get_n_order(order=1).dict_values
888
            {(0,): 1.0, (1,): 2.0, (2,): 3.0}
889
            >>> interaction_values.get_n_order(min_order=1, max_order=2).dict_values
890
            {(0,): 1.0, (1,): 2.0, (2,): 3.0, (0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0}
891
            >>> interaction_values.get_n_order(min_order=2).dict_values
892
            {(0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0, (0, 1, 2): 7.0}
893

894
        Args:
895
            order: The order of the interactions to return. Defaults to ``None`` which requires
896
                ``min_order`` or ``max_order`` to be set.
897
            min_order: The minimum order of the interactions to return. Defaults to ``None`` which
898
                sets it to the order.
899
            max_order: The maximum order of the interactions to return. Defaults to ``None`` which
900
                sets it to the order.
901

902
        Returns:
903
            The interaction values of the specified order.
904

905
        Raises:
906
            ValueError: If all three parameters are set to ``None``.
907
        """
908
        if order is None and min_order is None and max_order is None:
1✔
909
            msg = "Either order, min_order or max_order must be set."
1✔
910
            raise ValueError(msg)
1✔
911

912
        if order is not None:
1✔
913
            max_order = order if max_order is None else max_order
1✔
914
            min_order = order if min_order is None else min_order
1✔
915
        else:  # order is None
916
            min_order = self.min_order if min_order is None else min_order
1✔
917
            max_order = self.max_order if max_order is None else max_order
1✔
918

919
        if min_order > max_order:
1✔
920
            msg = f"min_order ({min_order}) must be less than or equal to max_order ({max_order})."
1✔
921
            raise ValueError(msg)
1✔
922

923
        new_values = []
1✔
924
        new_interaction_lookup = {}
1✔
925
        for interaction in self.interaction_lookup:
1✔
926
            if len(interaction) < min_order or len(interaction) > max_order:
1✔
927
                continue
1✔
928
            interaction_idx = len(new_interaction_lookup)
1✔
929
            new_values.append(self[interaction])
1✔
930
            new_interaction_lookup[interaction] = interaction_idx
1✔
931

932
        return InteractionValues(
1✔
933
            values=np.array(new_values),
934
            index=self.index,
935
            max_order=max_order,
936
            n_players=self.n_players,
937
            min_order=min_order,
938
            interaction_lookup=new_interaction_lookup,
939
            estimated=self.estimated,
940
            estimation_budget=self.estimation_budget,
941
            baseline_value=self.baseline_value,
942
        )
943

944
    def get_subset(self, players: list[int]) -> InteractionValues:
1✔
945
        """Selects a subset of players from the InteractionValues object.
946

947
        Args:
948
            players: List of players to select from the InteractionValues object.
949

950
        Returns:
951
            InteractionValues: Filtered InteractionValues object containing only values related to
952
            selected players.
953

954
        Example:
955
            >>> interaction_values = InteractionValues(
956
            ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
957
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
958
            ...     index="SII",
959
            ...     max_order=2,
960
            ...     n_players=3,
961
            ...     min_order=1,
962
            ...     baseline_value=0.0,
963
            ... )
964
            >>> interaction_values.get_subset([0, 1]).dict_values
965
            {(0,): 0.1, (1,): 0.2, (0, 1): 0.3}
966
            >>> interaction_values.get_subset([0, 2]).dict_values
967
            {(0,): 0.1, (2,): 0.3, (0, 2): 0.4}
968
            >>> interaction_values.get_subset([1]).dict_values
969
            {(1,): 0.2}
970

971
        """
972
        keys = self.interaction_lookup.keys()
1✔
973
        idx, keys_in_subset = [], []
1✔
974
        for i, key in enumerate(keys):
1✔
975
            if all(p in players for p in key):
1✔
976
                idx.append(i)
1✔
977
                keys_in_subset.append(key)
1✔
978
        new_values = self.values[idx]
1✔
979
        new_interaction_lookup = {key: index for index, key in enumerate(keys_in_subset)}
1✔
980
        n_players = self.n_players - len(players)
1✔
981
        return InteractionValues(
1✔
982
            values=new_values,
983
            index=self.index,
984
            max_order=self.max_order,
985
            n_players=n_players,
986
            min_order=self.min_order,
987
            interaction_lookup=new_interaction_lookup,
988
            estimated=self.estimated,
989
            estimation_budget=self.estimation_budget,
990
            baseline_value=self.baseline_value,
991
        )
992

993
    def save(self, path: Path, *, as_pickle: bool = False, as_npz: bool = False) -> None:
1✔
994
        """Save the InteractionValues object to a file.
995

996
        By default, the InteractionValues object is saved as a JSON file.
997

998
        Args:
999
            path: The path to save the InteractionValues object to.
1000
            as_pickle: Whether to save the InteractionValues object as a pickle file (``True``).
1001
            as_npz: Whether to save the InteractionValues object as a ``npz`` file (``True``).
1002

1003
        Raises:
1004
            DeprecationWarning: If `as_pickle` or `as_npz` is set to ``True``, a deprecation
1005
                warning is raised
1006
        """
1007
        # check if the directory exists
1008
        directory = Path(path).parent
1✔
1009
        if not Path(directory).exists():
1✔
UNCOV
1010
            with contextlib.suppress(FileNotFoundError):
×
UNCOV
1011
                Path(directory).mkdir(parents=True, exist_ok=True)
×
1012
        if as_pickle:
1✔
1013
            raise_deprecation_warning(
1✔
1014
                message=SAVE_JSON_DEPRECATION_MSG, deprecated_in="1.3.1", removed_in="1.4.0"
1015
            )
1016
            with Path(path).open("wb") as file:
1✔
1017
                pickle.dump(self, file)
1✔
1018
        elif as_npz:
1✔
1019
            raise_deprecation_warning(
1✔
1020
                message=SAVE_JSON_DEPRECATION_MSG, deprecated_in="1.3.1", removed_in="1.4.0"
1021
            )
1022
            # save object as npz file
1023
            np.savez(
1✔
1024
                path,
1025
                values=self.values,
1026
                index=self.index,
1027
                max_order=self.max_order,
1028
                n_players=self.n_players,
1029
                min_order=self.min_order,
1030
                interaction_lookup=self.interaction_lookup,
1031
                estimated=self.estimated,
1032
                estimation_budget=self.estimation_budget,
1033
                baseline_value=self.baseline_value,
1034
            )
1035
        else:
1036
            self.to_json_file(path)
1✔
1037

1038
    @classmethod
1✔
1039
    def load(cls, path: Path | str) -> InteractionValues:
1✔
1040
        """Load an InteractionValues object from a file.
1041

1042
        Args:
1043
            path: The path to load the InteractionValues object from.
1044

1045
        Returns:
1046
            The loaded InteractionValues object.
1047

1048
        """
1049
        path = Path(path)
1✔
1050
        # check if path ends with .json
1051
        if path.name.endswith(".json"):
1✔
1052
            return cls.from_json_file(path)
1✔
1053

1054
        raise_deprecation_warning(
1✔
1055
            SAVE_JSON_DEPRECATION_MSG, deprecated_in="1.3.1", removed_in="1.4.0"
1056
        )
1057

1058
        # try loading as npz file
1059
        if path.name.endswith(".npz"):
1✔
1060
            data = np.load(path, allow_pickle=True)
1✔
1061
            return InteractionValues(
1✔
1062
                values=data["values"],
1063
                index=str(data["index"]),
1064
                max_order=int(data["max_order"]),
1065
                n_players=int(data["n_players"]),
1066
                min_order=int(data["min_order"]),
1067
                interaction_lookup=data["interaction_lookup"].item(),
1068
                estimated=bool(data["estimated"]),
1069
                estimation_budget=data["estimation_budget"].item(),
1070
                baseline_value=float(data["baseline_value"]),
1071
            )
UNCOV
1072
        msg = f"Path {path} does not end with .json or .npz. Cannot load InteractionValues."
×
UNCOV
1073
        raise ValueError(msg)
×
1074

1075
    @classmethod
1✔
1076
    def from_dict(cls, data: dict[str, Any]) -> InteractionValues:
1✔
1077
        """Create an InteractionValues object from a dictionary.
1078

1079
        Args:
1080
            data: The dictionary containing the data to create the InteractionValues object from.
1081

1082
        Returns:
1083
            The InteractionValues object created from the dictionary.
1084

1085
        """
1086
        return cls(
1✔
1087
            values=data["values"],
1088
            index=data["index"],
1089
            max_order=data["max_order"],
1090
            n_players=data["n_players"],
1091
            min_order=data["min_order"],
1092
            interaction_lookup=data["interaction_lookup"],
1093
            estimated=data["estimated"],
1094
            estimation_budget=data["estimation_budget"],
1095
            baseline_value=data["baseline_value"],
1096
        )
1097

1098
    def to_dict(self) -> dict:
1✔
1099
        """Convert the InteractionValues object to a dictionary.
1100

1101
        Returns:
1102
            The InteractionValues object as a dictionary.
1103

1104
        """
1105
        return {
1✔
1106
            "values": self.values,
1107
            "index": self.index,
1108
            "max_order": self.max_order,
1109
            "n_players": self.n_players,
1110
            "min_order": self.min_order,
1111
            "interaction_lookup": self.interaction_lookup,
1112
            "estimated": self.estimated,
1113
            "estimation_budget": self.estimation_budget,
1114
            "baseline_value": self.baseline_value,
1115
        }
1116

1117
    def aggregate(
1✔
1118
        self,
1119
        others: Sequence[InteractionValues],
1120
        aggregation: str = "mean",
1121
    ) -> InteractionValues:
1122
        """Aggregates InteractionValues objects using a specific aggregation method.
1123

1124
        Args:
1125
            others: A list of InteractionValues objects to aggregate.
1126
            aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
1127
                ``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
1128

1129
        Returns:
1130
            The aggregated InteractionValues object.
1131

1132
        Note:
1133
            For documentation on the aggregation methods, see the ``aggregate_interaction_values()``
1134
            function.
1135

1136
        """
1137
        return aggregate_interaction_values([self, *others], aggregation)
1✔
1138

1139
    def plot_network(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
1140
        """Visualize InteractionValues on a graph.
1141

1142
        Note:
1143
            For arguments, see :func:`shapiq.plot.network.network_plot` and
1144
                :func:`shapiq.plot.si_graph.si_graph_plot`.
1145

1146
        Args:
1147
            show: Whether to show the plot. Defaults to ``True``.
1148

1149
            **kwargs: Additional keyword arguments to pass to the plotting function.
1150

1151
        Returns:
1152
            If show is ``False``, the function returns a tuple with the figure and the axis of the
1153
                plot.
1154
        """
1155
        from shapiq.plot.network import network_plot
1✔
1156

1157
        if self.max_order > 1:
1✔
1158
            return network_plot(
1✔
1159
                interaction_values=self,
1160
                show=show,
1161
                **kwargs,
1162
            )
1163
        msg = (
1✔
1164
            "InteractionValues contains only 1-order values,"
1165
            "but requires also 2-order values for the network plot."
1166
        )
1167
        raise ValueError(msg)
1✔
1168

1169
    def plot_si_graph(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
1170
        """Visualize InteractionValues as a SI graph.
1171

1172
        For arguments, see shapiq.plots.si_graph_plot().
1173

1174
        Returns:
1175
            The SI graph as a tuple containing the figure and the axes.
1176

1177
        """
1178
        from shapiq.plot.si_graph import si_graph_plot
1✔
1179

1180
        return si_graph_plot(self, show=show, **kwargs)
1✔
1181

1182
    def plot_stacked_bar(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
1183
        """Visualize InteractionValues on a graph.
1184

1185
        For arguments, see shapiq.plots.stacked_bar_plot().
1186

1187
        Returns:
1188
            The stacked bar plot as a tuple containing the figure and the axes.
1189

1190
        """
1191
        from shapiq import stacked_bar_plot
1✔
1192

1193
        return stacked_bar_plot(self, show=show, **kwargs)
1✔
1194

1195
    def plot_force(
1✔
1196
        self,
1197
        feature_names: np.ndarray | None = None,
1198
        *,
1199
        show: bool = True,
1200
        abbreviate: bool = True,
1201
        contribution_threshold: float = 0.05,
1202
    ) -> Figure | None:
1203
        """Visualize InteractionValues on a force plot.
1204

1205
        For arguments, see shapiq.plots.force_plot().
1206

1207
        Args:
1208
            feature_names: The feature names used for plotting. If no feature names are provided, the
1209
                feature indices are used instead. Defaults to ``None``.
1210
            show: Whether to show the plot. Defaults to ``False``.
1211
            abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
1212
            contribution_threshold: The threshold for contributions to be displayed in percent.
1213
                Defaults to ``0.05``.
1214

1215
        Returns:
1216
            The force plot as a matplotlib figure (if show is ``False``).
1217

1218
        """
1219
        from .plot import force_plot
1✔
1220

1221
        return force_plot(
1✔
1222
            self,
1223
            feature_names=feature_names,
1224
            show=show,
1225
            abbreviate=abbreviate,
1226
            contribution_threshold=contribution_threshold,
1227
        )
1228

1229
    def plot_waterfall(
1✔
1230
        self,
1231
        feature_names: np.ndarray | None = None,
1232
        *,
1233
        show: bool = True,
1234
        abbreviate: bool = True,
1235
        max_display: int = 10,
1236
    ) -> Axes | None:
1237
        """Draws interaction values on a waterfall plot.
1238

1239
        Note:
1240
            Requires the ``shap`` Python package to be installed.
1241

1242
        Args:
1243
            feature_names: The feature names used for plotting. If no feature names are provided, the
1244
                feature indices are used instead. Defaults to ``None``.
1245
            show: Whether to show the plot. Defaults to ``False``.
1246
            abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
1247
            max_display: The maximum number of interactions to display. Defaults to ``10``.
1248
        """
1249
        from shapiq import waterfall_plot
1✔
1250

1251
        return waterfall_plot(
1✔
1252
            self,
1253
            feature_names=feature_names,
1254
            show=show,
1255
            max_display=max_display,
1256
            abbreviate=abbreviate,
1257
        )
1258

1259
    def plot_sentence(
1✔
1260
        self,
1261
        words: list[str],
1262
        *,
1263
        show: bool = True,
1264
        **kwargs: Any,
1265
    ) -> tuple[Figure, Axes] | None:
1266
        """Plots the first order effects (attributions) of a sentence or paragraph.
1267

1268
        For arguments, see shapiq.plots.sentence_plot().
1269

1270
        Returns:
1271
            If ``show`` is ``True``, the function returns ``None``. Otherwise, it returns a tuple
1272
            with the figure and the axis of the plot.
1273

1274
        """
1275
        from shapiq.plot.sentence import sentence_plot
1✔
1276

1277
        return sentence_plot(self, words, show=show, **kwargs)
1✔
1278

1279
    def plot_upset(self, *, show: bool = True, **kwargs: Any) -> Figure | None:
1✔
1280
        """Plots the upset plot.
1281

1282
        For arguments, see shapiq.plot.upset_plot().
1283

1284
        Returns:
1285
            The upset plot as a matplotlib figure (if show is ``False``).
1286

1287
        """
1288
        from shapiq.plot.upset import upset_plot
1✔
1289

1290
        return upset_plot(self, show=show, **kwargs)
1✔
1291

1292

1293
def aggregate_interaction_values(
1✔
1294
    interaction_values: Sequence[InteractionValues],
1295
    aggregation: str = "mean",
1296
) -> InteractionValues:
1297
    """Aggregates InteractionValues objects using a specific aggregation method.
1298

1299
    Args:
1300
        interaction_values: A list of InteractionValues objects to aggregate.
1301
        aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
1302
            ``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
1303

1304
    Returns:
1305
        The aggregated InteractionValues object.
1306

1307
    Example:
1308
        >>> iv1 = InteractionValues(
1309
        ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
1310
        ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
1311
        ...     index="SII",
1312
        ...     max_order=2,
1313
        ...     n_players=3,
1314
        ...     min_order=1,
1315
        ...     baseline_value=0.0,
1316
        ... )
1317
        >>> iv2 = InteractionValues(
1318
        ...     values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]),  # this iv is missing the (1, 2) value
1319
        ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4},  # no (1, 2)
1320
        ...     index="SII",
1321
        ...     max_order=2,
1322
        ...     n_players=3,
1323
        ...     min_order=1,
1324
        ...     baseline_value=1.0,
1325
        ... )
1326
        >>> aggregate_interaction_values([iv1, iv2], "mean")
1327
        InteractionValues(
1328
            index=SII, max_order=2, min_order=1, estimated=True, estimation_budget=None,
1329
            n_players=3, baseline_value=0.5,
1330
            Top 10 interactions:
1331
                (1, 2): 0.60
1332
                (0, 2): 0.35
1333
                (0, 1): 0.25
1334
                (0,): 0.15
1335
                (1,): 0.25
1336
                (2,): 0.35
1337
        )
1338

1339
    Note:
1340
        The index of the aggregated InteractionValues object is set to the index of the first
1341
        InteractionValues object in the list.
1342

1343
    Raises:
1344
        ValueError: If the aggregation method is not supported.
1345

1346
    """
1347

1348
    def _aggregate(vals: list[float], method: str) -> float:
1✔
1349
        """Does the actual aggregation of the values."""
1350
        if method == "mean":
1✔
1351
            return float(np.mean(vals))
1✔
1352
        if method == "median":
1✔
1353
            return float(np.median(vals))
1✔
1354
        if method == "sum":
1✔
1355
            return np.sum(vals)
1✔
1356
        if method == "max":
1✔
1357
            return np.max(vals)
1✔
1358
        if method == "min":
1✔
1359
            return np.min(vals)
1✔
1360
        msg = f"Aggregation method {method} is not supported."
1✔
1361
        raise ValueError(msg)
1✔
1362

1363
    # get all keys from all InteractionValues objects
1364
    all_keys = set()
1✔
1365
    for iv in interaction_values:
1✔
1366
        all_keys.update(iv.interaction_lookup.keys())
1✔
1367
    all_keys = sorted(all_keys)
1✔
1368

1369
    # aggregate the values
1370
    new_values = np.zeros(len(all_keys), dtype=float)
1✔
1371
    new_lookup = {}
1✔
1372
    for i, key in enumerate(all_keys):
1✔
1373
        new_lookup[key] = i
1✔
1374
        values = [iv[key] for iv in interaction_values]
1✔
1375
        new_values[i] = _aggregate(values, aggregation)
1✔
1376

1377
    max_order = max([iv.max_order for iv in interaction_values])
1✔
1378
    min_order = min([iv.min_order for iv in interaction_values])
1✔
1379
    n_players = max([iv.n_players for iv in interaction_values])
1✔
1380
    baseline_value = _aggregate([iv.baseline_value for iv in interaction_values], aggregation)
1✔
1381
    estimation_budget = interaction_values[0].estimation_budget
1✔
1382

1383
    return InteractionValues(
1✔
1384
        values=new_values,
1385
        index=interaction_values[0].index,
1386
        max_order=max_order,
1387
        n_players=n_players,
1388
        min_order=min_order,
1389
        interaction_lookup=new_lookup,
1390
        estimated=True,
1391
        estimation_budget=estimation_budget,
1392
        baseline_value=baseline_value,
1393
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc