• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 16345920221

17 Jul 2025 01:05PM UTC coverage: 77.589% (-13.5%) from 91.075%
16345920221

push

github

web-flow
🔨 Refactors library into a src structure. (#415)

* moves shapiq into a src folder

* moves shapiq tests into tests_shapiq subfolder in tests

* refactors tests to work properly

* removes pickle support and closes #413

* changes unit tests to only run the unit tests

* adds workflow for running shapiq_games

* updates coverage to only run for shapiq

* update workflow to check for shapiq_games import

* update CHANGELOG.md

* fixes install-import.yml

* fixes version in docs

* moved deprecated tests out of the main test suite

* moves fixtures in the correct test suite

* installs libomp on macos runner (try bugfix)

* correct spelling

* removes libomp again

* moves os runs into individual workflows for easier debugging

* runs macOS on py3.13

* renames workflows

* installs libomp again on macOS

* downgraded to 3.11 and reinstall python

* try different uv version

* adds libomp

* changes skip to xfail in integration tests with wrong index/order combinations

* moves test out for debugging CI

* removes outdated test

* adds concurrency for quicker testsing

* re-adds randomly

* dont reset seed

* removed pytest-randomly again

* adds the tests back in

3 of 21 new or added lines in 19 files covered. (14.29%)

5536 of 7135 relevant lines covered (77.59%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.18
/src/shapiq/explainer/tree/treeshapiq.py
1
"""Implementation of the tree explainer."""
2

3
from __future__ import annotations
1✔
4

5
import copy
1✔
6
from math import factorial
1✔
7
from typing import TYPE_CHECKING, Literal
1✔
8

9
import numpy as np
1✔
10
import scipy as sp
1✔
11

12
from shapiq.game_theory.indices import get_computation_index
1✔
13
from shapiq.interaction_values import InteractionValues, finalize_computed_interactions
1✔
14
from shapiq.utils.sets import generate_interaction_lookup, powerset
1✔
15

16
from .conversion.edges import create_edge_tree
1✔
17
from .validation import validate_tree_model
1✔
18

19
if TYPE_CHECKING:
1✔
NEW
20
    from shapiq.typing import Model
×
21

22
    from .base import EdgeTree, TreeModel
×
23

24

25
TreeSHAPIQIndices = Literal["SV", "SII", "k-SII"]
1✔
26

27

28
class TreeSHAPIQ:
1✔
29
    """The TreeSHAP-IQ computation class.
30

31
    This class implements the TreeSHAP-IQ algorithm for computing Shapley Interaction values for
32
    tree-based models. It is used internally by the
33
    :class:`~shapiq.explainer.tree.explainer.TreeExplainer`. The TreeSHAP-IQ algorithm is presented
34
    in `Muschalik et al. (2024)` [Mus24]_.
35

36
    TreeSHAP-IQ is an algorithm for computing Shapley Interaction values for tree-based models.
37
    It is based on the Linear TreeSHAP algorithm by `Yu et al. (2022)` [Yu22]_, but extended to
38
    compute Shapley Interaction values up to a given order. TreeSHAP-IQ needs to visit each node
39
    only once and makes use of polynomial arithmetic to compute the Shapley Interaction values
40
    efficiently.
41

42
    Note:
43
        This class is not intended to be used directly. Instead, use the ``TreeExplainer`` class to
44
        explain tree-based models which internally uses then the TreeSHAP-IQ algorithm.
45

46
    References:
47
        .. [Yu22] Peng Yu, Chao Xu, Albert Bifet, Jesse Read Linear Tree Shap (2022). In: Proceedings of 36th Conference on Neural Information Processing Systems. https://openreview.net/forum?id=OzbkiUo24g
48
        .. [Mus24] Maximilian Muschalik, Fabian Fumagalli, Barbara Hammer, & Eyke Hüllermeier (2024). Beyond TreeSHAP: Efficient Computation of Any-Order Shapley Interactions for Tree Ensembles. In: Proceedings of the AAAI Conference on Artificial Intelligence, 38(13), 14388-14396. https://doi.org/10.1609/aaai.v38i13.29352
49

50
    """
51

52
    def __init__(
1✔
53
        self,
54
        model: dict | TreeModel | Model,
55
        *,
56
        max_order: int = 2,
57
        min_order: int = 1,
58
        index: TreeSHAPIQIndices = "k-SII",
59
        verbose: bool = False,
60
    ) -> None:
61
        """Initializes the TreeSHAP-IQ explainer.
62

63
        Args:
64
            model: A single tree model to explain. Note that unlike the
65
                :class:`~shapiq.explainer.tree.explainer.TreeExplainer` class, TreeSHAP-IQ only
66
                supports a single tree. It can be a dictionary representation of the tree, a
67
                :class:`~shapiq.explainer.tree.base.TreeModel` object, or any other single tree
68
                model supported by the :meth:`~shapiq.explainer.tree.validation.validate_tree_model`
69
                function.
70

71
            max_order: The maximum interaction order to be computed. An interaction order of ``1``
72
                corresponds to the Shapley value. Any value higher than ``1`` computes the Shapley
73
                interaction values up to that order. Defaults to ``2``.
74

75
            min_order: The minimum interaction order to be computed. Defaults to ``1``. Note that
76
                setting min_order currently does not have any effect on the computation.
77

78
            index: The type of interaction to be computed.
79

80
            verbose: Whether to print information about the tree during initialization. Defaults to
81
                ``False``.
82

83
        """
84
        # set parameters
85
        self._root_node_id = 0
1✔
86
        self.verbose = verbose
1✔
87
        if max_order < min_order or max_order < 1 or min_order < 1:
1✔
88
            msg = (
1✔
89
                "The maximum order must be greater than the minimum order and both must be greater "
90
                "than 0."
91
            )
92
            raise ValueError(msg)
1✔
93
        self._max_order: int = max_order
1✔
94
        self._min_order: int = min_order
1✔
95
        self._index: str = index
1✔
96
        self._base_index: str = get_computation_index(self._index)
1✔
97

98
        # validate and parse model
99
        validated_model = validate_tree_model(model)  # the parsed and validated model
1✔
100
        # TODO(mmshlk): add support for other sample weights https://github.com/mmschlk/shapiq/issues/99
101
        self._tree: TreeModel = copy.deepcopy(validated_model)
1✔
102
        self._relevant_features: np.ndarray = np.array(list(self._tree.feature_ids), dtype=int)
1✔
103
        self._tree.reduce_feature_complexity()
1✔
104
        self._n_nodes: int = self._tree.n_nodes
1✔
105
        self._n_features_in_tree: int = self._tree.n_features_in_tree
1✔
106
        self._max_feature_id: int = self._tree.max_feature_id
1✔
107
        self._feature_ids: set = self._tree.feature_ids
1✔
108

109
        # precompute interaction lookup tables
110
        self._interactions_lookup_relevant: dict[tuple, int] = generate_interaction_lookup(
1✔
111
            self._relevant_features,
112
            self._min_order,
113
            self._max_order,
114
        )
115
        self._interactions_lookup: dict[int, dict[tuple, int]] = {}  # lookup for interactions
1✔
116
        self._interaction_update_positions: dict[int, dict[int, np.ndarray[int]]] = {}  # lookup
1✔
117
        self._init_interaction_lookup_tables()
1✔
118

119
        # get the edge representation of the tree
120
        edge_tree = create_edge_tree(
1✔
121
            children_left=self._tree.children_left,
122
            children_right=self._tree.children_right,
123
            features=self._tree.features,
124
            node_sample_weight=self._tree.node_sample_weight,
125
            values=self._tree.values,
126
            max_interaction=self._max_order,
127
            n_features=self._max_feature_id + 1,
128
            n_nodes=self._n_nodes,
129
            subset_updates_pos_store=self._interaction_update_positions,
130
        )
131
        self._edge_tree: EdgeTree = copy.deepcopy(edge_tree)
1✔
132

133
        # compute the empty prediction
134
        computed_empty_prediction = float(
1✔
135
            np.sum(self._edge_tree.empty_predictions[self._tree.leaf_mask]),
136
        )
137
        tree_empty_prediction = self._tree.empty_prediction
1✔
138
        if tree_empty_prediction is None:
1✔
139
            tree_empty_prediction = computed_empty_prediction
×
140
        self.empty_prediction: float = tree_empty_prediction
1✔
141

142
        # stores the interaction scores up to a given order
143
        self.subset_ancestors_store: dict = {}
1✔
144
        self.D_store: dict = {}
1✔
145
        self.D_powers_store: dict = {}
1✔
146
        self.Ns_id_store: dict = {}
1✔
147
        self.Ns_store: dict = {}
1✔
148
        self.n_interpolation_size = self._n_features_in_tree
1✔
149
        if self._index in ("SV", "SII", "k-SII"):  # SP is of order at most d_max
1✔
150
            self.n_interpolation_size = min(self._edge_tree.max_depth, self._n_features_in_tree)
1✔
151
        try:
1✔
152
            self._init_summary_polynomials()
1✔
153
            self._trivial_computation = False
1✔
154
        except ValueError:
1✔
155
            if self._n_features_in_tree == 1:
1✔
156
                self._trivial_computation = True  # for one feature the computation is trivial
1✔
157
            else:
158
                raise
×
159

160
        # stores the nodes that are active in the tree for a given instance (new for each instance)
161
        self._activations: np.ndarray = np.zeros(self._n_nodes, dtype=bool)
1✔
162

163
        # print tree information
164
        if self.verbose:
1✔
165
            self._print_tree_info()
1✔
166

167
    def explain(self, x: np.ndarray) -> InteractionValues:
1✔
168
        """Computes the Shapley Interaction values for a given instance ``x`` and interaction order.
169

170
        Note:
171
            This function is the main explanation function of this class.
172

173
        Args:
174
            x (np.ndarray): Instance to be explained.
175

176
        Returns:
177
            InteractionValues: The computed Shapley Interaction values.
178

179
        """
180
        x_relevant = x[self._relevant_features]
1✔
181
        n_players = max(x.shape[0], self._n_features_in_tree)
1✔
182

183
        if self._trivial_computation:
1✔
184
            interactions = self._compute_trivial_shapley_interaction_values(x)
1✔
185
        else:
186
            # compute the Shapley Interaction values
187
            interactions = np.asarray([], dtype=float)
1✔
188
            for order in range(self._min_order, self._max_order + 1):
1✔
189
                shapley_interactions = np.zeros(
1✔
190
                    int(sp.special.binom(self._n_features_in_tree, order)),
191
                    dtype=float,
192
                )
193
                self.shapley_interactions = shapley_interactions
1✔
194
                self._prepare_variables_for_order(interaction_order=order)
1✔
195
                self._compute_shapley_interaction_values(x_relevant, order=order, node_id=0)
1✔
196
                # append the computed Shapley Interaction values to the result
197
                interactions = np.append(interactions, self.shapley_interactions.copy())
1✔
198

199
        shapley_interaction_values = InteractionValues(
1✔
200
            values=interactions,
201
            index=self._base_index,
202
            min_order=self._min_order,
203
            max_order=self._max_order,
204
            n_players=n_players,
205
            estimated=False,
206
            interaction_lookup=self._interactions_lookup_relevant,
207
            baseline_value=self.empty_prediction,
208
        )
209

210
        return finalize_computed_interactions(
1✔
211
            shapley_interaction_values,
212
            target_index=self._index,
213
        )
214

215
    def _compute_trivial_shapley_interaction_values(self, x: np.ndarray) -> np.ndarray:
1✔
216
        """Computes the Shapley interactions for the case of only one feature in the tree.
217

218
        Computing the Shapley interactions for the case of only one feature in the tree is trivial
219
        since only the main effect of this feature is considered, i.e., the first order value of the
220
        single feature gets the full effect and all higher order values are zero.
221

222
        Args:
223
            x: The original instance to be explained.
224

225
        Returns:
226
            np.ndarray: The computed Shapley Interaction values.
227

228
        """
229
        full_prediction = self._tree.predict_one(x)
1✔
230
        main_effect = full_prediction - self.empty_prediction
1✔
231
        shapley_interactions = np.zeros(1, dtype=float)
1✔
232
        shapley_interactions[0] = main_effect
1✔
233
        return shapley_interactions
1✔
234

235
    def _compute_shapley_interaction_values(
1✔
236
        self,
237
        x: np.ndarray,
238
        order: int = 1,
239
        node_id: int = 0,
240
        *,
241
        summary_poly_down: np.ndarray[float] = None,
242
        summary_poly_up: np.ndarray[float] = None,
243
        interaction_poly_down: np.ndarray[float] = None,
244
        quotient_poly_down: np.ndarray[float] = None,
245
        depth: int = 0,
246
    ) -> None:
247
        """Computes the Shapley Interaction values for a given instance x and interaction order.
248

249
        Note:
250
            This function is called recursively for each node in the tree.
251

252
        Args:
253
            x: The instance to be explained.
254

255
            order: The interaction order for which the Shapley Interaction values should be
256
                computed. Defaults to ``1``.
257

258
            node_id: The node ID of the current node in the tree. Defaults to ``0``.
259

260
            summary_poly_down: The summary polynomial for the current node. Defaults to ``None``
261
                (at init time).
262

263
            summary_poly_up: The summary polynomial propagated up the tree. Defaults to ``None``
264
                (at init time).
265

266
            interaction_poly_down: The interaction polynomial for the current node. Defaults to
267
                ``None`` (at init time).
268

269
            quotient_poly_down: The quotient polynomial for the current node. Defaults to ``None``
270
                (at init time).
271

272
            depth: The depth of the current node in the tree. Defaults to ``0``.
273

274
        """
275
        # fmt: off
276
        # manually formatted for better readability in formulas and equations
277
        # reset activations for new calculations
278
        if node_id == 0:
1✔
279
            self._activations.fill(False)  # noqa: FBT003
1✔
280

281
        # get polynomials if None
282
        polynomials = self._get_polynomials(
1✔
283
            order=order,
284
            summary_poly_down=summary_poly_down,
285
            summary_poly_up=summary_poly_up,
286
            interaction_poly_down=interaction_poly_down,
287
            quotient_poly_down=quotient_poly_down,
288
        )
289
        summary_poly_down, summary_poly_up, interaction_poly_down, quotient_poly_down = polynomials
1✔
290

291
        # get related nodes (surrounding) nodes
292
        left_child = int(self._tree.children_left[node_id])
1✔
293
        right_child = int(self._tree.children_right[node_id])
1✔
294
        parent_id = int(self._edge_tree.parents[node_id])
1✔
295
        ancestor_id = int(self._edge_tree.ancestors[node_id])
1✔
296

297
        # get feature information
298
        feature_id = int(self._tree.features[parent_id])
1✔
299
        feature_threshold = self._tree.thresholds[node_id]
1✔
300
        child_edge_feature = self._tree.features[node_id]
1✔
301

302
        # get height of related nodes
303
        current_height = int(self._edge_tree.edge_heights[node_id])
1✔
304
        left_height = int(self._edge_tree.edge_heights[left_child])
1✔
305
        right_height = int(self._edge_tree.edge_heights[right_child])
1✔
306

307
        # get path information
308
        is_leaf = bool(self._tree.leaf_mask[node_id])
1✔
309
        has_ancestor = bool(self._edge_tree.has_ancestors[node_id])
1✔
310
        activations = self._activations
1✔
311

312
        # if feature_id > -1:
313
        try:
1✔
314
            interaction_sets = self.subset_updates_pos[feature_id]
1✔
315
        except KeyError:
1✔
316
            interaction_sets = np.array([], dtype=int)
1✔
317

318
        # if node is not a leaf -> set activations for children nodes accordingly
319
        if not is_leaf:
1✔
320
            if x[child_edge_feature] <= feature_threshold:
1✔
321
                activations[left_child], activations[right_child] = True, False
1✔
322
            else:
323
                activations[left_child], activations[right_child] = False, True
1✔
324

325
        # if node is not the root node -> calculate the summary polynomials
326
        if node_id != self._root_node_id:
1✔
327
            # set activations of current node in relation to the ancestor (for setting p_e to zero)
328
            if has_ancestor:
1✔
329
                activations[node_id] &= activations[ancestor_id]
1✔
330
            # if node is active get the correct p_e value
331
            p_e_current = self._edge_tree.p_e_values[node_id] if activations[node_id] else 0.0
1✔
332
            # update summary polynomial
333
            summary_poly_down[depth] = summary_poly_down[depth - 1] * (self.D + p_e_current)
1✔
334
            # update quotient polynomials
335
            quotient_poly_down[depth, :] = quotient_poly_down[depth - 1, :].copy()
1✔
336
            quotient_poly_down[depth, interaction_sets] = quotient_poly_down[depth, interaction_sets] * (self.D + p_e_current)
1✔
337
            # update interaction polynomial
338
            interaction_poly_down[depth, :] = interaction_poly_down[depth - 1, :].copy()
1✔
339
            interaction_poly_down[depth, interaction_sets] = interaction_poly_down[depth, interaction_sets] * (-self.D + p_e_current)
1✔
340
            # remove previous polynomial factor if node has ancestors
341
            if has_ancestor:
1✔
342
                p_e_ancestor = 0.0
1✔
343
                if activations[ancestor_id]:
1✔
344
                    p_e_ancestor = self._edge_tree.p_e_values[ancestor_id]
1✔
345
                # rescale the polynomials
346
                summary_poly_down[depth] = summary_poly_down[depth] / (self.D + p_e_ancestor)
1✔
347
                quotient_poly_down[depth, interaction_sets] = quotient_poly_down[depth, interaction_sets] / (self.D + p_e_ancestor)
1✔
348
                interaction_poly_down[depth, interaction_sets] = interaction_poly_down[depth, interaction_sets] / (-self.D + p_e_ancestor)
1✔
349

350
        # if node is leaf -> add the empty prediction to the summary polynomial and store it
351
        if is_leaf:  # recursion base case
1✔
352
            summary_poly_up[depth] = (
1✔
353
                summary_poly_down[depth] * self._edge_tree.empty_predictions[node_id]
354
            )
355
        else:  # not a leaf -> continue recursion
356
            # left child
357
            self._compute_shapley_interaction_values(
1✔
358
                x,
359
                order=order,
360
                node_id=left_child,
361
                summary_poly_down=summary_poly_down,
362
                summary_poly_up=summary_poly_up,
363
                interaction_poly_down=interaction_poly_down,
364
                quotient_poly_down=quotient_poly_down,
365
                depth=depth + 1,
366
            )
367
            summary_poly_up[depth] = (
1✔
368
                summary_poly_up[depth + 1] * self.D_powers[current_height - left_height]
369
            )
370
            # right child
371
            self._compute_shapley_interaction_values(
1✔
372
                x,
373
                order=order,
374
                node_id=right_child,
375
                summary_poly_down=summary_poly_down,
376
                summary_poly_up=summary_poly_up,
377
                interaction_poly_down=interaction_poly_down,
378
                quotient_poly_down=quotient_poly_down,
379
                depth=depth + 1,
380
            )
381
            summary_poly_up[depth] += (
1✔
382
                summary_poly_up[depth + 1] * self.D_powers[current_height - right_height]
383
            )
384

385
        # if node is not the root node -> calculate the Shapley Interaction values for the node
386
        if node_id is not self._root_node_id:
1✔
387
            interactions_seen = interaction_sets[
1✔
388
                self._int_height[node_id][interaction_sets] == order
389
            ]
390
            if len(interactions_seen) > 0:
1✔
391
                if self._index not in ("SV", "SII", "k-SII"):  # for CII
1✔
392
                    D_power = self.D_powers[self._n_features_in_tree - current_height]
1✔
393
                    index_quotient = self._n_features_in_tree - order
1✔
394
                else:  # for SII and k-SII
395
                    D_power = self.D_powers[0]
1✔
396
                    index_quotient = current_height - order
1✔
397
                interaction_update = np.dot(
1✔
398
                    interaction_poly_down[depth, interactions_seen],
399
                    self.Ns_id[self.n_interpolation_size, : self.n_interpolation_size],
400
                )
401
                interaction_update *= self._psi(
1✔
402
                    summary_poly_up[depth, :],
403
                    D_power,
404
                    quotient_poly_down[depth, interactions_seen],
405
                    self.Ns,
406
                    index_quotient,
407
                )
408
                self.shapley_interactions[interactions_seen] += interaction_update
1✔
409

410
            # if node has ancestors -> adjust the Shapley Interaction values for the node
411
            ancestors_of_interactions = self.subset_ancestors[node_id][interaction_sets]
1✔
412
            if np.any(ancestors_of_interactions > -1):  # at least one ancestor exists (not -1)
1✔
413
                ancestor_node_id_exists = ancestors_of_interactions > -1  # get mask of ancestors
1✔
414
                interactions_with_ancestor = interaction_sets[ancestor_node_id_exists]
1✔
415
                cond_interaction_seen = (
1✔
416
                    self._int_height[parent_id][interactions_with_ancestor] == order
417
                )
418
                interactions_ancestors = ancestors_of_interactions[ancestor_node_id_exists]
1✔
419
                interactions_with_ancestor_to_update = interactions_with_ancestor[
1✔
420
                    cond_interaction_seen
421
                ]
422
                if len(interactions_with_ancestor_to_update) > 0:
1✔
423
                    ancestor_heights = self._edge_tree.edge_heights[
1✔
424
                        interactions_ancestors[cond_interaction_seen]
425
                    ]
426
                    if self._index not in ("SV", "SII", "k-SII"):  # for CII
1✔
427
                        D_power = self.D_powers[self._n_features_in_tree - current_height]
1✔
428
                        index_quotient = self._n_features_in_tree - order
1✔
429
                    else:  # for SII and k-SII
430
                        D_power = self.D_powers[ancestor_heights - current_height]
1✔
431
                        index_quotient = ancestor_heights - order
1✔
432
                    update = np.dot(
1✔
433
                        interaction_poly_down[depth - 1, interactions_with_ancestor_to_update],
434
                        self.Ns_id[self.n_interpolation_size, : self.n_interpolation_size],
435
                    )
436
                    to_update = self._psi_ancestor(
1✔
437
                        summary_poly_up[depth],
438
                        D_power,
439
                        quotient_poly_down[depth - 1, interactions_with_ancestor_to_update],
440
                        self.Ns,
441
                        index_quotient,
442
                    )
443
                    if to_update.shape == (1, 1):
1✔
444
                        update *= to_update[0]  # cast out shape of (1, 1) to float
1✔
445
                    else:
446
                        update *= to_update  # something errors here for CII
1✔
447
                    # fmt: on
448
                    self.shapley_interactions[interactions_with_ancestor_to_update] -= update
1✔
449

450
    @staticmethod
1✔
451
    def _psi_ancestor(
1✔
452
        E: np.ndarray,
453
        D_power: np.ndarray,
454
        quotient_poly: np.ndarray,
455
        Ns: np.ndarray,
456
        degree: int,
457
    ) -> np.ndarray:
458
        """Similar to _psi but with ancestors."""
459
        d = degree + 1
1✔
460
        n = Ns[d].T  # Variant of _psi that can deal with multiple inputs in degree
1✔
461
        return np.diag((E * D_power / quotient_poly).dot(n)) / (d)
1✔
462

463
    @staticmethod
1✔
464
    def _psi(
1✔
465
        E: np.ndarray,
466
        D_power: np.ndarray,
467
        quotient_poly: np.ndarray,
468
        Ns: np.ndarray,
469
        degree: int,
470
    ) -> np.ndarray[float]:
471
        """Computes the psi function for the TreeSHAP-IQ algorithm.
472

473
        It scales the interaction polynomials with the summary polynomial and the quotient
474
        polynomial. For details, refer to `Muschalik et al. (2024) <https://doi.org/10.48550/arXiv.2401.12069>`_.
475

476
        Args:
477
            E: The summary polynomial.
478
            D_power: The power of the D polynomial.
479
            quotient_poly: The quotient polynomial.
480
            Ns: The Ns polynomial.
481
            degree: The degree of the interaction polynomial.
482

483
        Returns:
484
            np.ndarray: The computed psi function.
485

486
        """
487
        d = degree + 1
1✔
488
        n = Ns[d, :d]
1✔
489
        return ((E * D_power / quotient_poly)[:, :d]).dot(n) / d
1✔
490

491
    def _init_summary_polynomials(self) -> None:
1✔
492
        """Initializes the summary polynomial variables.
493

494
        Note:
495
            This function is called once during the initialization of the explainer.
496
        """
497
        for order in range(1, self._max_order + 1):
1✔
498
            subset_ancestors: dict[int, np.ndarray] = self._precalculate_interaction_ancestors(
1✔
499
                interaction_order=order,
500
                n_features=self._n_features_in_tree,
501
            )
502
            self.subset_ancestors_store[order] = subset_ancestors
1✔
503

504
            # If the tree has only one feature, we assign a default value of 0
505
            self.D_store[order] = np.polynomial.chebyshev.chebpts2(self.n_interpolation_size)
1✔
506

507
            self.D_powers_store[order] = self._cache(self.D_store[order])
1✔
508
            if self._index in ("SV", "SII", "k-SII"):
1✔
509
                self.Ns_store[order] = self._get_n_matrix(self.D_store[order])
1✔
510
            else:
511
                self.Ns_store[order] = self._get_n_cii_matrix(self.D_store[order], order)
1✔
512
            self.Ns_id_store[order] = self._get_n_id_matrix(self.D_store[order])
1✔
513

514
    def _get_polynomials(
1✔
515
        self,
516
        order: int,
517
        summary_poly_down: np.ndarray[float] | None = None,
518
        summary_poly_up: np.ndarray[float] | None = None,
519
        interaction_poly_down: np.ndarray[float] | None = None,
520
        quotient_poly_down: np.ndarray[float] | None = None,
521
    ) -> tuple[np.ndarray[float], np.ndarray[float], np.ndarray[float], np.ndarray[float]]:
522
        """Retrieves the polynomials for a given interaction order.
523

524
        This function initializes the polynomials for the first call of the recursive explanation
525
        function.
526

527
        Args:
528
            order: The interaction order for which the polynomials should be loaded.
529

530
            summary_poly_down: The summary polynomial for the current node. Defaults to ``None``.
531

532
            summary_poly_up: The summary polynomial propagated up the tree. Defaults to ``None``.
533

534
            interaction_poly_down: The interaction polynomial for the current node. Defaults to
535
                ``None``.
536

537
            quotient_poly_down: The quotient polynomial for the current node. Defaults to ``None``.
538

539
        Returns:
540
            The summary polynomial down, the summary polynomial up, the interaction polynomial down,
541
                and the quotient polynomial down.
542

543
        """
544
        if summary_poly_down is None:
1✔
545
            summary_poly_down = np.zeros((self._edge_tree.max_depth + 1, self.n_interpolation_size))
1✔
546
            summary_poly_down[0, :] = 1
1✔
547
        if summary_poly_up is None:
1✔
548
            summary_poly_up = np.zeros((self._edge_tree.max_depth + 1, self.n_interpolation_size))
1✔
549
        if interaction_poly_down is None:
1✔
550
            interaction_poly_down = np.zeros(
1✔
551
                (
552
                    self._edge_tree.max_depth + 1,
553
                    int(sp.special.binom(self._n_features_in_tree, order)),
554
                    self.n_interpolation_size,
555
                ),
556
            )
557
            interaction_poly_down[0, :] = 1
1✔
558
        if quotient_poly_down is None:
1✔
559
            quotient_poly_down = np.zeros(
1✔
560
                (
561
                    self._edge_tree.max_depth + 1,
562
                    int(sp.special.binom(self._n_features_in_tree, order)),
563
                    self.n_interpolation_size,
564
                ),
565
            )
566
            quotient_poly_down[0, :] = 1
1✔
567
        return summary_poly_down, summary_poly_up, interaction_poly_down, quotient_poly_down
1✔
568

569
    def _prepare_variables_for_order(self, interaction_order: int) -> None:
1✔
570
        """Retrieves the precomputed variables for a given interaction order.
571

572
        This function is called before the recursive explanation function is called.
573

574
        Args:
575
            interaction_order (int): The interaction order for which the storage variables should be
576
                loaded.
577

578
        """
579
        self.subset_updates_pos = self._interaction_update_positions[interaction_order]
1✔
580
        self.subset_ancestors = self.subset_ancestors_store[interaction_order]
1✔
581
        self.D = self.D_store[interaction_order]
1✔
582
        self.D_powers = self.D_powers_store[interaction_order]
1✔
583
        self._int_height = self._edge_tree.interaction_height_store[interaction_order]
1✔
584
        self.Ns_id = self.Ns_id_store[interaction_order]
1✔
585
        self.Ns = self.Ns_store[interaction_order]
1✔
586

587
    def _init_interaction_lookup_tables(self) -> None:
1✔
588
        """Initializes the lookup tables for the interaction subsets."""
589
        for order in range(1, self._max_order + 1):
1✔
590
            order_interactions_lookup = generate_interaction_lookup(
1✔
591
                self._n_features_in_tree,
592
                order,
593
                order,
594
            )
595
            self._interactions_lookup[order] = order_interactions_lookup
1✔
596
            _, interaction_update_positions = self._precompute_subsets_with_feature(
1✔
597
                interaction_order=order,
598
                n_features=self._n_features_in_tree,
599
                order_interactions_lookup=order_interactions_lookup,
600
            )
601
            self._interaction_update_positions[order] = interaction_update_positions
1✔
602

603
    @staticmethod
1✔
604
    def _precompute_subsets_with_feature(
1✔
605
        n_features: int,
606
        interaction_order: int,
607
        order_interactions_lookup: dict[tuple, int],
608
    ) -> tuple[dict[int, list[tuple]], dict[int, np.ndarray[int]]]:
609
        """Precomputes the subsets of interactions that include a given feature.
610

611
        Args:
612
            n_features: The number of features in the model.
613
            interaction_order: The interaction order to be computed.
614
            order_interactions_lookup: The lookup table of interaction subsets to their positions
615
                in the interaction values array for a given interaction order (e.g. all 2-way
616
                interactions for order ``2``).
617

618
        Returns:
619
            interaction_updates: A dictionary (lookup table) containing the interaction subsets
620
                for each feature given an interaction order.
621
            interaction_update_positions: A dictionary (lookup table) containing the positions of
622
                the interaction subsets to update for each feature given an interaction order.
623

624
        """
625
        # stores interactions that include feature i (needs to be updated when feature i appears)
626
        interaction_updates: dict[int, list[tuple]] = {}
1✔
627
        # stores position of interactions that include feature i
628
        interaction_update_positions: dict[int, np.ndarray] = {}
1✔
629

630
        # prepare the interaction updates and positions
631
        for feature_i in range(n_features):
1✔
632
            positions = np.zeros(
1✔
633
                int(sp.special.binom(n_features - 1, interaction_order - 1)),
634
                dtype=int,
635
            )
636
            interaction_update_positions[feature_i] = positions.copy()
1✔
637
            interaction_updates[feature_i] = []
1✔
638

639
        # fill the interaction updates and positions
640
        position_counter = np.zeros(n_features, dtype=int)  # used to keep track of the position
1✔
641
        for interaction in powerset(
1✔
642
            range(n_features),
643
            min_size=interaction_order,
644
            max_size=interaction_order,
645
        ):
646
            for i in interaction:
1✔
647
                interaction_updates[i].append(interaction)
1✔
648
                position = position_counter[i]
1✔
649
                interaction_update_positions[i][position] = order_interactions_lookup[interaction]
1✔
650
                position_counter[i] += 1
1✔
651

652
        return interaction_updates, interaction_update_positions
1✔
653

654
    def _precalculate_interaction_ancestors(
1✔
655
        self,
656
        interaction_order: int,
657
        n_features: int,
658
    ) -> dict[int, np.ndarray]:
659
        """Computes the ancestors of the interactions for a given order of interactions.
660

661
        Calculates the position of the ancestors of the interactions for the tree for a given
662
        order of interactions.
663

664
        Args:
665
            interaction_order: The interaction order for which the ancestors should be computed.
666
            n_features: The number of features in the model.
667

668
        Returns:
669
            subset_ancestors: A dictionary containing the ancestors of the interactions for each
670
                node in the tree.
671

672
        """
673
        # stores position of interactions
674
        subset_ancestors: dict[int, np.ndarray] = {}
1✔
675

676
        for node_id in self._tree.nodes[1:]:  # for all nodes except the root node
1✔
677
            subset_ancestors[node_id] = np.full(
1✔
678
                int(sp.special.binom(n_features, interaction_order)), -1, dtype=int
679
            )
680
        for i, S in enumerate(powerset(range(n_features), interaction_order, interaction_order)):
1✔
681
            for node_id in self._tree.nodes[1:]:  # for all nodes except the root node
1✔
682
                subset_ancestor = -1
1✔
683
                for feature in S:
1✔
684
                    subset_ancestor = max(
1✔
685
                        subset_ancestor,
686
                        self._edge_tree.ancestor_nodes[node_id][feature],
687
                    )
688
                subset_ancestors[node_id][i] = subset_ancestor
1✔
689
        return subset_ancestors
1✔
690

691
    @staticmethod
1✔
692
    def _get_n_matrix(interpolated_poly: np.ndarray) -> np.ndarray:
1✔
693
        """Computes the N matrix for the Shapley interaction values.
694

695
        Args:
696
            interpolated_poly: The interpolated polynomial.
697

698
        Returns:
699
            The N matrix.
700

701
        """
702
        depth = interpolated_poly.shape[0]
1✔
703
        Ns = np.zeros((depth + 1, depth))
1✔
704
        for i in range(1, depth + 1):
1✔
705
            Ns[i, :i] = np.linalg.inv(np.vander(interpolated_poly[:i]).T).dot(
1✔
706
                1.0 / np.array([sp.special.binom(i - 1, k) for k in range(i)])
707
            )
708
        return Ns
1✔
709

710
    def _get_n_cii_matrix(self, interpolated_poly: np.ndarray, order: int) -> np.ndarray:
1✔
711
        """Computes the N matrix for the CII index."""
712
        depth = interpolated_poly.shape[0]
1✔
713
        Ns = np.zeros((depth + 1, depth))
1✔
714
        for i in range(1, depth + 1):
1✔
715
            Ns[i, :i] = np.linalg.inv(np.vander(interpolated_poly[:i]).T).dot(
1✔
716
                i * np.array([self._get_subset_weight_cii(j, order) for j in range(i)]),
717
            )
718
        return Ns
1✔
719

720
    def _get_subset_weight_cii(self, t: int, order: int) -> float | None:
1✔
721
        """Computes the weight for a given subset size and interaction order.
722

723
        Args:
724
            t: The size of the subset.
725
            order: The interaction order.
726

727
        Returns:
728
            float | None: The weight for the subset, or None if the index is not supported.
729
        """
730
        if self._index == "STII":
1✔
731
            return self._max_order / (
1✔
732
                self._n_features_in_tree * sp.special.binom(self._n_features_in_tree - 1, t)
733
            )
734
        if self._index == "FSII":
1✔
735
            return (
1✔
736
                factorial(2 * self._max_order - 1)
737
                / factorial(self._max_order - 1) ** 2
738
                * factorial(self._max_order + t - 1)
739
                * factorial(self._n_features_in_tree - t - 1)
740
                / factorial(self._n_features_in_tree + self._max_order - 1)
741
            )
742
        if self._index == "BII":
1✔
743
            return 1 / (2 ** (self._n_features_in_tree - order))
1✔
744
        return None
×
745

746
    @staticmethod
1✔
747
    def _get_n_id_matrix(D: np.ndarray) -> np.ndarray:
1✔
748
        """Computes N_id matrix."""
749
        depth = D.shape[0]
1✔
750
        Ns_id = np.zeros((depth + 1, depth))
1✔
751
        for i in range(1, depth + 1):
1✔
752
            Ns_id[i, :i] = np.linalg.inv(np.vander(D[:i]).T).dot(np.ones(i))
1✔
753
        return Ns_id
1✔
754

755
    @staticmethod
1✔
756
    def _cache(interpolated_poly: np.ndarray[float]) -> np.ndarray[float]:
1✔
757
        """Caches the powers of the interpolated polynomial.
758

759
        Args:
760
            interpolated_poly: The interpolated polynomial.
761

762
        Returns:
763
            The cached powers of the interpolated polynomial.
764

765
        """
766
        return np.vander(interpolated_poly + 1).T[::-1]
1✔
767

768
    def _print_tree_info(self) -> None:
1✔
769
        """Prints information about the tree to be explained."""
770
        information = "Tree information:"
1✔
771
        information += f"\nNumber of nodes: {self._n_nodes}"
1✔
772
        information += f"\nNumber of features: {self._n_features_in_tree}"
1✔
773
        information += f"\nMaximum interaction order: {self._max_order}"
1✔
774
        information += f"\nInteraction index: {self._index}"
1✔
775
        information += f"\nEmpty prediction (from _tree): {self._tree.empty_prediction}"
1✔
776
        information += f"\nEmpty prediction (from self): {self.empty_prediction}"
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc