• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 16345920221

17 Jul 2025 01:05PM UTC coverage: 77.589% (-13.5%) from 91.075%
16345920221

push

github

web-flow
🔨 Refactors library into a src structure. (#415)

* moves shapiq into a src folder

* moves shapiq tests into tests_shapiq subfolder in tests

* refactors tests to work properly

* removes pickle support and closes #413

* changes unit tests to only run the unit tests

* adds workflow for running shapiq_games

* updates coverage to only run for shapiq

* update workflow to check for shapiq_games import

* update CHANGELOG.md

* fixes install-import.yml

* fixes version in docs

* moved deprecated tests out of the main test suite

* moves fixtures in the correct test suite

* installs libomp on macos runner (try bugfix)

* correct spelling

* removes libomp again

* moves os runs into individual workflows for easier debugging

* runs macOS on py3.13

* renames workflows

* installs libomp again on macOS

* downgraded to 3.11 and reinstall python

* try different uv version

* adds libomp

* changes skip to xfail in integration tests with wrong index/order combinations

* moves test out for debugging CI

* removes outdated test

* adds concurrency for quicker testsing

* re-adds randomly

* dont reset seed

* removed pytest-randomly again

* adds the tests back in

3 of 21 new or added lines in 19 files covered. (14.29%)

5536 of 7135 relevant lines covered (77.59%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.76
/src/shapiq/interaction_values.py
1
"""InteractionValues data-class, which is used to store the interaction scores."""
2

3
from __future__ import annotations
1✔
4

5
import contextlib
1✔
6
import copy
1✔
7
import json
1✔
8
import pickle
1✔
9
from dataclasses import dataclass
1✔
10
from pathlib import Path
1✔
11
from typing import TYPE_CHECKING
1✔
12
from warnings import warn
1✔
13

14
import numpy as np
1✔
15

16
from .game_theory.indices import (
1✔
17
    ALL_AVAILABLE_INDICES,
18
    index_generalizes_bv,
19
    index_generalizes_sv,
20
    is_empty_value_the_baseline,
21
    is_index_aggregated,
22
)
23
from .utils.errors import raise_deprecation_warning
1✔
24
from .utils.sets import generate_interaction_lookup
1✔
25

26
if TYPE_CHECKING:
1✔
27
    from collections.abc import Sequence
×
28
    from typing import Any
×
29

30
    from matplotlib.axes import Axes
×
31
    from matplotlib.figure import Figure
×
32

33
    from shapiq.typing import JSONType
×
34

35

36
SAVE_JSON_DEPRECATION_MSG = (
1✔
37
    "Saving InteractionValues not as a JSON file is deprecated. "
38
    "The parameters `as_pickle` and `as_npz` will be removed in the future. "
39
)
40

41

42
@dataclass
1✔
43
class InteractionValues:
1✔
44
    """This class contains the interaction values as estimated by an approximator.
45

46
    Attributes:
47
        values: The interaction values of the model in vectorized form.
48
        index: The interaction index estimated. All available indices are defined in
49
            ``ALL_AVAILABLE_INDICES``.
50
        max_order: The order of the approximation.
51
        n_players: The number of players.
52
        min_order: The minimum order of the approximation. Defaults to ``0``.
53
        interaction_lookup: A dictionary that maps interactions to their index in the values
54
            vector. If ``interaction_lookup`` is not provided, it is computed from the ``n_players``,
55
            ``min_order``, and `max_order` parameters. Defaults to ``None``.
56
        estimated: Whether the interaction values are estimated or not. Defaults to ``True``.
57
        estimation_budget: The budget used for the estimation. Defaults to ``None``.
58
        baseline_value: The value of the baseline interaction also known as 'empty prediction' or
59
            ``'empty value'`` since it denotes the value of the empty coalition (empty set). If not
60
            provided it is searched for in the values vector (raising an Error if not found).
61
            Defaults to ``None``.
62

63
    Raises:
64
        UserWarning: If the index is not a valid index as defined in ``ALL_AVAILABLE_INDICES``.
65
        TypeError: If the baseline value is not a number.
66

67
    """
68

69
    values: np.ndarray
1✔
70
    index: str
1✔
71
    max_order: int
1✔
72
    n_players: int
1✔
73
    min_order: int
1✔
74
    baseline_value: float
1✔
75
    interaction_lookup: dict[tuple[int, ...], int] = None  # type: ignore[assignment]
1✔
76
    estimated: bool = True
1✔
77
    estimation_budget: int = None  # type: ignore[assignment]
1✔
78

79
    def __post_init__(self) -> None:
1✔
80
        """Checks if the index is valid."""
81
        if self.index not in ALL_AVAILABLE_INDICES:
1✔
82
            warn(
1✔
83
                UserWarning(
84
                    f"Index {self.index} is not a valid index as defined in "
85
                    f"{ALL_AVAILABLE_INDICES}. This might lead to unexpected behavior.",
86
                ),
87
                stacklevel=2,
88
            )
89

90
        # set BV or SV if max_order is 1
91
        if self.max_order == 1:
1✔
92
            if index_generalizes_bv(self.index):
1✔
93
                self.index = "BV"
1✔
94
            if index_generalizes_sv(self.index):
1✔
95
                self.index = "SV"
1✔
96

97
        # populate interaction_lookup and reverse_interaction_lookup
98
        if self.interaction_lookup is None:
1✔
99
            self.interaction_lookup = generate_interaction_lookup(
1✔
100
                self.n_players,
101
                self.min_order,
102
                self.max_order,
103
            )
104

105
        if not isinstance(self.baseline_value, int | float):
1✔
106
            msg = f"Baseline value must be provided as a number. Got {self.baseline_value}."
1✔
107
            raise TypeError(msg)
1✔
108

109
        # check if () is in the interaction_lookup if min_order is 0 -> add it to the end
110
        if self.min_order == 0 and () not in self.interaction_lookup:
1✔
111
            self.interaction_lookup[()] = len(self.interaction_lookup)
1✔
112
            self.values = np.concatenate((self.values, np.array([self.baseline_value])))
1✔
113

114
    @property
1✔
115
    def dict_values(self) -> dict[tuple[int, ...], float]:
1✔
116
        """Getter for the dict directly mapping from all interactions to scores."""
117
        return {
1✔
118
            interaction: float(self.values[self.interaction_lookup[interaction]])
119
            for interaction in self.interaction_lookup
120
        }
121

122
    def to_json_file(
1✔
123
        self,
124
        path: Path,
125
        *,
126
        desc: str | None = None,
127
        created_from: object | None = None,
128
        **kwargs: JSONType,
129
    ) -> None:
130
        """Saves the InteractionValues object to a JSON file.
131

132
        Args:
133
            path: The path to the JSON file.
134
            desc: A description of the InteractionValues object. Defaults to ``None``.
135
            created_from: An object from which the InteractionValues object was created. Defaults to
136
                ``None``.
137
            **kwargs: Additional parameters to store in the metadata of the JSON file.
138
        """
139
        from shapiq.utils.saving import interactions_to_dict, make_file_metadata, save_json
1✔
140

141
        file_metadata = make_file_metadata(
1✔
142
            object_to_store=self,
143
            data_type="interaction_values",
144
            desc=desc,
145
            created_from=created_from,
146
            parameters=kwargs,
147
        )
148
        json_data = {
1✔
149
            **file_metadata,
150
            "metadata": {
151
                "n_players": self.n_players,
152
                "index": self.index,
153
                "max_order": self.max_order,
154
                "min_order": self.min_order,
155
                "estimated": self.estimated,
156
                "estimation_budget": self.estimation_budget,
157
                "baseline_value": self.baseline_value,
158
            },
159
            "data": interactions_to_dict(interactions=self.dict_values),
160
        }
161
        save_json(json_data, path)
1✔
162

163
    @classmethod
1✔
164
    def from_json_file(cls, path: Path) -> InteractionValues:
1✔
165
        """Loads an InteractionValues object from a JSON file.
166

167
        Args:
168
            path: The path to the JSON file. Note that the path must end with `'.json'`.
169

170
        Returns:
171
            The InteractionValues object loaded from the JSON file.
172

173
        Raises:
174
            ValueError: If the path does not end with `'.json'`.
175
        """
176
        from shapiq.utils.saving import dict_to_lookup_and_values
1✔
177

178
        if not path.name.endswith(".json"):
1✔
179
            msg = f"Path {path} does not end with .json. Cannot load InteractionValues."
×
180
            raise ValueError(msg)
×
181

182
        with path.open("r", encoding="utf-8") as file:
1✔
183
            json_data = json.load(file)
1✔
184

185
        metadata = json_data["metadata"]
1✔
186
        interaction_dict = json_data["data"]
1✔
187
        interaction_lookup, values = dict_to_lookup_and_values(interaction_dict)
1✔
188

189
        return cls(
1✔
190
            values=values,
191
            index=metadata["index"],
192
            max_order=metadata["max_order"],
193
            n_players=metadata["n_players"],
194
            min_order=metadata["min_order"],
195
            interaction_lookup=interaction_lookup,
196
            estimated=metadata["estimated"],
197
            estimation_budget=metadata["estimation_budget"],
198
            baseline_value=metadata["baseline_value"],
199
        )
200

201
    def sparsify(self, threshold: float = 1e-3) -> None:
1✔
202
        """Manually sets values close to zero actually to zero (removing values).
203

204
        Args:
205
            threshold: The threshold value below which interactions are zeroed out. Defaults to
206
                1e-3.
207

208
        """
209
        # find interactions to remove in self.values
210
        interactions_to_remove: set[int] = set(np.where(np.abs(self.values) < threshold)[0])
1✔
211
        new_values = np.delete(self.values, list(interactions_to_remove))
1✔
212
        new_interaction_lookup = {}
1✔
213
        for index, _interaction in enumerate(self.interaction_lookup):
1✔
214
            if index not in interactions_to_remove:
1✔
215
                interaction = tuple(sorted(_interaction))
1✔
216
                new_interaction_lookup[interaction] = len(new_interaction_lookup)
1✔
217
        self.values = new_values
1✔
218
        self.interaction_lookup = new_interaction_lookup
1✔
219

220
    def get_top_k_interactions(self, k: int) -> InteractionValues:
1✔
221
        """Returns the top k interactions.
222

223
        Args:
224
            k: The number of top interactions to return.
225

226
        Returns:
227
            The top k interactions as an InteractionValues object.
228

229
        """
230
        top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
1✔
231
        new_values = np.zeros(k, dtype=float)
1✔
232
        new_interaction_lookup = {}
1✔
233
        for interaction_pos, interaction in enumerate(self.interaction_lookup):
1✔
234
            if interaction_pos in top_k_indices:
1✔
235
                new_position = len(new_interaction_lookup)
1✔
236
                new_values[new_position] = float(self[interaction_pos])
1✔
237
                new_interaction_lookup[interaction] = new_position
1✔
238
        return InteractionValues(
1✔
239
            values=new_values,
240
            index=self.index,
241
            max_order=self.max_order,
242
            n_players=self.n_players,
243
            min_order=self.min_order,
244
            interaction_lookup=new_interaction_lookup,
245
            estimated=self.estimated,
246
            estimation_budget=self.estimation_budget,
247
            baseline_value=self.baseline_value,
248
        )
249

250
    def get_top_k(
1✔
251
        self, k: int, *, as_interaction_values: bool = True
252
    ) -> InteractionValues | tuple[dict, list[tuple]]:
253
        """Returns the top k interactions.
254

255
        Args:
256
            k: The number of top interactions to return.
257
            as_interaction_values: Whether to return the top `k` interactions as an InteractionValues
258
                object. Defaults to ``False``.
259

260
        Returns:
261
            The top k interactions as a dictionary and a sorted list of tuples.
262

263
        Examples:
264
            >>> interaction_values = InteractionValues(
265
            ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
266
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
267
            ...     index="SII",
268
            ...     max_order=2,
269
            ...     n_players=3,
270
            ...     min_order=1,
271
            ...     baseline_value=0.0,
272
            ... )
273
            >>> top_k_interactions, sorted_top_k_interactions = interaction_values.get_top_k(2, False)
274
            >>> top_k_interactions
275
            {(0, 2): 0.5, (1, 0): 0.6}
276
            >>> sorted_top_k_interactions
277
            [((1, 0), 0.6), ((0, 2), 0.5)]
278

279
        """
280
        if as_interaction_values:
1✔
281
            return self.get_top_k_interactions(k)
1✔
282
        top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
1✔
283
        top_k_interactions = {}
1✔
284
        for interaction, index in self.interaction_lookup.items():
1✔
285
            if index in top_k_indices:
1✔
286
                top_k_interactions[interaction] = self.values[index]
1✔
287
        sorted_top_k_interactions = [
1✔
288
            (interaction, top_k_interactions[interaction])
289
            for interaction in sorted(top_k_interactions, key=top_k_interactions.get, reverse=True)
290
        ]
291
        return top_k_interactions, sorted_top_k_interactions
1✔
292

293
    def __repr__(self) -> str:
1✔
294
        """Returns the representation of the InteractionValues object."""
295
        representation = "InteractionValues(\n"
1✔
296
        representation += (
1✔
297
            f"    index={self.index}, max_order={self.max_order}, min_order={self.min_order}"
298
            f", estimated={self.estimated}, estimation_budget={self.estimation_budget},\n"
299
            f"    n_players={self.n_players}, baseline_value={self.baseline_value}\n)"
300
        )
301
        return representation
1✔
302

303
    def __str__(self) -> str:
1✔
304
        """Returns the string representation of the InteractionValues object."""
305
        representation = self.__repr__()
1✔
306
        representation = representation[:-2]  # remove the last "\n)" and add values
1✔
307
        _, sorted_top_10_interactions = self.get_top_k(
1✔
308
            10, as_interaction_values=False
309
        )  # get top 10 interactions
310
        # add values to string representation
311
        representation += ",\n    Top 10 interactions:\n"
1✔
312
        for interaction, value in sorted_top_10_interactions:
1✔
313
            representation += f"        {interaction}: {value}\n"
1✔
314
        representation += ")"
1✔
315
        return representation
1✔
316

317
    def __len__(self) -> int:
1✔
318
        """Returns the length of the InteractionValues object."""
319
        return len(self.values)  # might better to return the theoretical no. of interactions
1✔
320

321
    def __iter__(self) -> np.nditer:
1✔
322
        """Returns an iterator over the values of the InteractionValues object."""
323
        return np.nditer(self.values)
1✔
324

325
    def __getitem__(self, item: int | tuple[int, ...]) -> float:
1✔
326
        """Returns the score for the given interaction.
327

328
        Args:
329
            item: The interaction as a tuple of integers for which to return the score. If ``item`` is
330
                an integer it serves as the index to the values vector.
331

332
        Returns:
333
            The interaction value. If the interaction is not present zero is returned.
334

335
        """
336
        if isinstance(item, int):
1✔
337
            return float(self.values[item])
1✔
338
        item = tuple(sorted(item))
1✔
339
        try:
1✔
340
            return float(self.values[self.interaction_lookup[item]])
1✔
341
        except KeyError:
1✔
342
            return 0.0
1✔
343

344
    def __setitem__(self, item: int | tuple[int, ...], value: float) -> None:
1✔
345
        """Sets the score for the given interaction.
346

347
        Args:
348
            item: The interaction as a tuple of integers for which to set the score. If ``item`` is an
349
                integer it serves as the index to the values vector.
350
            value: The value to set for the interaction.
351

352
        Raises:
353
            KeyError: If the interaction is not found in the InteractionValues object.
354

355
        """
356
        try:
1✔
357
            if isinstance(item, int):
1✔
358
                self.values[item] = value
1✔
359
            else:
360
                item = tuple(sorted(item))
1✔
361
                self.values[self.interaction_lookup[item]] = value
1✔
362
        except Exception as e:
1✔
363
            msg = f"Interaction {item} not found in the InteractionValues. Unable to set a value."
1✔
364
            raise KeyError(msg) from e
1✔
365

366
    def __eq__(self, other: object) -> bool:
1✔
367
        """Checks if two InteractionValues objects are equal.
368

369
        Args:
370
            other: The other InteractionValues object.
371

372
        Returns:
373
            True if the two objects are equal, False otherwise.
374

375
        """
376
        if not isinstance(other, InteractionValues):
1✔
377
            msg = "Cannot compare InteractionValues with other types."
1✔
378
            raise TypeError(msg)
1✔
379
        if (
1✔
380
            self.index != other.index
381
            or self.max_order != other.max_order
382
            or self.min_order != other.min_order
383
            or self.n_players != other.n_players
384
            or self.baseline_value != other.baseline_value
385
        ):
386
            return False
1✔
387
        if not np.allclose(self.values, other.values):
1✔
388
            return False
1✔
389
        return self.interaction_lookup == other.interaction_lookup
1✔
390

391
    def __ne__(self, other: object) -> bool:
1✔
392
        """Checks if two InteractionValues objects are not equal.
393

394
        Args:
395
            other: The other InteractionValues object.
396

397
        Returns:
398
            True if the two objects are not equal, False otherwise.
399

400
        """
401
        return not self.__eq__(other)
1✔
402

403
    def __hash__(self) -> int:
1✔
404
        """Returns the hash of the InteractionValues object."""
405
        return hash(
1✔
406
            (
407
                self.index,
408
                self.max_order,
409
                self.min_order,
410
                self.n_players,
411
                tuple(self.values.flatten()),
412
            ),
413
        )
414

415
    def __copy__(self) -> InteractionValues:
1✔
416
        """Returns a copy of the InteractionValues object."""
417
        return InteractionValues(
1✔
418
            values=copy.deepcopy(self.values),
419
            index=self.index,
420
            max_order=self.max_order,
421
            estimated=self.estimated,
422
            estimation_budget=self.estimation_budget,
423
            n_players=self.n_players,
424
            interaction_lookup=copy.deepcopy(self.interaction_lookup),
425
            min_order=self.min_order,
426
            baseline_value=self.baseline_value,
427
        )
428

429
    def __add__(self, other: InteractionValues | float) -> InteractionValues:
1✔
430
        """Adds two InteractionValues objects together or a scalar."""
431
        n_players, min_order, max_order = self.n_players, self.min_order, self.max_order
1✔
432
        if isinstance(other, InteractionValues):
1✔
433
            if self.index != other.index:  # different indices
1✔
434
                msg = (
1✔
435
                    f"Cannot add InteractionValues with different indices {self.index} and "
436
                    f"{other.index}."
437
                )
438
                raise ValueError(msg)
1✔
439
            if (
1✔
440
                self.interaction_lookup != other.interaction_lookup
441
                or self.n_players != other.n_players
442
                or self.min_order != other.min_order
443
                or self.max_order != other.max_order
444
            ):  # different interactions but addable
445
                interaction_lookup = {**self.interaction_lookup}
1✔
446
                position = len(self.interaction_lookup)
1✔
447
                values_to_add = []
1✔
448
                added_values = self.values.copy()
1✔
449
                for interaction in other.interaction_lookup:
1✔
450
                    if interaction not in interaction_lookup:
1✔
451
                        interaction_lookup[interaction] = position
1✔
452
                        position += 1
1✔
453
                        values_to_add.append(other[interaction])
1✔
454
                    else:
455
                        added_values[interaction_lookup[interaction]] += other[interaction]
1✔
456
                added_values = np.concatenate((added_values, np.asarray(values_to_add)))
1✔
457
                # adjust n_players, min_order, and max_order
458
                n_players = max(self.n_players, other.n_players)
1✔
459
                min_order = min(self.min_order, other.min_order)
1✔
460
                max_order = max(self.max_order, other.max_order)
1✔
461
                baseline_value = self.baseline_value + other.baseline_value
1✔
462
            else:  # basic case with same interactions
463
                added_values = self.values + other.values
1✔
464
                interaction_lookup = self.interaction_lookup
1✔
465
                baseline_value = self.baseline_value + other.baseline_value
1✔
466
        elif isinstance(other, int | float):
1✔
467
            added_values = self.values.copy() + other
1✔
468
            interaction_lookup = self.interaction_lookup.copy()
1✔
469
            baseline_value = self.baseline_value + other
1✔
470
        else:
471
            msg = f"Cannot add InteractionValues with object of type {type(other)}."
1✔
472
            raise TypeError(msg)
1✔
473
        return InteractionValues(
1✔
474
            values=added_values,
475
            index=self.index,
476
            max_order=max_order,
477
            n_players=n_players,
478
            min_order=min_order,
479
            interaction_lookup=interaction_lookup,
480
            estimated=self.estimated,
481
            estimation_budget=self.estimation_budget,
482
            baseline_value=baseline_value,
483
        )
484

485
    def __radd__(self, other: InteractionValues | float) -> InteractionValues:
1✔
486
        """Adds two InteractionValues objects together or a scalar."""
487
        return self.__add__(other)
1✔
488

489
    def __neg__(self) -> InteractionValues:
1✔
490
        """Negates the InteractionValues object."""
491
        return InteractionValues(
1✔
492
            values=-self.values,
493
            index=self.index,
494
            max_order=self.max_order,
495
            n_players=self.n_players,
496
            min_order=self.min_order,
497
            interaction_lookup=self.interaction_lookup,
498
            estimated=self.estimated,
499
            estimation_budget=self.estimation_budget,
500
            baseline_value=-self.baseline_value,
501
        )
502

503
    def __sub__(self, other: InteractionValues | float) -> InteractionValues:
1✔
504
        """Subtracts two InteractionValues objects or a scalar."""
505
        return self.__add__(-other)
1✔
506

507
    def __rsub__(self, other: InteractionValues | float) -> InteractionValues:
1✔
508
        """Subtracts two InteractionValues objects or a scalar."""
509
        return (-self).__add__(other)
1✔
510

511
    def __mul__(self, other: float) -> InteractionValues:
1✔
512
        """Multiplies an InteractionValues object by a scalar."""
513
        return InteractionValues(
1✔
514
            values=self.values * other,
515
            index=self.index,
516
            max_order=self.max_order,
517
            n_players=self.n_players,
518
            min_order=self.min_order,
519
            interaction_lookup=self.interaction_lookup,
520
            estimated=self.estimated,
521
            estimation_budget=self.estimation_budget,
522
            baseline_value=self.baseline_value * other,
523
        )
524

525
    def __rmul__(self, other: float) -> InteractionValues:
1✔
526
        """Multiplies an InteractionValues object by a scalar."""
527
        return self.__mul__(other)
1✔
528

529
    def __abs__(self) -> InteractionValues:
1✔
530
        """Returns the absolute values of the InteractionValues object."""
531
        return InteractionValues(
1✔
532
            values=np.abs(self.values),
533
            index=self.index,
534
            max_order=self.max_order,
535
            n_players=self.n_players,
536
            min_order=self.min_order,
537
            interaction_lookup=self.interaction_lookup,
538
            estimated=self.estimated,
539
            estimation_budget=self.estimation_budget,
540
            baseline_value=self.baseline_value,
541
        )
542

543
    def get_n_order_values(self, order: int) -> np.ndarray:
1✔
544
        """Returns the interaction values of a specific order as a numpy array.
545

546
        Note:
547
            Depending on the order and number of players the resulting array might be sparse and
548
            very large.
549

550
        Args:
551
            order: The order of the interactions to return.
552

553
        Returns:
554
            The interaction values of the specified order as a numpy array of shape ``(n_players,)``
555
            for order ``1`` and ``(n_players, n_players)`` for order ``2``, etc.
556

557
        Raises:
558
            ValueError: If the order is less than ``1``.
559

560
        """
561
        from itertools import permutations
1✔
562

563
        if order < 1:
1✔
564
            msg = "Order must be greater or equal to 1."
1✔
565
            raise ValueError(msg)
1✔
566
        values_shape = tuple([self.n_players] * order)
1✔
567
        values = np.zeros(values_shape, dtype=float)
1✔
568
        for interaction in self.interaction_lookup:
1✔
569
            if len(interaction) != order:
1✔
570
                continue
1✔
571
            # get all orderings of the interaction (e.g. (0, 1) and (1, 0) for interaction (0, 1))
572
            for perm in permutations(interaction):
1✔
573
                values[perm] = self[interaction]
1✔
574

575
        return values
1✔
576

577
    def get_n_order(
1✔
578
        self,
579
        order: int | None = None,
580
        min_order: int | None = None,
581
        max_order: int | None = None,
582
    ) -> InteractionValues:
583
        """Select particular order of interactions.
584

585
        Creates a new InteractionValues object containing only the interactions within the
586
        specified order range.
587

588
        You can specify:
589
            - `order`: to select interactions of a single specific order (e.g., all pairwise
590
                interactions).
591
            - `min_order` and/or `max_order`: to select a range of interaction orders.
592
            - If `order` and `min_order`/`max_order` are both set, `min_order` and `max_order` will
593
                override the `order` value.
594

595
        Example:
596
            >>> interaction_values = InteractionValues(
597
            ...     values=np.array([1, 2, 3, 4, 5, 6, 7]),
598
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5, (0, 1, 2): 6},
599
            ...     index="SII",
600
            ...     max_order=3,
601
            ...     n_players=3,
602
            ...     min_order=1,
603
            ...     baseline_value=0.0,
604
            ... )
605
            >>> interaction_values.get_n_order(order=1).dict_values
606
            {(0,): 1.0, (1,): 2.0, (2,): 3.0}
607
            >>> interaction_values.get_n_order(min_order=1, max_order=2).dict_values
608
            {(0,): 1.0, (1,): 2.0, (2,): 3.0, (0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0}
609
            >>> interaction_values.get_n_order(min_order=2).dict_values
610
            {(0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0, (0, 1, 2): 7.0}
611

612
        Args:
613
            order: The order of the interactions to return. Defaults to ``None`` which requires
614
                ``min_order`` or ``max_order`` to be set.
615
            min_order: The minimum order of the interactions to return. Defaults to ``None`` which
616
                sets it to the order.
617
            max_order: The maximum order of the interactions to return. Defaults to ``None`` which
618
                sets it to the order.
619

620
        Returns:
621
            The interaction values of the specified order.
622

623
        Raises:
624
            ValueError: If all three parameters are set to ``None``.
625
        """
626
        if order is None and min_order is None and max_order is None:
1✔
627
            msg = "Either order, min_order or max_order must be set."
1✔
628
            raise ValueError(msg)
1✔
629

630
        if order is not None:
1✔
631
            max_order = order if max_order is None else max_order
1✔
632
            min_order = order if min_order is None else min_order
1✔
633
        else:  # order is None
634
            min_order = self.min_order if min_order is None else min_order
1✔
635
            max_order = self.max_order if max_order is None else max_order
1✔
636

637
        if min_order > max_order:
1✔
638
            msg = f"min_order ({min_order}) must be less than or equal to max_order ({max_order})."
1✔
639
            raise ValueError(msg)
1✔
640

641
        new_values = []
1✔
642
        new_interaction_lookup = {}
1✔
643
        for interaction in self.interaction_lookup:
1✔
644
            if len(interaction) < min_order or len(interaction) > max_order:
1✔
645
                continue
1✔
646
            interaction_idx = len(new_interaction_lookup)
1✔
647
            new_values.append(self[interaction])
1✔
648
            new_interaction_lookup[interaction] = interaction_idx
1✔
649

650
        return InteractionValues(
1✔
651
            values=np.array(new_values),
652
            index=self.index,
653
            max_order=max_order,
654
            n_players=self.n_players,
655
            min_order=min_order,
656
            interaction_lookup=new_interaction_lookup,
657
            estimated=self.estimated,
658
            estimation_budget=self.estimation_budget,
659
            baseline_value=self.baseline_value,
660
        )
661

662
    def get_subset(self, players: list[int]) -> InteractionValues:
1✔
663
        """Selects a subset of players from the InteractionValues object.
664

665
        Args:
666
            players: List of players to select from the InteractionValues object.
667

668
        Returns:
669
            InteractionValues: Filtered InteractionValues object containing only values related to
670
            selected players.
671

672
        Example:
673
            >>> interaction_values = InteractionValues(
674
            ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
675
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
676
            ...     index="SII",
677
            ...     max_order=2,
678
            ...     n_players=3,
679
            ...     min_order=1,
680
            ...     baseline_value=0.0,
681
            ... )
682
            >>> interaction_values.get_subset([0, 1]).dict_values
683
            {(0,): 0.1, (1,): 0.2, (0, 1): 0.3}
684
            >>> interaction_values.get_subset([0, 2]).dict_values
685
            {(0,): 0.1, (2,): 0.3, (0, 2): 0.4}
686
            >>> interaction_values.get_subset([1]).dict_values
687
            {(1,): 0.2}
688

689
        """
690
        keys = self.interaction_lookup.keys()
1✔
691
        idx, keys_in_subset = [], []
1✔
692
        for i, key in enumerate(keys):
1✔
693
            if all(p in players for p in key):
1✔
694
                idx.append(i)
1✔
695
                keys_in_subset.append(key)
1✔
696
        new_values = self.values[idx]
1✔
697
        new_interaction_lookup = {key: index for index, key in enumerate(keys_in_subset)}
1✔
698
        n_players = self.n_players - len(players)
1✔
699
        return InteractionValues(
1✔
700
            values=new_values,
701
            index=self.index,
702
            max_order=self.max_order,
703
            n_players=n_players,
704
            min_order=self.min_order,
705
            interaction_lookup=new_interaction_lookup,
706
            estimated=self.estimated,
707
            estimation_budget=self.estimation_budget,
708
            baseline_value=self.baseline_value,
709
        )
710

711
    def save(self, path: Path, *, as_pickle: bool = False, as_npz: bool = False) -> None:
1✔
712
        """Save the InteractionValues object to a file.
713

714
        By default, the InteractionValues object is saved as a JSON file.
715

716
        Args:
717
            path: The path to save the InteractionValues object to.
718
            as_pickle: Whether to save the InteractionValues object as a pickle file (``True``).
719
            as_npz: Whether to save the InteractionValues object as a ``npz`` file (``True``).
720

721
        Raises:
722
            DeprecationWarning: If `as_pickle` or `as_npz` is set to ``True``, a deprecation
723
                warning is raised
724
        """
725
        # check if the directory exists
726
        directory = Path(path).parent
1✔
727
        if not Path(directory).exists():
1✔
728
            with contextlib.suppress(FileNotFoundError):
×
729
                Path(directory).mkdir(parents=True, exist_ok=True)
×
730
        if as_pickle:
1✔
731
            raise_deprecation_warning(
1✔
732
                message=SAVE_JSON_DEPRECATION_MSG, deprecated_in="1.3.1", removed_in="1.4.0"
733
            )
734
            with Path(path).open("wb") as file:
1✔
735
                pickle.dump(self, file)
1✔
736
        elif as_npz:
1✔
737
            raise_deprecation_warning(
1✔
738
                message=SAVE_JSON_DEPRECATION_MSG, deprecated_in="1.3.1", removed_in="1.4.0"
739
            )
740
            # save object as npz file
741
            np.savez(
1✔
742
                path,
743
                values=self.values,
744
                index=self.index,
745
                max_order=self.max_order,
746
                n_players=self.n_players,
747
                min_order=self.min_order,
748
                interaction_lookup=self.interaction_lookup,
749
                estimated=self.estimated,
750
                estimation_budget=self.estimation_budget,
751
                baseline_value=self.baseline_value,
752
            )
753
        else:
754
            self.to_json_file(path)
1✔
755

756
    @classmethod
1✔
757
    def load(cls, path: Path | str) -> InteractionValues:
1✔
758
        """Load an InteractionValues object from a file.
759

760
        Args:
761
            path: The path to load the InteractionValues object from.
762

763
        Returns:
764
            The loaded InteractionValues object.
765

766
        """
767
        path = Path(path)
1✔
768
        # check if path ends with .json
769
        if path.name.endswith(".json"):
1✔
770
            return cls.from_json_file(path)
1✔
771

772
        raise_deprecation_warning(
1✔
773
            SAVE_JSON_DEPRECATION_MSG, deprecated_in="1.3.1", removed_in="1.4.0"
774
        )
775

776
        # try loading as npz file
777
        if path.name.endswith(".npz"):
1✔
778
            data = np.load(path, allow_pickle=True)
1✔
779
            return InteractionValues(
1✔
780
                values=data["values"],
781
                index=str(data["index"]),
782
                max_order=int(data["max_order"]),
783
                n_players=int(data["n_players"]),
784
                min_order=int(data["min_order"]),
785
                interaction_lookup=data["interaction_lookup"].item(),
786
                estimated=bool(data["estimated"]),
787
                estimation_budget=data["estimation_budget"].item(),
788
                baseline_value=float(data["baseline_value"]),
789
            )
NEW
790
        msg = f"Path {path} does not end with .json or .npz. Cannot load InteractionValues."
×
NEW
791
        raise ValueError(msg)
×
792

793
    @classmethod
1✔
794
    def from_dict(cls, data: dict[str, Any]) -> InteractionValues:
1✔
795
        """Create an InteractionValues object from a dictionary.
796

797
        Args:
798
            data: The dictionary containing the data to create the InteractionValues object from.
799

800
        Returns:
801
            The InteractionValues object created from the dictionary.
802

803
        """
804
        return cls(
1✔
805
            values=data["values"],
806
            index=data["index"],
807
            max_order=data["max_order"],
808
            n_players=data["n_players"],
809
            min_order=data["min_order"],
810
            interaction_lookup=data["interaction_lookup"],
811
            estimated=data["estimated"],
812
            estimation_budget=data["estimation_budget"],
813
            baseline_value=data["baseline_value"],
814
        )
815

816
    def to_dict(self) -> dict:
1✔
817
        """Convert the InteractionValues object to a dictionary.
818

819
        Returns:
820
            The InteractionValues object as a dictionary.
821

822
        """
823
        return {
1✔
824
            "values": self.values,
825
            "index": self.index,
826
            "max_order": self.max_order,
827
            "n_players": self.n_players,
828
            "min_order": self.min_order,
829
            "interaction_lookup": self.interaction_lookup,
830
            "estimated": self.estimated,
831
            "estimation_budget": self.estimation_budget,
832
            "baseline_value": self.baseline_value,
833
        }
834

835
    def aggregate(
1✔
836
        self,
837
        others: Sequence[InteractionValues],
838
        aggregation: str = "mean",
839
    ) -> InteractionValues:
840
        """Aggregates InteractionValues objects using a specific aggregation method.
841

842
        Args:
843
            others: A list of InteractionValues objects to aggregate.
844
            aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
845
                ``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
846

847
        Returns:
848
            The aggregated InteractionValues object.
849

850
        Note:
851
            For documentation on the aggregation methods, see the ``aggregate_interaction_values()``
852
            function.
853

854
        """
855
        return aggregate_interaction_values([self, *others], aggregation)
1✔
856

857
    def plot_network(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
858
        """Visualize InteractionValues on a graph.
859

860
        Note:
861
            For arguments, see :func:`shapiq.plot.network.network_plot` and
862
                :func:`shapiq.plot.si_graph.si_graph_plot`.
863

864
        Args:
865
            show: Whether to show the plot. Defaults to ``True``.
866

867
            **kwargs: Additional keyword arguments to pass to the plotting function.
868

869
        Returns:
870
            If show is ``False``, the function returns a tuple with the figure and the axis of the
871
                plot.
872
        """
873
        from shapiq.plot.network import network_plot
1✔
874

875
        if self.max_order > 1:
1✔
876
            return network_plot(
1✔
877
                interaction_values=self,
878
                show=show,
879
                **kwargs,
880
            )
881
        msg = (
1✔
882
            "InteractionValues contains only 1-order values,"
883
            "but requires also 2-order values for the network plot."
884
        )
885
        raise ValueError(msg)
1✔
886

887
    def plot_si_graph(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
888
        """Visualize InteractionValues as a SI graph.
889

890
        For arguments, see shapiq.plots.si_graph_plot().
891

892
        Returns:
893
            The SI graph as a tuple containing the figure and the axes.
894

895
        """
896
        from shapiq.plot.si_graph import si_graph_plot
1✔
897

898
        return si_graph_plot(self, show=show, **kwargs)
1✔
899

900
    def plot_stacked_bar(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
901
        """Visualize InteractionValues on a graph.
902

903
        For arguments, see shapiq.plots.stacked_bar_plot().
904

905
        Returns:
906
            The stacked bar plot as a tuple containing the figure and the axes.
907

908
        """
909
        from shapiq import stacked_bar_plot
1✔
910

911
        return stacked_bar_plot(self, show=show, **kwargs)
1✔
912

913
    def plot_force(
1✔
914
        self,
915
        feature_names: np.ndarray | None = None,
916
        *,
917
        show: bool = True,
918
        abbreviate: bool = True,
919
        contribution_threshold: float = 0.05,
920
    ) -> Figure | None:
921
        """Visualize InteractionValues on a force plot.
922

923
        For arguments, see shapiq.plots.force_plot().
924

925
        Args:
926
            feature_names: The feature names used for plotting. If no feature names are provided, the
927
                feature indices are used instead. Defaults to ``None``.
928
            show: Whether to show the plot. Defaults to ``False``.
929
            abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
930
            contribution_threshold: The threshold for contributions to be displayed in percent.
931
                Defaults to ``0.05``.
932

933
        Returns:
934
            The force plot as a matplotlib figure (if show is ``False``).
935

936
        """
937
        from .plot import force_plot
1✔
938

939
        return force_plot(
1✔
940
            self,
941
            feature_names=feature_names,
942
            show=show,
943
            abbreviate=abbreviate,
944
            contribution_threshold=contribution_threshold,
945
        )
946

947
    def plot_waterfall(
1✔
948
        self,
949
        feature_names: np.ndarray | None = None,
950
        *,
951
        show: bool = True,
952
        abbreviate: bool = True,
953
        max_display: int = 10,
954
    ) -> Axes | None:
955
        """Draws interaction values on a waterfall plot.
956

957
        Note:
958
            Requires the ``shap`` Python package to be installed.
959

960
        Args:
961
            feature_names: The feature names used for plotting. If no feature names are provided, the
962
                feature indices are used instead. Defaults to ``None``.
963
            show: Whether to show the plot. Defaults to ``False``.
964
            abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
965
            max_display: The maximum number of interactions to display. Defaults to ``10``.
966
        """
967
        from shapiq import waterfall_plot
1✔
968

969
        return waterfall_plot(
1✔
970
            self,
971
            feature_names=feature_names,
972
            show=show,
973
            max_display=max_display,
974
            abbreviate=abbreviate,
975
        )
976

977
    def plot_sentence(
1✔
978
        self,
979
        words: list[str],
980
        *,
981
        show: bool = True,
982
        **kwargs: Any,
983
    ) -> tuple[Figure, Axes] | None:
984
        """Plots the first order effects (attributions) of a sentence or paragraph.
985

986
        For arguments, see shapiq.plots.sentence_plot().
987

988
        Returns:
989
            If ``show`` is ``True``, the function returns ``None``. Otherwise, it returns a tuple
990
            with the figure and the axis of the plot.
991

992
        """
993
        from shapiq.plot.sentence import sentence_plot
1✔
994

995
        return sentence_plot(self, words, show=show, **kwargs)
1✔
996

997
    def plot_upset(self, *, show: bool = True, **kwargs: Any) -> Figure | None:
1✔
998
        """Plots the upset plot.
999

1000
        For arguments, see shapiq.plot.upset_plot().
1001

1002
        Returns:
1003
            The upset plot as a matplotlib figure (if show is ``False``).
1004

1005
        """
1006
        from shapiq.plot.upset import upset_plot
1✔
1007

1008
        return upset_plot(self, show=show, **kwargs)
1✔
1009

1010

1011
def aggregate_interaction_values(
1✔
1012
    interaction_values: Sequence[InteractionValues],
1013
    aggregation: str = "mean",
1014
) -> InteractionValues:
1015
    """Aggregates InteractionValues objects using a specific aggregation method.
1016

1017
    Args:
1018
        interaction_values: A list of InteractionValues objects to aggregate.
1019
        aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
1020
            ``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
1021

1022
    Returns:
1023
        The aggregated InteractionValues object.
1024

1025
    Example:
1026
        >>> iv1 = InteractionValues(
1027
        ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
1028
        ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
1029
        ...     index="SII",
1030
        ...     max_order=2,
1031
        ...     n_players=3,
1032
        ...     min_order=1,
1033
        ...     baseline_value=0.0,
1034
        ... )
1035
        >>> iv2 = InteractionValues(
1036
        ...     values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]),  # this iv is missing the (1, 2) value
1037
        ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4},  # no (1, 2)
1038
        ...     index="SII",
1039
        ...     max_order=2,
1040
        ...     n_players=3,
1041
        ...     min_order=1,
1042
        ...     baseline_value=1.0,
1043
        ... )
1044
        >>> aggregate_interaction_values([iv1, iv2], "mean")
1045
        InteractionValues(
1046
            index=SII, max_order=2, min_order=1, estimated=True, estimation_budget=None,
1047
            n_players=3, baseline_value=0.5,
1048
            Top 10 interactions:
1049
                (1, 2): 0.60
1050
                (0, 2): 0.35
1051
                (0, 1): 0.25
1052
                (0,): 0.15
1053
                (1,): 0.25
1054
                (2,): 0.35
1055
        )
1056

1057
    Note:
1058
        The index of the aggregated InteractionValues object is set to the index of the first
1059
        InteractionValues object in the list.
1060

1061
    Raises:
1062
        ValueError: If the aggregation method is not supported.
1063

1064
    """
1065

1066
    def _aggregate(vals: list[float], method: str) -> float:
1✔
1067
        """Does the actual aggregation of the values."""
1068
        if method == "mean":
1✔
1069
            return float(np.mean(vals))
1✔
1070
        if method == "median":
1✔
1071
            return float(np.median(vals))
1✔
1072
        if method == "sum":
1✔
1073
            return np.sum(vals)
1✔
1074
        if method == "max":
1✔
1075
            return np.max(vals)
1✔
1076
        if method == "min":
1✔
1077
            return np.min(vals)
1✔
1078
        msg = f"Aggregation method {method} is not supported."
1✔
1079
        raise ValueError(msg)
1✔
1080

1081
    # get all keys from all InteractionValues objects
1082
    all_keys = set()
1✔
1083
    for iv in interaction_values:
1✔
1084
        all_keys.update(iv.interaction_lookup.keys())
1✔
1085
    all_keys = sorted(all_keys)
1✔
1086

1087
    # aggregate the values
1088
    new_values = np.zeros(len(all_keys), dtype=float)
1✔
1089
    new_lookup = {}
1✔
1090
    for i, key in enumerate(all_keys):
1✔
1091
        new_lookup[key] = i
1✔
1092
        values = [iv[key] for iv in interaction_values]
1✔
1093
        new_values[i] = _aggregate(values, aggregation)
1✔
1094

1095
    max_order = max([iv.max_order for iv in interaction_values])
1✔
1096
    min_order = min([iv.min_order for iv in interaction_values])
1✔
1097
    n_players = max([iv.n_players for iv in interaction_values])
1✔
1098
    baseline_value = _aggregate([iv.baseline_value for iv in interaction_values], aggregation)
1✔
1099
    estimation_budget = interaction_values[0].estimation_budget
1✔
1100

1101
    return InteractionValues(
1✔
1102
        values=new_values,
1103
        index=interaction_values[0].index,
1104
        max_order=max_order,
1105
        n_players=n_players,
1106
        min_order=min_order,
1107
        interaction_lookup=new_lookup,
1108
        estimated=True,
1109
        estimation_budget=estimation_budget,
1110
        baseline_value=baseline_value,
1111
    )
1112

1113

1114
def finalize_computed_interactions(
1✔
1115
    interactions: InteractionValues,
1116
    target_index: str | None = None,
1117
) -> InteractionValues:
1118
    """Finalizes computed InteractionValues to be interpretable.
1119

1120
    This function takes care of the following:
1121
        - Aggregates the interactions if necessary. (e.g. from SII to k-SII)
1122
        - Adjusts the baseline and empty value if necessary. (e.g. for Shapley indices the baseline
1123
            value is the prediction of the model without any features - also called empty value, for
1124
            Banzhaf the baseline value is not the empty prediction as Banzhaf does not fulfill the
1125
            efficiency property)
1126

1127
    Args:
1128
        interactions: The InteractionValues to finalize.
1129
        target_index: The index to which the InteractionValues should be finalized. Defaults to
1130
            ``None`` which means that the InteractionValues are finalized to the index of the
1131
            InteractionValues object.
1132

1133
    Returns:
1134
        The interaction values.
1135

1136
    Note:
1137
        If you develop new approximators and computation methods, you should finalize the
1138
        InteractionValues object before returning it to the user.
1139

1140
    Raises:
1141
        ValueError: If the baseline value is not provided for SII and k-SII.
1142

1143
    """
1144
    from .game_theory.aggregation import aggregate_base_interaction
1✔
1145

1146
    if target_index is None:
1✔
1147
        target_index = interactions.index
1✔
1148

1149
    # aggregate the interactions if necessary
1150
    if is_index_aggregated(target_index) and target_index != interactions.index:
1✔
1151
        interactions = aggregate_base_interaction(interactions)
1✔
1152

1153
    # adjust the baseline and empty value if necessary
1154
    if () in interactions.interaction_lookup:
1✔
1155
        idx = interactions.interaction_lookup[()]
1✔
1156
        empty_value = interactions[idx]
1✔
1157
        if empty_value != interactions.baseline_value and interactions.index != "SII":
1✔
1158
            if is_empty_value_the_baseline(interactions.index):
1✔
1159
                # insert the empty value given in baseline into the values
1160
                interactions[idx] = interactions.baseline_value
1✔
1161
            else:  # manually set baseline to the empty value
1162
                interactions.baseline_value = interactions[idx]
1✔
1163
    # empty not in interactions but min_order is 0 (should be in the interactions)
1164
    elif interactions.min_order == 0:
1✔
1165
        # TODO(mmshlk): this might not be what we really want to do always: what if empty and baseline are different?
1166
        # https://github.com/mmschlk/shapiq/issues/385
1167
        interactions.interaction_lookup[()] = len(interactions.interaction_lookup)
1✔
1168
        interactions.values = np.concatenate(
1✔
1169
            (interactions.values, np.array([interactions.baseline_value])),
1170
        )
1171

1172
    return interactions
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc