• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 16647002974

31 Jul 2025 10:52AM UTC coverage: 93.906% (+0.005%) from 93.901%
16647002974

push

github

mmschlk
fix: Remove defense checks

4916 of 5235 relevant lines covered (93.91%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.94
/src/shapiq/interaction_values.py
1
"""InteractionValues data-class, which is used to store the interaction scores."""
2

3
from __future__ import annotations
1✔
4

5
import contextlib
1✔
6
import copy
1✔
7
import json
1✔
8
import pickle
1✔
9
from pathlib import Path
1✔
10
from typing import TYPE_CHECKING
1✔
11
from warnings import warn
1✔
12

13
import numpy as np
1✔
14

15
from .game_theory.indices import (
1✔
16
    ALL_AVAILABLE_INDICES,
17
    get_index_from_computation_index,
18
    is_empty_value_the_baseline,
19
    is_index_aggregated,
20
    is_index_valid,
21
)
22
from .utils.errors import raise_deprecation_warning
1✔
23
from .utils.sets import generate_interaction_lookup
1✔
24

25
if TYPE_CHECKING:
1✔
26
    from collections.abc import Sequence
×
27
    from typing import Any
×
28

29
    from matplotlib.axes import Axes
×
30
    from matplotlib.figure import Figure
×
31

32
    from shapiq.typing import JSONType
×
33

34

35
SAVE_JSON_DEPRECATION_MSG = (
1✔
36
    "Saving InteractionValues not as a JSON file is deprecated. "
37
    "The parameters `as_pickle` and `as_npz` will be removed in the future. "
38
)
39

40

41
class InteractionValues:
1✔
42
    """This class contains the interaction values as estimated by an approximator.
43

44
    Attributes:
45
        values: The interaction values of the model in vectorized form.
46
        index: The interaction index estimated. All available indices are defined in
47
            ``ALL_AVAILABLE_INDICES``.
48
        max_order: The order of the approximation.
49
        n_players: The number of players.
50
        min_order: The minimum order of the approximation. Defaults to ``0``.
51
        interaction_lookup: A dictionary that maps interactions to their index in the values
52
            vector. If ``interaction_lookup`` is not provided, it is computed from the ``n_players``,
53
            ``min_order``, and `max_order` parameters. Defaults to ``None``.
54
        estimated: Whether the interaction values are estimated or not. Defaults to ``True``.
55
        estimation_budget: The budget used for the estimation. Defaults to ``None``.
56
        baseline_value: The value of the baseline interaction also known as 'empty prediction' or
57
            ``'empty value'`` since it denotes the value of the empty coalition (empty set). If not
58
            provided it is searched for in the values vector (raising an Error if not found).
59
            Defaults to ``None``.
60

61
    Raises:
62
        UserWarning: If the index is not a valid index as defined in ``ALL_AVAILABLE_INDICES``.
63
        TypeError: If the baseline value is not a number.
64

65
    """
66

67
    interactions: dict[tuple[int, ...], float]
1✔
68
    """The interactions as a dictionary mapping interactions to their values."""
1✔
69

70
    def __init__(
1✔
71
        self,
72
        values: np.ndarray | dict[tuple[int, ...], float],
73
        *,
74
        index: str,
75
        max_order: int,
76
        n_players: int,
77
        min_order: int,
78
        interaction_lookup: dict[tuple[int, ...], int] | None = None,  # type: ignore[assignment]
79
        estimated: bool = True,
80
        estimation_budget: int | None = None,  # type: ignore[assignment]
81
        baseline_value: float | np.number = 0.0,
82
        target_index: str | None = None,
83
    ) -> None:
84
        """Initialize the InteractionValues object.
85

86
        Args:
87
            values: The interaction values as a numpy array or a dictionary mapping interactions to their
88
                values.
89

90
            index: The index of the interaction values. This should be one of the indices defined in
91
            ALL_AVAILABLE_INDICES. It is used to determine how the interaction values are interpreted.
92
            max_order: The maximum order of the interactions.
93
            n_players: The number of players in the game.
94
            min_order: The minimum order of the interactions. Defaults to 0.
95
            interaction_lookup: A dictionary mapping interactions to their index in the values vector.
96
            Defaults to None, which means it will be generated from the n_players, min_order, and max_order parameters.
97
            estimated: Whether the interaction values are estimated or not. Defaults to True.
98
            estimation_budget: The budget used for the estimation. Defaults to None.
99
            baseline_value: The baseline value of the interaction values, also known as the empty prediction or empty value.
100
            target_index: The index to which the InteractionValues should be finalized. Defaults to None, which means that
101
            target_index = index
102
        """
103
        if not isinstance(baseline_value, (int | float | np.number)):
1✔
104
            msg = f"Baseline value must be provided as a number. Got {type(baseline_value)}."
1✔
105
            raise TypeError(msg)
1✔
106
        self.baseline_value = baseline_value
1✔
107
        if not is_index_valid(index, raise_error=False):
1✔
108
            warn(
1✔
109
                f"Index `{index}` is not a valid interaction index. "
110
                f"Valid indices are: {', '.join(ALL_AVAILABLE_INDICES)}.",
111
                stacklevel=2,
112
            )
113
        index = get_index_from_computation_index(index, max_order)
1✔
114
        if target_index is None:
1✔
115
            target_index = index
1✔
116

117
        interactions = _validate_and_return_interactions(
1✔
118
            values=values,
119
            interaction_lookup=interaction_lookup,
120
            n_players=n_players,
121
            min_order=min_order,
122
            max_order=max_order,
123
            baseline_value=baseline_value,
124
        )
125

126
        interactions, index, min_order, baseline_value = _update_interactions_for_index(
1✔
127
            interactions=interactions,
128
            index=index,
129
            target_index=target_index,
130
            min_order=min_order,
131
            max_order=max_order,
132
            baseline_value=baseline_value,
133
        )
134

135
        self.interactions = interactions
1✔
136
        self.index = index
1✔
137
        self.max_order = max_order
1✔
138
        self.n_players = n_players
1✔
139
        self.min_order = min_order
1✔
140
        self.estimated = estimated
1✔
141
        self.estimation_budget = estimation_budget
1✔
142

143
    @property
1✔
144
    def dict_values(self) -> dict[tuple[int, ...], float]:
1✔
145
        """Getter for the dict directly mapping from all interactions to scores."""
146
        return self.interactions
1✔
147

148
    @property
1✔
149
    def values(self) -> np.ndarray:
1✔
150
        """Getter for the values of the InteractionValues object.
151

152
        Returns:
153
            The values of the InteractionValues object as a numpy array.
154

155
        """
156
        return np.array(list(self.interactions.values()))
1✔
157

158
    @property
1✔
159
    def interaction_lookup(self) -> dict[tuple[int, ...], int]:
1✔
160
        """Getter for the interaction lookup of the InteractionValues object.
161

162
        Returns:
163
            The interaction lookup of the InteractionValues object as a dictionary mapping interactions
164
            to their index in the values vector.
165

166
        """
167
        return {
1✔
168
            interaction: index for index, (interaction, _) in enumerate(self.interactions.items())
169
        }
170

171
    def to_json_file(
1✔
172
        self,
173
        path: Path,
174
        *,
175
        desc: str | None = None,
176
        created_from: object | None = None,
177
        **kwargs: JSONType,
178
    ) -> None:
179
        """Saves the InteractionValues object to a JSON file.
180

181
        Args:
182
            path: The path to the JSON file.
183
            desc: A description of the InteractionValues object. Defaults to ``None``.
184
            created_from: An object from which the InteractionValues object was created. Defaults to
185
                ``None``.
186
            **kwargs: Additional parameters to store in the metadata of the JSON file.
187
        """
188
        from shapiq.utils.saving import (
1✔
189
            interactions_to_dict,
190
            make_file_metadata,
191
            save_json,
192
        )
193

194
        file_metadata = make_file_metadata(
1✔
195
            object_to_store=self,
196
            data_type="interaction_values",
197
            desc=desc,
198
            created_from=created_from,
199
            parameters=kwargs,
200
        )
201
        json_data = {
1✔
202
            **file_metadata,
203
            "metadata": {
204
                "n_players": self.n_players,
205
                "index": self.index,
206
                "max_order": self.max_order,
207
                "min_order": self.min_order,
208
                "estimated": self.estimated,
209
                "estimation_budget": self.estimation_budget,
210
                "baseline_value": self.baseline_value,
211
            },
212
            "data": interactions_to_dict(interactions=self.dict_values),
213
        }
214
        save_json(json_data, path)
1✔
215

216
    @classmethod
1✔
217
    def from_json_file(cls, path: Path) -> InteractionValues:
1✔
218
        """Loads an InteractionValues object from a JSON file.
219

220
        Args:
221
            path: The path to the JSON file. Note that the path must end with `'.json'`.
222

223
        Returns:
224
            The InteractionValues object loaded from the JSON file.
225

226
        Raises:
227
            ValueError: If the path does not end with `'.json'`.
228
        """
229
        from shapiq.utils.saving import dict_to_lookup_and_values
1✔
230

231
        if not path.name.endswith(".json"):
1✔
232
            msg = f"Path {path} does not end with .json. Cannot load InteractionValues."
×
233
            raise ValueError(msg)
×
234

235
        with path.open("r", encoding="utf-8") as file:
1✔
236
            json_data = json.load(file)
1✔
237

238
        metadata = json_data["metadata"]
1✔
239
        interaction_dict = json_data["data"]
1✔
240
        interaction_lookup, values = dict_to_lookup_and_values(interaction_dict)
1✔
241

242
        return cls(
1✔
243
            values=values,
244
            index=metadata["index"],
245
            max_order=metadata["max_order"],
246
            n_players=metadata["n_players"],
247
            min_order=metadata["min_order"],
248
            interaction_lookup=interaction_lookup,
249
            estimated=metadata["estimated"],
250
            estimation_budget=metadata["estimation_budget"],
251
            baseline_value=metadata["baseline_value"],
252
        )
253

254
    def sparsify(self, threshold: float = 1e-3) -> None:
1✔
255
        """Manually sets values close to zero actually to zero (removing values).
256

257
        Args:
258
            threshold: The threshold value below which interactions are zeroed out. Defaults to
259
                1e-3.
260

261
        """
262
        # find interactions to remove in self.interactions
263
        sparse_interactions = copy.deepcopy(self.interactions)
1✔
264
        for interaction, value in self.interactions.items():
1✔
265
            if np.abs(value) < threshold:
1✔
266
                del sparse_interactions[interaction]
1✔
267
        self.interactions = sparse_interactions
1✔
268

269
    def get_top_k_interactions(self, k: int) -> InteractionValues:
1✔
270
        """Returns the top k interactions.
271

272
        Args:
273
            k: The number of top interactions to return.
274

275
        Returns:
276
            The top k interactions as an InteractionValues object.
277

278
        """
279
        top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
1✔
280
        new_values = np.zeros(k, dtype=float)
1✔
281
        new_interaction_lookup = {}
1✔
282
        for interaction_pos, interaction in enumerate(self.interaction_lookup):
1✔
283
            if interaction_pos in top_k_indices:
1✔
284
                new_position = len(new_interaction_lookup)
1✔
285
                new_values[new_position] = float(self[interaction_pos])
1✔
286
                new_interaction_lookup[interaction] = new_position
1✔
287
        return InteractionValues(
1✔
288
            values=new_values,
289
            index=self.index,
290
            max_order=self.max_order,
291
            n_players=self.n_players,
292
            min_order=self.min_order,
293
            interaction_lookup=new_interaction_lookup,
294
            estimated=self.estimated,
295
            estimation_budget=self.estimation_budget,
296
            baseline_value=self.baseline_value,
297
        )
298

299
    def get_top_k(
1✔
300
        self, k: int, *, as_interaction_values: bool = True
301
    ) -> InteractionValues | tuple[dict, list[tuple]]:
302
        """Returns the top k interactions.
303

304
        Args:
305
            k: The number of top interactions to return.
306
            as_interaction_values: Whether to return the top `k` interactions as an InteractionValues
307
                object. Defaults to ``False``.
308

309
        Returns:
310
            The top k interactions as a dictionary and a sorted list of tuples.
311

312
        Examples:
313
            >>> interaction_values = InteractionValues(
314
            ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
315
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
316
            ...     index="SII",
317
            ...     max_order=2,
318
            ...     n_players=3,
319
            ...     min_order=1,
320
            ...     baseline_value=0.0,
321
            ... )
322
            >>> top_k_interactions, sorted_top_k_interactions = interaction_values.get_top_k(2, False)
323
            >>> top_k_interactions
324
            {(0, 2): 0.5, (1, 0): 0.6}
325
            >>> sorted_top_k_interactions
326
            [((1, 0), 0.6), ((0, 2), 0.5)]
327

328
        """
329
        if as_interaction_values:
1✔
330
            return self.get_top_k_interactions(k)
1✔
331
        top_k_indices = np.argsort(np.abs(self.values))[::-1][:k]
1✔
332
        top_k_interactions = {}
1✔
333
        for interaction, index in self.interaction_lookup.items():
1✔
334
            if index in top_k_indices:
1✔
335
                top_k_interactions[interaction] = self.values[index]
1✔
336
        sorted_top_k_interactions = [
1✔
337
            (interaction, top_k_interactions[interaction])
338
            for interaction in sorted(top_k_interactions, key=top_k_interactions.get, reverse=True)
339
        ]
340
        return top_k_interactions, sorted_top_k_interactions
1✔
341

342
    def __repr__(self) -> str:
1✔
343
        """Returns the representation of the InteractionValues object."""
344
        representation = "InteractionValues(\n"
1✔
345
        representation += (
1✔
346
            f"    index={self.index}, max_order={self.max_order}, min_order={self.min_order}"
347
            f", estimated={self.estimated}, estimation_budget={self.estimation_budget},\n"
348
            f"    n_players={self.n_players}, baseline_value={self.baseline_value}\n)"
349
        )
350
        return representation
1✔
351

352
    def __str__(self) -> str:
1✔
353
        """Returns the string representation of the InteractionValues object."""
354
        representation = self.__repr__()
1✔
355
        representation = representation[:-2]  # remove the last "\n)" and add values
1✔
356
        _, sorted_top_10_interactions = self.get_top_k(
1✔
357
            10, as_interaction_values=False
358
        )  # get top 10 interactions
359
        # add values to string representation
360
        representation += ",\n    Top 10 interactions:\n"
1✔
361
        for interaction, value in sorted_top_10_interactions:
1✔
362
            representation += f"        {interaction}: {value}\n"
1✔
363
        representation += ")"
1✔
364
        return representation
1✔
365

366
    def __len__(self) -> int:
1✔
367
        """Returns the length of the InteractionValues object."""
368
        return len(self.values)  # might better to return the theoretical no. of interactions
1✔
369

370
    def __iter__(self) -> np.nditer:
1✔
371
        """Returns an iterator over the values of the InteractionValues object."""
372
        return np.nditer(self.values)
1✔
373

374
    def __getitem__(self, item: int | tuple[int, ...]) -> float:
1✔
375
        """Returns the score for the given interaction.
376

377
        Args:
378
            item: The interaction as a tuple of integers for which to return the score. If ``item`` is
379
                an integer it serves as the index to the values vector.
380

381
        Returns:
382
            The interaction value. If the interaction is not present zero is returned.
383

384
        """
385
        if isinstance(item, int):
1✔
386
            return float(self.values[item])
1✔
387
        item = tuple(sorted(item))
1✔
388
        try:
1✔
389
            return float(self.interactions[item])
1✔
390
        except KeyError:
1✔
391
            return 0.0
1✔
392

393
    def __setitem__(self, item: int | tuple[int, ...], value: float) -> None:
1✔
394
        """Sets the score for the given interaction.
395

396
        Args:
397
            item: The interaction as a tuple of integers for which to set the score. If ``item`` is an
398
                integer it serves as the index to the values vector.
399
            value: The value to set for the interaction.
400

401
        Raises:
402
            KeyError: If the interaction is not found in the InteractionValues object.
403

404
        """
405
        try:
1✔
406
            if isinstance(item, int):
1✔
407
                # dict.items() preserves the order of insertion, so we can use it to set the value
408
                for i, (interaction, _) in enumerate(self.interactions.items()):
1✔
409
                    if i == item:
1✔
410
                        self.interactions[interaction] = value
1✔
411
                        break
1✔
412
            else:
413
                item = tuple(sorted(item))
1✔
414
                if self.interactions[item] is not None:
1✔
415
                    # if the interaction is already present, update its value. Otherwise KeyError is raised
416
                    self.interactions[item] = value
1✔
417
        except Exception as e:
1✔
418
            msg = f"Interaction {item} not found in the InteractionValues. Unable to set a value."
1✔
419
            raise KeyError(msg) from e
1✔
420

421
    def __eq__(self, other: object) -> bool:
1✔
422
        """Checks if two InteractionValues objects are equal.
423

424
        Args:
425
            other: The other InteractionValues object.
426

427
        Returns:
428
            True if the two objects are equal, False otherwise.
429

430
        """
431
        if not isinstance(other, InteractionValues):
1✔
432
            msg = "Cannot compare InteractionValues with other types."
1✔
433
            raise TypeError(msg)
1✔
434
        if (
1✔
435
            self.index != other.index
436
            or self.max_order != other.max_order
437
            or self.min_order != other.min_order
438
            or self.n_players != other.n_players
439
            or self.baseline_value != other.baseline_value
440
        ):
441
            return False
1✔
442
        if not np.allclose(self.values, other.values):
1✔
443
            return False
1✔
444
        return self.interaction_lookup == other.interaction_lookup
1✔
445

446
    def __ne__(self, other: object) -> bool:
1✔
447
        """Checks if two InteractionValues objects are not equal.
448

449
        Args:
450
            other: The other InteractionValues object.
451

452
        Returns:
453
            True if the two objects are not equal, False otherwise.
454

455
        """
456
        return not self.__eq__(other)
1✔
457

458
    def __hash__(self) -> int:
1✔
459
        """Returns the hash of the InteractionValues object."""
460
        return hash(
1✔
461
            (
462
                self.index,
463
                self.max_order,
464
                self.min_order,
465
                self.n_players,
466
                tuple(self.values.flatten()),
467
            ),
468
        )
469

470
    def __copy__(self) -> InteractionValues:
1✔
471
        """Returns a copy of the InteractionValues object."""
472
        return InteractionValues(
1✔
473
            values=copy.deepcopy(self.values),
474
            index=self.index,
475
            max_order=self.max_order,
476
            estimated=self.estimated,
477
            estimation_budget=self.estimation_budget,
478
            n_players=self.n_players,
479
            interaction_lookup=copy.deepcopy(self.interaction_lookup),
480
            min_order=self.min_order,
481
            baseline_value=self.baseline_value,
482
        )
483

484
    def __add__(self, other: InteractionValues | float) -> InteractionValues:
1✔
485
        """Adds two InteractionValues objects together or a scalar."""
486
        n_players, min_order, max_order = self.n_players, self.min_order, self.max_order
1✔
487
        if isinstance(other, InteractionValues):
1✔
488
            if self.index != other.index:  # different indices
1✔
489
                msg = (
1✔
490
                    f"Cannot add InteractionValues with different indices {self.index} and "
491
                    f"{other.index}."
492
                )
493
                raise ValueError(msg)
1✔
494
            if (
1✔
495
                self.interaction_lookup != other.interaction_lookup
496
                or self.n_players != other.n_players
497
                or self.min_order != other.min_order
498
                or self.max_order != other.max_order
499
            ):  # different interactions but addable
500
                added_interactions = self.interactions.copy()
1✔
501
                for interaction in other.interactions:
1✔
502
                    if interaction not in added_interactions:
1✔
503
                        added_interactions[interaction] = other.interactions[interaction]
1✔
504
                    else:
505
                        added_interactions[interaction] += other.interactions[interaction]
1✔
506
                interaction_lookup = {
1✔
507
                    interaction: i for i, interaction in enumerate(added_interactions)
508
                }
509
                # adjust n_players, min_order, and max_order
510
                n_players = max(self.n_players, other.n_players)
1✔
511
                min_order = min(self.min_order, other.min_order)
1✔
512
                max_order = max(self.max_order, other.max_order)
1✔
513
                baseline_value = self.baseline_value + other.baseline_value
1✔
514
            else:  # basic case with same interactions
515
                added_interactions = {
1✔
516
                    interaction: self.interactions[interaction] + other.interactions[interaction]
517
                    for interaction in self.interactions
518
                }
519
                interaction_lookup = self.interaction_lookup
1✔
520
                baseline_value = self.baseline_value + other.baseline_value
1✔
521
        elif isinstance(other, int | float):
1✔
522
            added_interactions = {
1✔
523
                interaction: self.interactions[interaction] + other
524
                for interaction in self.interactions
525
            }
526
            interaction_lookup = self.interaction_lookup.copy()
1✔
527
            baseline_value = self.baseline_value + other
1✔
528
        else:
529
            msg = f"Cannot add InteractionValues with object of type {type(other)}."
1✔
530
            raise TypeError(msg)
1✔
531

532
        return InteractionValues(
1✔
533
            values=added_interactions,
534
            index=self.index,
535
            max_order=max_order,
536
            n_players=n_players,
537
            min_order=min_order,
538
            interaction_lookup=interaction_lookup,
539
            estimated=self.estimated,
540
            estimation_budget=self.estimation_budget,
541
            baseline_value=baseline_value,
542
        )
543

544
    def __radd__(self, other: InteractionValues | float) -> InteractionValues:
1✔
545
        """Adds two InteractionValues objects together or a scalar."""
546
        return self.__add__(other)
1✔
547

548
    def __neg__(self) -> InteractionValues:
1✔
549
        """Negates the InteractionValues object."""
550
        return InteractionValues(
1✔
551
            values=-self.values,
552
            index=self.index,
553
            max_order=self.max_order,
554
            n_players=self.n_players,
555
            min_order=self.min_order,
556
            interaction_lookup=self.interaction_lookup,
557
            estimated=self.estimated,
558
            estimation_budget=self.estimation_budget,
559
            baseline_value=-self.baseline_value,
560
        )
561

562
    def __sub__(self, other: InteractionValues | float) -> InteractionValues:
1✔
563
        """Subtracts two InteractionValues objects or a scalar."""
564
        return self.__add__(-other)
1✔
565

566
    def __rsub__(self, other: InteractionValues | float) -> InteractionValues:
1✔
567
        """Subtracts two InteractionValues objects or a scalar."""
568
        return (-self).__add__(other)
1✔
569

570
    def __mul__(self, other: float) -> InteractionValues:
1✔
571
        """Multiplies an InteractionValues object by a scalar."""
572
        interactions = {
1✔
573
            interaction: value * other for interaction, value in self.interactions.items()
574
        }
575
        return InteractionValues(
1✔
576
            values=interactions,
577
            index=self.index,
578
            max_order=self.max_order,
579
            n_players=self.n_players,
580
            min_order=self.min_order,
581
            interaction_lookup=self.interaction_lookup,
582
            estimated=self.estimated,
583
            estimation_budget=self.estimation_budget,
584
            baseline_value=self.baseline_value * other,
585
        )
586

587
    def __rmul__(self, other: float) -> InteractionValues:
1✔
588
        """Multiplies an InteractionValues object by a scalar."""
589
        return self.__mul__(other)
1✔
590

591
    def __abs__(self) -> InteractionValues:
1✔
592
        """Returns the absolute values of the InteractionValues object."""
593
        interactions = {interaction: abs(value) for interaction, value in self.interactions.items()}
1✔
594
        return InteractionValues(
1✔
595
            values=interactions,
596
            index=self.index,
597
            max_order=self.max_order,
598
            n_players=self.n_players,
599
            min_order=self.min_order,
600
            interaction_lookup=self.interaction_lookup,
601
            estimated=self.estimated,
602
            estimation_budget=self.estimation_budget,
603
            baseline_value=self.baseline_value,
604
        )
605

606
    def get_n_order_values(self, order: int) -> np.ndarray:
1✔
607
        """Returns the interaction values of a specific order as a numpy array.
608

609
        Note:
610
            Depending on the order and number of players the resulting array might be sparse and
611
            very large.
612

613
        Args:
614
            order: The order of the interactions to return.
615

616
        Returns:
617
            The interaction values of the specified order as a numpy array of shape ``(n_players,)``
618
            for order ``1`` and ``(n_players, n_players)`` for order ``2``, etc.
619

620
        Raises:
621
            ValueError: If the order is less than ``1``.
622

623
        """
624
        from itertools import permutations
1✔
625

626
        if order < 1:
1✔
627
            msg = "Order must be greater or equal to 1."
1✔
628
            raise ValueError(msg)
1✔
629
        values_shape = tuple([self.n_players] * order)
1✔
630
        values = np.zeros(values_shape, dtype=float)
1✔
631
        for interaction in self.interaction_lookup:
1✔
632
            if len(interaction) != order:
1✔
633
                continue
1✔
634
            # get all orderings of the interaction (e.g. (0, 1) and (1, 0) for interaction (0, 1))
635
            for perm in permutations(interaction):
1✔
636
                values[perm] = self[interaction]
1✔
637

638
        return values
1✔
639

640
    def get_n_order(
1✔
641
        self,
642
        order: int | None = None,
643
        min_order: int | None = None,
644
        max_order: int | None = None,
645
    ) -> InteractionValues:
646
        """Select particular order of interactions.
647

648
        Creates a new InteractionValues object containing only the interactions within the
649
        specified order range.
650

651
        You can specify:
652
            - `order`: to select interactions of a single specific order (e.g., all pairwise
653
                interactions).
654
            - `min_order` and/or `max_order`: to select a range of interaction orders.
655
            - If `order` and `min_order`/`max_order` are both set, `min_order` and `max_order` will
656
                override the `order` value.
657

658
        Example:
659
            >>> interaction_values = InteractionValues(
660
            ...     values=np.array([1, 2, 3, 4, 5, 6, 7]),
661
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5, (0, 1, 2): 6},
662
            ...     index="SII",
663
            ...     max_order=3,
664
            ...     n_players=3,
665
            ...     min_order=1,
666
            ...     baseline_value=0.0,
667
            ... )
668
            >>> interaction_values.get_n_order(order=1).dict_values
669
            {(0,): 1.0, (1,): 2.0, (2,): 3.0}
670
            >>> interaction_values.get_n_order(min_order=1, max_order=2).dict_values
671
            {(0,): 1.0, (1,): 2.0, (2,): 3.0, (0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0}
672
            >>> interaction_values.get_n_order(min_order=2).dict_values
673
            {(0, 1): 4.0, (0, 2): 5.0, (1, 2): 6.0, (0, 1, 2): 7.0}
674

675
        Args:
676
            order: The order of the interactions to return. Defaults to ``None`` which requires
677
                ``min_order`` or ``max_order`` to be set.
678
            min_order: The minimum order of the interactions to return. Defaults to ``None`` which
679
                sets it to the order.
680
            max_order: The maximum order of the interactions to return. Defaults to ``None`` which
681
                sets it to the order.
682

683
        Returns:
684
            The interaction values of the specified order.
685

686
        Raises:
687
            ValueError: If all three parameters are set to ``None``.
688
        """
689
        if order is None and min_order is None and max_order is None:
1✔
690
            msg = "Either order, min_order or max_order must be set."
1✔
691
            raise ValueError(msg)
1✔
692

693
        if order is not None:
1✔
694
            max_order = order if max_order is None else max_order
1✔
695
            min_order = order if min_order is None else min_order
1✔
696
        else:  # order is None
697
            min_order = self.min_order if min_order is None else min_order
1✔
698
            max_order = self.max_order if max_order is None else max_order
1✔
699

700
        if min_order > max_order:
1✔
701
            msg = f"min_order ({min_order}) must be less than or equal to max_order ({max_order})."
1✔
702
            raise ValueError(msg)
1✔
703

704
        new_values = []
1✔
705
        new_interaction_lookup = {}
1✔
706
        for interaction in self.interaction_lookup:
1✔
707
            if len(interaction) < min_order or len(interaction) > max_order:
1✔
708
                continue
1✔
709
            interaction_idx = len(new_interaction_lookup)
1✔
710
            new_values.append(self[interaction])
1✔
711
            new_interaction_lookup[interaction] = interaction_idx
1✔
712

713
        return InteractionValues(
1✔
714
            values=np.array(new_values),
715
            index=self.index,
716
            max_order=max_order,
717
            n_players=self.n_players,
718
            min_order=min_order,
719
            interaction_lookup=new_interaction_lookup,
720
            estimated=self.estimated,
721
            estimation_budget=self.estimation_budget,
722
            baseline_value=self.baseline_value,
723
        )
724

725
    def get_subset(self, players: list[int]) -> InteractionValues:
1✔
726
        """Selects a subset of players from the InteractionValues object.
727

728
        Args:
729
            players: List of players to select from the InteractionValues object.
730

731
        Returns:
732
            InteractionValues: Filtered InteractionValues object containing only values related to
733
            selected players.
734

735
        Example:
736
            >>> interaction_values = InteractionValues(
737
            ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
738
            ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
739
            ...     index="SII",
740
            ...     max_order=2,
741
            ...     n_players=3,
742
            ...     min_order=1,
743
            ...     baseline_value=0.0,
744
            ... )
745
            >>> interaction_values.get_subset([0, 1]).dict_values
746
            {(0,): 0.1, (1,): 0.2, (0, 1): 0.3}
747
            >>> interaction_values.get_subset([0, 2]).dict_values
748
            {(0,): 0.1, (2,): 0.3, (0, 2): 0.4}
749
            >>> interaction_values.get_subset([1]).dict_values
750
            {(1,): 0.2}
751

752
        """
753
        keys = self.interaction_lookup.keys()
1✔
754
        idx, keys_in_subset = [], []
1✔
755
        for i, key in enumerate(keys):
1✔
756
            if all(p in players for p in key):
1✔
757
                idx.append(i)
1✔
758
                keys_in_subset.append(key)
1✔
759
        new_values = self.values[idx]
1✔
760
        new_interaction_lookup = {key: index for index, key in enumerate(keys_in_subset)}
1✔
761
        n_players = self.n_players - len(players)
1✔
762
        return InteractionValues(
1✔
763
            values=new_values,
764
            index=self.index,
765
            max_order=self.max_order,
766
            n_players=n_players,
767
            min_order=self.min_order,
768
            interaction_lookup=new_interaction_lookup,
769
            estimated=self.estimated,
770
            estimation_budget=self.estimation_budget,
771
            baseline_value=self.baseline_value,
772
        )
773

774
    def save(self, path: Path, *, as_pickle: bool = False, as_npz: bool = False) -> None:
1✔
775
        """Save the InteractionValues object to a file.
776

777
        By default, the InteractionValues object is saved as a JSON file.
778

779
        Args:
780
            path: The path to save the InteractionValues object to.
781
            as_pickle: Whether to save the InteractionValues object as a pickle file (``True``).
782
            as_npz: Whether to save the InteractionValues object as a ``npz`` file (``True``).
783

784
        Raises:
785
            DeprecationWarning: If `as_pickle` or `as_npz` is set to ``True``, a deprecation
786
                warning is raised
787
        """
788
        # check if the directory exists
789
        directory = Path(path).parent
1✔
790
        if not Path(directory).exists():
1✔
791
            with contextlib.suppress(FileNotFoundError):
×
792
                Path(directory).mkdir(parents=True, exist_ok=True)
×
793
        if as_pickle:
1✔
794
            raise_deprecation_warning(
1✔
795
                message=SAVE_JSON_DEPRECATION_MSG,
796
                deprecated_in="1.3.1",
797
                removed_in="1.4.0",
798
            )
799
            with Path(path).open("wb") as file:
1✔
800
                pickle.dump(self, file)
1✔
801
        elif as_npz:
1✔
802
            raise_deprecation_warning(
1✔
803
                message=SAVE_JSON_DEPRECATION_MSG,
804
                deprecated_in="1.3.1",
805
                removed_in="1.4.0",
806
            )
807
            # save object as npz file
808
            np.savez(
1✔
809
                path,
810
                values=self.values,
811
                index=self.index,
812
                max_order=self.max_order,
813
                n_players=self.n_players,
814
                min_order=self.min_order,
815
                interaction_lookup=self.interaction_lookup,
816
                estimated=self.estimated,
817
                estimation_budget=self.estimation_budget,
818
                baseline_value=self.baseline_value,
819
            )
820
        else:
821
            self.to_json_file(path)
1✔
822

823
    @classmethod
1✔
824
    def load(cls, path: Path | str) -> InteractionValues:
1✔
825
        """Load an InteractionValues object from a file.
826

827
        Args:
828
            path: The path to load the InteractionValues object from.
829

830
        Returns:
831
            The loaded InteractionValues object.
832

833
        """
834
        path = Path(path)
1✔
835
        # check if path ends with .json
836
        if path.name.endswith(".json"):
1✔
837
            return cls.from_json_file(path)
1✔
838

839
        raise_deprecation_warning(
1✔
840
            SAVE_JSON_DEPRECATION_MSG, deprecated_in="1.3.1", removed_in="1.4.0"
841
        )
842

843
        # try loading as npz file
844
        if path.name.endswith(".npz"):
1✔
845
            data = np.load(path, allow_pickle=True)
1✔
846
            return InteractionValues(
1✔
847
                values=data["values"],
848
                index=str(data["index"]),
849
                max_order=int(data["max_order"]),
850
                n_players=int(data["n_players"]),
851
                min_order=int(data["min_order"]),
852
                interaction_lookup=data["interaction_lookup"].item(),
853
                estimated=bool(data["estimated"]),
854
                estimation_budget=data["estimation_budget"].item(),
855
                baseline_value=float(data["baseline_value"]),
856
            )
857
        msg = f"Path {path} does not end with .json or .npz. Cannot load InteractionValues."
×
858
        raise ValueError(msg)
×
859

860
    @classmethod
1✔
861
    def from_dict(cls, data: dict[str, Any]) -> InteractionValues:
1✔
862
        """Create an InteractionValues object from a dictionary.
863

864
        Args:
865
            data: The dictionary containing the data to create the InteractionValues object from.
866

867
        Returns:
868
            The InteractionValues object created from the dictionary.
869

870
        """
871
        return cls(
1✔
872
            values=data["values"],
873
            index=data["index"],
874
            max_order=data["max_order"],
875
            n_players=data["n_players"],
876
            min_order=data["min_order"],
877
            interaction_lookup=data["interaction_lookup"],
878
            estimated=data["estimated"],
879
            estimation_budget=data["estimation_budget"],
880
            baseline_value=data["baseline_value"],
881
        )
882

883
    def to_dict(self) -> dict:
1✔
884
        """Convert the InteractionValues object to a dictionary.
885

886
        Returns:
887
            The InteractionValues object as a dictionary.
888

889
        """
890
        return {
1✔
891
            "values": self.interactions,
892
            "index": self.index,
893
            "max_order": self.max_order,
894
            "n_players": self.n_players,
895
            "min_order": self.min_order,
896
            "interaction_lookup": self.interaction_lookup,
897
            "estimated": self.estimated,
898
            "estimation_budget": self.estimation_budget,
899
            "baseline_value": self.baseline_value,
900
        }
901

902
    def aggregate(
1✔
903
        self,
904
        others: Sequence[InteractionValues],
905
        aggregation: str = "mean",
906
    ) -> InteractionValues:
907
        """Aggregates InteractionValues objects using a specific aggregation method.
908

909
        Args:
910
            others: A list of InteractionValues objects to aggregate.
911
            aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
912
                ``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
913

914
        Returns:
915
            The aggregated InteractionValues object.
916

917
        Note:
918
            For documentation on the aggregation methods, see the ``aggregate_interaction_values()``
919
            function.
920

921
        """
922
        return aggregate_interaction_values([self, *others], aggregation)
1✔
923

924
    def plot_network(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
925
        """Visualize InteractionValues on a graph.
926

927
        Note:
928
            For arguments, see :func:`shapiq.plot.network.network_plot` and
929
                :func:`shapiq.plot.si_graph.si_graph_plot`.
930

931
        Args:
932
            show: Whether to show the plot. Defaults to ``True``.
933

934
            **kwargs: Additional keyword arguments to pass to the plotting function.
935

936
        Returns:
937
            If show is ``False``, the function returns a tuple with the figure and the axis of the
938
                plot.
939
        """
940
        from shapiq.plot.network import network_plot
1✔
941

942
        if self.max_order > 1:
1✔
943
            return network_plot(
1✔
944
                interaction_values=self,
945
                show=show,
946
                **kwargs,
947
            )
948
        msg = (
1✔
949
            "InteractionValues contains only 1-order values,"
950
            "but requires also 2-order values for the network plot."
951
        )
952
        raise ValueError(msg)
1✔
953

954
    def plot_si_graph(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
955
        """Visualize InteractionValues as a SI graph.
956

957
        For arguments, see shapiq.plots.si_graph_plot().
958

959
        Returns:
960
            The SI graph as a tuple containing the figure and the axes.
961

962
        """
963
        from shapiq.plot.si_graph import si_graph_plot
1✔
964

965
        return si_graph_plot(self, show=show, **kwargs)
1✔
966

967
    def plot_stacked_bar(self, *, show: bool = True, **kwargs: Any) -> tuple[Figure, Axes] | None:
1✔
968
        """Visualize InteractionValues on a graph.
969

970
        For arguments, see shapiq.plots.stacked_bar_plot().
971

972
        Returns:
973
            The stacked bar plot as a tuple containing the figure and the axes.
974

975
        """
976
        from shapiq import stacked_bar_plot
1✔
977

978
        return stacked_bar_plot(self, show=show, **kwargs)
1✔
979

980
    def plot_force(
1✔
981
        self,
982
        feature_names: np.ndarray | None = None,
983
        *,
984
        show: bool = True,
985
        abbreviate: bool = True,
986
        contribution_threshold: float = 0.05,
987
    ) -> Figure | None:
988
        """Visualize InteractionValues on a force plot.
989

990
        For arguments, see shapiq.plots.force_plot().
991

992
        Args:
993
            feature_names: The feature names used for plotting. If no feature names are provided, the
994
                feature indices are used instead. Defaults to ``None``.
995
            show: Whether to show the plot. Defaults to ``False``.
996
            abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
997
            contribution_threshold: The threshold for contributions to be displayed in percent.
998
                Defaults to ``0.05``.
999

1000
        Returns:
1001
            The force plot as a matplotlib figure (if show is ``False``).
1002

1003
        """
1004
        from .plot import force_plot
1✔
1005

1006
        return force_plot(
1✔
1007
            self,
1008
            feature_names=feature_names,
1009
            show=show,
1010
            abbreviate=abbreviate,
1011
            contribution_threshold=contribution_threshold,
1012
        )
1013

1014
    def plot_waterfall(
1✔
1015
        self,
1016
        feature_names: np.ndarray | None = None,
1017
        *,
1018
        show: bool = True,
1019
        abbreviate: bool = True,
1020
        max_display: int = 10,
1021
    ) -> Axes | None:
1022
        """Draws interaction values on a waterfall plot.
1023

1024
        Note:
1025
            Requires the ``shap`` Python package to be installed.
1026

1027
        Args:
1028
            feature_names: The feature names used for plotting. If no feature names are provided, the
1029
                feature indices are used instead. Defaults to ``None``.
1030
            show: Whether to show the plot. Defaults to ``False``.
1031
            abbreviate: Whether to abbreviate the feature names or not. Defaults to ``True``.
1032
            max_display: The maximum number of interactions to display. Defaults to ``10``.
1033
        """
1034
        from shapiq import waterfall_plot
1✔
1035

1036
        return waterfall_plot(
1✔
1037
            self,
1038
            feature_names=feature_names,
1039
            show=show,
1040
            max_display=max_display,
1041
            abbreviate=abbreviate,
1042
        )
1043

1044
    def plot_sentence(
1✔
1045
        self,
1046
        words: list[str],
1047
        *,
1048
        show: bool = True,
1049
        **kwargs: Any,
1050
    ) -> tuple[Figure, Axes] | None:
1051
        """Plots the first order effects (attributions) of a sentence or paragraph.
1052

1053
        For arguments, see shapiq.plots.sentence_plot().
1054

1055
        Returns:
1056
            If ``show`` is ``True``, the function returns ``None``. Otherwise, it returns a tuple
1057
            with the figure and the axis of the plot.
1058

1059
        """
1060
        from shapiq.plot.sentence import sentence_plot
1✔
1061

1062
        return sentence_plot(self, words, show=show, **kwargs)
1✔
1063

1064
    def plot_upset(self, *, show: bool = True, **kwargs: Any) -> Figure | None:
1✔
1065
        """Plots the upset plot.
1066

1067
        For arguments, see shapiq.plot.upset_plot().
1068

1069
        Returns:
1070
            The upset plot as a matplotlib figure (if show is ``False``).
1071

1072
        """
1073
        from shapiq.plot.upset import upset_plot
1✔
1074

1075
        return upset_plot(self, show=show, **kwargs)
1✔
1076

1077

1078
def aggregate_interaction_values(
1✔
1079
    interaction_values: Sequence[InteractionValues],
1080
    aggregation: str = "mean",
1081
) -> InteractionValues:
1082
    """Aggregates InteractionValues objects using a specific aggregation method.
1083

1084
    Args:
1085
        interaction_values: A list of InteractionValues objects to aggregate.
1086
        aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
1087
            ``"median"``, ``"sum"``, ``"max"``, and ``"min"``.
1088

1089
    Returns:
1090
        The aggregated InteractionValues object.
1091

1092
    Example:
1093
        >>> iv1 = InteractionValues(
1094
        ...     values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
1095
        ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
1096
        ...     index="SII",
1097
        ...     max_order=2,
1098
        ...     n_players=3,
1099
        ...     min_order=1,
1100
        ...     baseline_value=0.0,
1101
        ... )
1102
        >>> iv2 = InteractionValues(
1103
        ...     values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]),  # this iv is missing the (1, 2) value
1104
        ...     interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4},  # no (1, 2)
1105
        ...     index="SII",
1106
        ...     max_order=2,
1107
        ...     n_players=3,
1108
        ...     min_order=1,
1109
        ...     baseline_value=1.0,
1110
        ... )
1111
        >>> aggregate_interaction_values([iv1, iv2], "mean")
1112
        InteractionValues(
1113
            index=SII, max_order=2, min_order=1, estimated=True, estimation_budget=None,
1114
            n_players=3, baseline_value=0.5,
1115
            Top 10 interactions:
1116
                (1, 2): 0.60
1117
                (0, 2): 0.35
1118
                (0, 1): 0.25
1119
                (0,): 0.15
1120
                (1,): 0.25
1121
                (2,): 0.35
1122
        )
1123

1124
    Note:
1125
        The index of the aggregated InteractionValues object is set to the index of the first
1126
        InteractionValues object in the list.
1127

1128
    Raises:
1129
        ValueError: If the aggregation method is not supported.
1130

1131
    """
1132

1133
    def _aggregate(vals: list[float], method: str) -> float:
1✔
1134
        """Does the actual aggregation of the values."""
1135
        if method == "mean":
1✔
1136
            return float(np.mean(vals))
1✔
1137
        if method == "median":
1✔
1138
            return float(np.median(vals))
1✔
1139
        if method == "sum":
1✔
1140
            return np.sum(vals)
1✔
1141
        if method == "max":
1✔
1142
            return np.max(vals)
1✔
1143
        if method == "min":
1✔
1144
            return np.min(vals)
1✔
1145
        msg = f"Aggregation method {method} is not supported."
1✔
1146
        raise ValueError(msg)
1✔
1147

1148
    # get all keys from all InteractionValues objects
1149
    all_keys = set()
1✔
1150
    for iv in interaction_values:
1✔
1151
        all_keys.update(iv.interaction_lookup.keys())
1✔
1152
    all_keys = sorted(all_keys)
1✔
1153

1154
    # aggregate the values
1155
    new_values = np.zeros(len(all_keys), dtype=float)
1✔
1156
    new_lookup = {}
1✔
1157
    for i, key in enumerate(all_keys):
1✔
1158
        new_lookup[key] = i
1✔
1159
        values = [iv[key] for iv in interaction_values]
1✔
1160
        new_values[i] = _aggregate(values, aggregation)
1✔
1161

1162
    max_order = max([iv.max_order for iv in interaction_values])
1✔
1163
    min_order = min([iv.min_order for iv in interaction_values])
1✔
1164
    n_players = max([iv.n_players for iv in interaction_values])
1✔
1165
    baseline_value = _aggregate([iv.baseline_value for iv in interaction_values], aggregation)
1✔
1166
    estimation_budget = interaction_values[0].estimation_budget
1✔
1167

1168
    return InteractionValues(
1✔
1169
        values=new_values,
1170
        index=interaction_values[0].index,
1171
        max_order=max_order,
1172
        n_players=n_players,
1173
        min_order=min_order,
1174
        interaction_lookup=new_lookup,
1175
        estimated=True,
1176
        estimation_budget=estimation_budget,
1177
        baseline_value=baseline_value,
1178
    )
1179

1180

1181
def _validate_and_return_interactions(
1✔
1182
    values: np.ndarray | dict[tuple[int, ...], float],
1183
    interaction_lookup: dict[tuple[int, ...], int] | None,
1184
    n_players: int,
1185
    min_order: int,
1186
    max_order: int,
1187
    baseline_value: float | np.number,
1188
) -> dict[tuple[int, ...], float]:
1189
    """Check the interactions for validity and consistency.
1190

1191
    Args:
1192
        values (np.ndarray | dict[tuple[int, ...], float]): The interaction values.
1193
        interaction_lookup (dict[tuple[int, ...], int]): A mapping from interactions to their indices.
1194
        n_players (int): The number of players.
1195
        min_order (int): The minimum order of interactions.
1196
        max_order (int): The maximum order of interactions.
1197
        baseline_value (float | np.number): The baseline value to use for empty interactions.
1198

1199
    Raises:
1200
        TypeError: If the values or interaction_lookup are not of the expected types.
1201
    """
1202
    interactions: dict[tuple[int, ...], float] = {}
1✔
1203
    if interaction_lookup is None:
1✔
1204
        interaction_lookup = generate_interaction_lookup(
1✔
1205
            players=n_players,
1206
            min_order=min_order,
1207
            max_order=max_order,
1208
        )
1209
    if interaction_lookup is not None and not isinstance(interaction_lookup, dict):
1✔
1210
        msg = f"Interaction lookup must be a dictionary. Got {type(interaction_lookup)}."
×
1211
        raise TypeError(msg)
×
1212

1213
    if isinstance(values, dict):
1✔
1214
        interactions = copy.deepcopy(values)
1✔
1215
    else:
1216
        interactions = {
1✔
1217
            interaction: values[index].item() for interaction, index in interaction_lookup.items()
1218
        }
1219

1220
    if min_order == 0 and () not in interactions:
1✔
1221
        interactions[()] = float(baseline_value)
1✔
1222
    return interactions
1✔
1223

1224

1225
def _update_interactions_for_index(
1✔
1226
    interactions: dict[tuple[int, ...], float],
1227
    index: str,
1228
    target_index: str,
1229
    max_order: int,
1230
    min_order: int,
1231
    baseline_value: float | np.number,
1232
) -> tuple[dict[tuple[int, ...], float], str, int, float]:
1233
    from .game_theory.aggregation import aggregate_base_attributions
1✔
1234

1235
    if is_index_aggregated(target_index) and target_index != index:
1✔
1236
        interactions, index, min_order = aggregate_base_attributions(
1✔
1237
            interactions=interactions,
1238
            index=index,
1239
            order=max_order,
1240
            min_order=min_order,
1241
            baseline_value=float(baseline_value),
1242
        )
1243
    if () in interactions:
1✔
1244
        empty_value = interactions[()]
1✔
1245
        if empty_value != baseline_value and index != "SII":
1✔
1246
            if is_empty_value_the_baseline(index):
1✔
1247
                # insert the empty value given in baseline into the values
1248
                interactions[()] = float(baseline_value)
1✔
1249
            else:  # manually set baseline to the empty value
1250
                baseline_value = interactions[()]
1✔
1251
    elif min_order == 0:
1✔
1252
        # TODO(mmshlk): this might not be what we really want to do always: what if empty and baseline are different?
1253
        # https://github.com/mmschlk/shapiq/issues/385
1254
        interactions[()] = float(baseline_value)
×
1255
    return interactions, index, min_order, float(baseline_value)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc