• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.18
/src/shapiq/explainer/tree/treeshapiq.py
1
"""Implementation of the tree explainer."""
2

3
from __future__ import annotations
1✔
4

5
import copy
1✔
6
from math import factorial
1✔
7
from typing import TYPE_CHECKING, Literal
1✔
8

9
import numpy as np
1✔
10
import scipy as sp
1✔
11

12
from shapiq.game_theory.indices import get_computation_index
1✔
13
from shapiq.interaction_values import InteractionValues
1✔
14
from shapiq.utils.sets import generate_interaction_lookup, powerset
1✔
15

16
from .conversion.edges import create_edge_tree
1✔
17
from .validation import validate_tree_model
1✔
18

19
if TYPE_CHECKING:
1✔
NEW
20
    from shapiq.typing import FloatVector, IntVector, Model
×
21

22
    from .base import EdgeTree, TreeModel
×
23

24

25
TreeSHAPIQIndices = Literal["SV", "SII", "k-SII"]
1✔
26

27

28
class TreeSHAPIQ:
1✔
29
    """The TreeSHAP-IQ computation class.
30

31
    This class implements the TreeSHAP-IQ algorithm for computing Shapley Interaction values for
32
    tree-based models. It is used internally by the
33
    :class:`~shapiq.explainer.tree.explainer.TreeExplainer`. The TreeSHAP-IQ algorithm is presented
34
    in `Muschalik et al. (2024)` [Mus24]_.
35

36
    TreeSHAP-IQ is an algorithm for computing Shapley Interaction values for tree-based models.
37
    It is based on the Linear TreeSHAP algorithm by `Yu et al. (2022)` [Yu22]_, but extended to
38
    compute Shapley Interaction values up to a given order. TreeSHAP-IQ needs to visit each node
39
    only once and makes use of polynomial arithmetic to compute the Shapley Interaction values
40
    efficiently.
41

42
    Note:
43
        This class is not intended to be used directly. Instead, use the ``TreeExplainer`` class to
44
        explain tree-based models which internally uses then the TreeSHAP-IQ algorithm.
45

46
    References:
47
        .. [Yu22] Peng Yu, Chao Xu, Albert Bifet, Jesse Read Linear Tree Shap (2022). In: Proceedings of 36th Conference on Neural Information Processing Systems. https://openreview.net/forum?id=OzbkiUo24g
48
        .. [Mus24] Maximilian Muschalik, Fabian Fumagalli, Barbara Hammer, & Eyke Hüllermeier (2024). Beyond TreeSHAP: Efficient Computation of Any-Order Shapley Interactions for Tree Ensembles. In: Proceedings of the AAAI Conference on Artificial Intelligence, 38(13), 14388-14396. https://doi.org/10.1609/aaai.v38i13.29352
49

50
    """
51

52
    def __init__(
1✔
53
        self,
54
        model: dict | TreeModel | Model,
55
        *,
56
        max_order: int = 2,
57
        min_order: int = 1,
58
        index: TreeSHAPIQIndices = "k-SII",
59
        verbose: bool = False,
60
    ) -> None:
61
        """Initializes the TreeSHAP-IQ explainer.
62

63
        Args:
64
            model: A single tree model to explain. Note that unlike the
65
                :class:`~shapiq.explainer.tree.explainer.TreeExplainer` class, TreeSHAP-IQ only
66
                supports a single tree. It can be a dictionary representation of the tree, a
67
                :class:`~shapiq.explainer.tree.base.TreeModel` object, or any other single tree
68
                model supported by the :meth:`~shapiq.explainer.tree.validation.validate_tree_model`
69
                function.
70

71
            max_order: The maximum interaction order to be computed. An interaction order of ``1``
72
                corresponds to the Shapley value. Any value higher than ``1`` computes the Shapley
73
                interaction values up to that order. Defaults to ``2``.
74

75
            min_order: The minimum interaction order to be computed. Defaults to ``1``. Note that
76
                setting min_order currently does not have any effect on the computation.
77

78
            index: The type of interaction to be computed.
79

80
            verbose: Whether to print information about the tree during initialization. Defaults to
81
                ``False``.
82

83
        """
84
        # set parameters
85
        self._root_node_id = 0
1✔
86
        self.verbose = verbose
1✔
87
        if max_order < min_order or max_order < 1 or min_order < 1:
1✔
88
            msg = (
1✔
89
                "The maximum order must be greater than the minimum order and both must be greater "
90
                "than 0."
91
            )
92
            raise ValueError(msg)
1✔
93
        self._max_order: int = max_order
1✔
94
        self._min_order: int = min_order
1✔
95
        self._index: str = index
1✔
96
        self._base_index: str = get_computation_index(self._index)
1✔
97

98
        # validate and parse model
99
        validated_model = validate_tree_model(model)  # the parsed and validated model
1✔
100
        # TODO(mmshlk): add support for other sample weights https://github.com/mmschlk/shapiq/issues/99
101
        self._tree: TreeModel = copy.deepcopy(validated_model)[0]
1✔
102
        self._relevant_features: np.ndarray = np.array(list(self._tree.feature_ids), dtype=int)
1✔
103
        self._tree.reduce_feature_complexity()
1✔
104
        self._n_nodes: int = self._tree.n_nodes
1✔
105
        self._n_features_in_tree: int = self._tree.n_features_in_tree
1✔
106
        self._max_feature_id: int = self._tree.max_feature_id
1✔
107
        self._feature_ids: set = self._tree.feature_ids
1✔
108

109
        # precompute interaction lookup tables
110
        self._interactions_lookup_relevant: dict[tuple, int] = generate_interaction_lookup(
1✔
111
            self._relevant_features,
112
            self._min_order,
113
            self._max_order,
114
        )
115
        self._interactions_lookup: dict[int, dict[tuple, int]] = {}  # lookup for interactions
1✔
116
        self._interaction_update_positions: dict[int, dict[int, IntVector]] = {}  # lookup
1✔
117
        self._init_interaction_lookup_tables()
1✔
118

119
        # get the edge representation of the tree
120
        edge_tree = create_edge_tree(
1✔
121
            children_left=self._tree.children_left,
122
            children_right=self._tree.children_right,
123
            features=self._tree.features,
124
            node_sample_weight=self._tree.node_sample_weight,
125
            values=self._tree.values,
126
            max_interaction=self._max_order,
127
            n_features=self._max_feature_id + 1,
128
            n_nodes=self._n_nodes,
129
            subset_updates_pos_store=self._interaction_update_positions,
130
        )
131
        self._edge_tree: EdgeTree = copy.deepcopy(edge_tree)
1✔
132

133
        # compute the empty prediction
134
        computed_empty_prediction = float(
1✔
135
            np.sum(self._edge_tree.empty_predictions[self._tree.leaf_mask]),
136
        )
137
        tree_empty_prediction = self._tree.empty_prediction
1✔
138
        if tree_empty_prediction is None:
1✔
139
            tree_empty_prediction = computed_empty_prediction
×
140
        self.empty_prediction: float = tree_empty_prediction
1✔
141

142
        # stores the interaction scores up to a given order
143
        self.subset_ancestors_store: dict = {}
1✔
144
        self.D_store: dict = {}
1✔
145
        self.D_powers_store: dict = {}
1✔
146
        self.Ns_id_store: dict = {}
1✔
147
        self.Ns_store: dict = {}
1✔
148
        self.n_interpolation_size = self._n_features_in_tree
1✔
149
        if self._index in ("SV", "SII", "k-SII"):  # SP is of order at most d_max
1✔
150
            self.n_interpolation_size = min(self._edge_tree.max_depth, self._n_features_in_tree)
1✔
151
        try:
1✔
152
            self._init_summary_polynomials()
1✔
153
            self._trivial_computation = False
1✔
154
        except ValueError:
1✔
155
            if self._n_features_in_tree == 1:
1✔
156
                self._trivial_computation = True  # for one feature the computation is trivial
1✔
157
            else:
158
                raise
×
159

160
        # stores the nodes that are active in the tree for a given instance (new for each instance)
161
        self._activations: np.ndarray = np.zeros(self._n_nodes, dtype=bool)
1✔
162

163
        # print tree information
164
        if self.verbose:
1✔
165
            self._print_tree_info()
1✔
166

167
    def explain(self, x: np.ndarray) -> InteractionValues:
1✔
168
        """Computes the Shapley Interaction values for a given instance ``x`` and interaction order.
169

170
        Note:
171
            This function is the main explanation function of this class.
172

173
        Args:
174
            x (np.ndarray): Instance to be explained.
175

176
        Returns:
177
            InteractionValues: The computed Shapley Interaction values.
178

179
        """
180
        x_relevant = x[self._relevant_features]
1✔
181
        n_players = max(x.shape[0], self._n_features_in_tree)
1✔
182

183
        if self._trivial_computation:
1✔
184
            interactions = self._compute_trivial_shapley_interaction_values(x)
1✔
185
        else:
186
            # compute the Shapley Interaction values
187
            interactions = np.asarray([], dtype=float)
1✔
188
            for order in range(self._min_order, self._max_order + 1):
1✔
189
                shapley_interactions = np.zeros(
1✔
190
                    int(sp.special.binom(self._n_features_in_tree, order)),
191
                    dtype=float,
192
                )
193
                self.shapley_interactions = shapley_interactions
1✔
194
                self._prepare_variables_for_order(interaction_order=order)
1✔
195
                self._compute_shapley_interaction_values(x_relevant, order=order, node_id=0)
1✔
196
                # append the computed Shapley Interaction values to the result
197
                interactions = np.append(interactions, self.shapley_interactions.copy())
1✔
198

199
        return InteractionValues(
1✔
200
            values=interactions,
201
            index=self._base_index,
202
            min_order=self._min_order,
203
            max_order=self._max_order,
204
            n_players=n_players,
205
            estimated=False,
206
            interaction_lookup=self._interactions_lookup_relevant,
207
            baseline_value=self.empty_prediction,
208
            target_index=self._index,
209
        )
210

211
    def _compute_trivial_shapley_interaction_values(self, x: np.ndarray) -> np.ndarray:
1✔
212
        """Computes the Shapley interactions for the case of only one feature in the tree.
213

214
        Computing the Shapley interactions for the case of only one feature in the tree is trivial
215
        since only the main effect of this feature is considered, i.e., the first order value of the
216
        single feature gets the full effect and all higher order values are zero.
217

218
        Args:
219
            x: The original instance to be explained.
220

221
        Returns:
222
            np.ndarray: The computed Shapley Interaction values.
223

224
        """
225
        full_prediction = self._tree.predict_one(x)
1✔
226
        main_effect = full_prediction - self.empty_prediction
1✔
227
        shapley_interactions = np.zeros(1, dtype=float)
1✔
228
        shapley_interactions[0] = main_effect
1✔
229
        return shapley_interactions
1✔
230

231
    def _compute_shapley_interaction_values(
1✔
232
        self,
233
        x: np.ndarray,
234
        order: int = 1,
235
        node_id: int = 0,
236
        *,
237
        summary_poly_down: FloatVector | None = None,
238
        summary_poly_up: FloatVector | None = None,
239
        interaction_poly_down: FloatVector | None = None,
240
        quotient_poly_down: FloatVector | None = None,
241
        depth: int = 0,
242
    ) -> None:
243
        """Computes the Shapley Interaction values for a given instance x and interaction order.
244

245
        Note:
246
            This function is called recursively for each node in the tree.
247

248
        Args:
249
            x: The instance to be explained.
250

251
            order: The interaction order for which the Shapley Interaction values should be
252
                computed. Defaults to ``1``.
253

254
            node_id: The node ID of the current node in the tree. Defaults to ``0``.
255

256
            summary_poly_down: The summary polynomial for the current node. Defaults to ``None``
257
                (at init time).
258

259
            summary_poly_up: The summary polynomial propagated up the tree. Defaults to ``None``
260
                (at init time).
261

262
            interaction_poly_down: The interaction polynomial for the current node. Defaults to
263
                ``None`` (at init time).
264

265
            quotient_poly_down: The quotient polynomial for the current node. Defaults to ``None``
266
                (at init time).
267

268
            depth: The depth of the current node in the tree. Defaults to ``0``.
269

270
        """
271
        # fmt: off
272
        # manually formatted for better readability in formulas and equations
273
        # reset activations for new calculations
274
        if node_id == 0:
1✔
275
            self._activations.fill(False)  # noqa: FBT003
1✔
276

277
        # get polynomials if None
278
        polynomials = self._get_polynomials(
1✔
279
            order=order,
280
            summary_poly_down=summary_poly_down,
281
            summary_poly_up=summary_poly_up,
282
            interaction_poly_down=interaction_poly_down,
283
            quotient_poly_down=quotient_poly_down,
284
        )
285
        summary_poly_down, summary_poly_up, interaction_poly_down, quotient_poly_down = polynomials
1✔
286

287
        # get related nodes (surrounding) nodes
288
        left_child = int(self._tree.children_left[node_id])
1✔
289
        right_child = int(self._tree.children_right[node_id])
1✔
290
        parent_id = int(self._edge_tree.parents[node_id])
1✔
291
        ancestor_id = int(self._edge_tree.ancestors[node_id])
1✔
292

293
        # get feature information
294
        feature_id = int(self._tree.features[parent_id])
1✔
295
        feature_threshold = self._tree.thresholds[node_id]
1✔
296
        child_edge_feature = self._tree.features[node_id]
1✔
297

298
        # get height of related nodes
299
        current_height = int(self._edge_tree.edge_heights[node_id])
1✔
300
        left_height = int(self._edge_tree.edge_heights[left_child])
1✔
301
        right_height = int(self._edge_tree.edge_heights[right_child])
1✔
302

303
        # get path information
304
        is_leaf = bool(self._tree.leaf_mask[node_id])
1✔
305
        has_ancestor = bool(self._edge_tree.has_ancestors[node_id])
1✔
306
        activations = self._activations
1✔
307

308
        # if feature_id > -1:
309
        try:
1✔
310
            interaction_sets = self.subset_updates_pos[feature_id]
1✔
311
        except KeyError:
1✔
312
            interaction_sets = np.array([], dtype=int)
1✔
313

314
        # if node is not a leaf -> set activations for children nodes accordingly
315
        if not is_leaf:
1✔
316
            if x[child_edge_feature] <= feature_threshold:
1✔
317
                activations[left_child], activations[right_child] = True, False
1✔
318
            else:
319
                activations[left_child], activations[right_child] = False, True
1✔
320

321
        # if node is not the root node -> calculate the summary polynomials
322
        if node_id != self._root_node_id:
1✔
323
            # set activations of current node in relation to the ancestor (for setting p_e to zero)
324
            if has_ancestor:
1✔
325
                activations[node_id] &= activations[ancestor_id]
1✔
326
            # if node is active get the correct p_e value
327
            p_e_current = self._edge_tree.p_e_values[node_id] if activations[node_id] else 0.0
1✔
328
            # update summary polynomial
329
            summary_poly_down[depth] = summary_poly_down[depth - 1] * (self.D + p_e_current)
1✔
330
            # update quotient polynomials
331
            quotient_poly_down[depth, :] = quotient_poly_down[depth - 1, :].copy()
1✔
332
            quotient_poly_down[depth, interaction_sets] = quotient_poly_down[depth, interaction_sets] * (self.D + p_e_current)
1✔
333
            # update interaction polynomial
334
            interaction_poly_down[depth, :] = interaction_poly_down[depth - 1, :].copy()
1✔
335
            interaction_poly_down[depth, interaction_sets] = interaction_poly_down[depth, interaction_sets] * (-self.D + p_e_current)
1✔
336
            # remove previous polynomial factor if node has ancestors
337
            if has_ancestor:
1✔
338
                p_e_ancestor = 0.0
1✔
339
                if activations[ancestor_id]:
1✔
340
                    p_e_ancestor = self._edge_tree.p_e_values[ancestor_id]
1✔
341
                # rescale the polynomials
342
                summary_poly_down[depth] = summary_poly_down[depth] / (self.D + p_e_ancestor)
1✔
343
                quotient_poly_down[depth, interaction_sets] = quotient_poly_down[depth, interaction_sets] / (self.D + p_e_ancestor)
1✔
344
                interaction_poly_down[depth, interaction_sets] = interaction_poly_down[depth, interaction_sets] / (-self.D + p_e_ancestor)
1✔
345

346
        # if node is leaf -> add the empty prediction to the summary polynomial and store it
347
        if is_leaf:  # recursion base case
1✔
348
            summary_poly_up[depth] = (
1✔
349
                summary_poly_down[depth] * self._edge_tree.empty_predictions[node_id]
350
            )
351
        else:  # not a leaf -> continue recursion
352
            # left child
353
            self._compute_shapley_interaction_values(
1✔
354
                x,
355
                order=order,
356
                node_id=left_child,
357
                summary_poly_down=summary_poly_down,
358
                summary_poly_up=summary_poly_up,
359
                interaction_poly_down=interaction_poly_down,
360
                quotient_poly_down=quotient_poly_down,
361
                depth=depth + 1,
362
            )
363
            summary_poly_up[depth] = (
1✔
364
                summary_poly_up[depth + 1] * self.D_powers[current_height - left_height]
365
            )
366
            # right child
367
            self._compute_shapley_interaction_values(
1✔
368
                x,
369
                order=order,
370
                node_id=right_child,
371
                summary_poly_down=summary_poly_down,
372
                summary_poly_up=summary_poly_up,
373
                interaction_poly_down=interaction_poly_down,
374
                quotient_poly_down=quotient_poly_down,
375
                depth=depth + 1,
376
            )
377
            summary_poly_up[depth] += (
1✔
378
                summary_poly_up[depth + 1] * self.D_powers[current_height - right_height]
379
            )
380

381
        # if node is not the root node -> calculate the Shapley Interaction values for the node
382
        if node_id is not self._root_node_id:
1✔
383
            interactions_seen = interaction_sets[
1✔
384
                self._int_height[node_id][interaction_sets] == order
385
            ]
386
            if len(interactions_seen) > 0:
1✔
387
                if self._index not in ("SV", "SII", "k-SII"):  # for CII
1✔
388
                    D_power = self.D_powers[self._n_features_in_tree - current_height]
1✔
389
                    index_quotient = self._n_features_in_tree - order
1✔
390
                else:  # for SII and k-SII
391
                    D_power = self.D_powers[0]
1✔
392
                    index_quotient = current_height - order
1✔
393
                interaction_update = np.dot(
1✔
394
                    interaction_poly_down[depth, interactions_seen],
395
                    self.Ns_id[self.n_interpolation_size, : self.n_interpolation_size],
396
                )
397
                interaction_update *= self._psi(
1✔
398
                    summary_poly_up[depth, :],
399
                    D_power,
400
                    quotient_poly_down[depth, interactions_seen],
401
                    self.Ns,
402
                    index_quotient,
403
                )
404
                self.shapley_interactions[interactions_seen] += interaction_update
1✔
405

406
            # if node has ancestors -> adjust the Shapley Interaction values for the node
407
            ancestors_of_interactions = self.subset_ancestors[node_id][interaction_sets]
1✔
408
            if np.any(ancestors_of_interactions > -1):  # at least one ancestor exists (not -1)
1✔
409
                ancestor_node_id_exists = ancestors_of_interactions > -1  # get mask of ancestors
1✔
410
                interactions_with_ancestor = interaction_sets[ancestor_node_id_exists]
1✔
411
                cond_interaction_seen = (
1✔
412
                    self._int_height[parent_id][interactions_with_ancestor] == order
413
                )
414
                interactions_ancestors = ancestors_of_interactions[ancestor_node_id_exists]
1✔
415
                interactions_with_ancestor_to_update = interactions_with_ancestor[
1✔
416
                    cond_interaction_seen
417
                ]
418
                if len(interactions_with_ancestor_to_update) > 0:
1✔
419
                    ancestor_heights = self._edge_tree.edge_heights[
1✔
420
                        interactions_ancestors[cond_interaction_seen]
421
                    ]
422
                    if self._index not in ("SV", "SII", "k-SII"):  # for CII
1✔
423
                        D_power = self.D_powers[self._n_features_in_tree - current_height]
1✔
424
                        index_quotient = self._n_features_in_tree - order
1✔
425
                    else:  # for SII and k-SII
426
                        D_power = self.D_powers[ancestor_heights - current_height]
1✔
427
                        index_quotient = (ancestor_heights - order)
1✔
428
                    update = np.dot(
1✔
429
                        interaction_poly_down[depth - 1, interactions_with_ancestor_to_update],
430
                        self.Ns_id[self.n_interpolation_size, : self.n_interpolation_size],
431
                    )
432
                    to_update = self._psi_ancestor(
1✔
433
                        summary_poly_up[depth],
434
                        D_power,
435
                        quotient_poly_down[depth - 1, interactions_with_ancestor_to_update],
436
                        self.Ns,
437
                        index_quotient,
438
                    )
439
                    if to_update.shape == (1, 1):
1✔
440
                        update *= to_update[0]  # cast out shape of (1, 1) to float
1✔
441
                    else:
442
                        update *= to_update  # something errors here for CII
1✔
443
                    # fmt: on
444
                    self.shapley_interactions[interactions_with_ancestor_to_update] -= update
1✔
445

446
    @staticmethod
1✔
447
    def _psi_ancestor(
1✔
448
        E: np.ndarray,
449
        D_power: np.ndarray,
450
        quotient_poly: np.ndarray,
451
        Ns: np.ndarray,
452
        degree: int | IntVector,
453
    ) -> np.ndarray:
454
        """Similar to _psi but with ancestors."""
455
        d = degree + 1
1✔
456
        n = Ns[d].T  # Variant of _psi that can deal with multiple inputs in degree
1✔
457
        return np.diag((E * D_power / quotient_poly).dot(n)) / (d)
1✔
458

459
    @staticmethod
1✔
460
    def _psi(
1✔
461
        E: np.ndarray,
462
        D_power: np.ndarray,
463
        quotient_poly: np.ndarray,
464
        Ns: np.ndarray,
465
        degree: int,
466
    ) -> FloatVector:
467
        """Computes the psi function for the TreeSHAP-IQ algorithm.
468

469
        It scales the interaction polynomials with the summary polynomial and the quotient
470
        polynomial. For details, refer to `Muschalik et al. (2024) <https://doi.org/10.48550/arXiv.2401.12069>`_.
471

472
        Args:
473
            E: The summary polynomial.
474
            D_power: The power of the D polynomial.
475
            quotient_poly: The quotient polynomial.
476
            Ns: The Ns polynomial.
477
            degree: The degree of the interaction polynomial.
478

479
        Returns:
480
            np.ndarray: The computed psi function.
481

482
        """
483
        d = degree + 1
1✔
484
        n = Ns[d, :d]
1✔
485
        return ((E * D_power / quotient_poly)[:, :d]).dot(n) / d
1✔
486

487
    def _init_summary_polynomials(self) -> None:
1✔
488
        """Initializes the summary polynomial variables.
489

490
        Note:
491
            This function is called once during the initialization of the explainer.
492
        """
493
        for order in range(1, self._max_order + 1):
1✔
494
            subset_ancestors: dict[int, np.ndarray] = self._precalculate_interaction_ancestors(
1✔
495
                interaction_order=order,
496
                n_features=self._n_features_in_tree,
497
            )
498
            self.subset_ancestors_store[order] = subset_ancestors
1✔
499

500
            # If the tree has only one feature, we assign a default value of 0
501
            self.D_store[order] = np.polynomial.chebyshev.chebpts2(self.n_interpolation_size)
1✔
502

503
            self.D_powers_store[order] = self._cache(self.D_store[order])
1✔
504
            if self._index in ("SV", "SII", "k-SII"):
1✔
505
                self.Ns_store[order] = self._get_n_matrix(self.D_store[order])
1✔
506
            else:
507
                self.Ns_store[order] = self._get_n_cii_matrix(self.D_store[order], order)
1✔
508
            self.Ns_id_store[order] = self._get_n_id_matrix(self.D_store[order])
1✔
509

510
    def _get_polynomials(
1✔
511
        self,
512
        order: int,
513
        summary_poly_down: FloatVector | None = None,
514
        summary_poly_up: FloatVector | None = None,
515
        interaction_poly_down: FloatVector | None = None,
516
        quotient_poly_down: FloatVector | None = None,
517
    ) -> tuple[FloatVector, FloatVector, FloatVector, FloatVector]:
518
        """Retrieves the polynomials for a given interaction order.
519

520
        This function initializes the polynomials for the first call of the recursive explanation
521
        function.
522

523
        Args:
524
            order: The interaction order for which the polynomials should be loaded.
525

526
            summary_poly_down: The summary polynomial for the current node. Defaults to ``None``.
527

528
            summary_poly_up: The summary polynomial propagated up the tree. Defaults to ``None``.
529

530
            interaction_poly_down: The interaction polynomial for the current node. Defaults to
531
                ``None``.
532

533
            quotient_poly_down: The quotient polynomial for the current node. Defaults to ``None``.
534

535
        Returns:
536
            The summary polynomial down, the summary polynomial up, the interaction polynomial down,
537
                and the quotient polynomial down.
538

539
        """
540
        if summary_poly_down is None:
1✔
541
            summary_poly_down = np.zeros((self._edge_tree.max_depth + 1, self.n_interpolation_size))
1✔
542
            summary_poly_down[0, :] = 1
1✔
543
        if summary_poly_up is None:
1✔
544
            summary_poly_up = np.zeros((self._edge_tree.max_depth + 1, self.n_interpolation_size))
1✔
545
        if interaction_poly_down is None:
1✔
546
            interaction_poly_down = np.zeros(
1✔
547
                (
548
                    self._edge_tree.max_depth + 1,
549
                    int(sp.special.binom(self._n_features_in_tree, order)),
550
                    self.n_interpolation_size,
551
                ),
552
            )
553
            interaction_poly_down[0, :] = 1
1✔
554
        if quotient_poly_down is None:
1✔
555
            quotient_poly_down = np.zeros(
1✔
556
                (
557
                    self._edge_tree.max_depth + 1,
558
                    int(sp.special.binom(self._n_features_in_tree, order)),
559
                    self.n_interpolation_size,
560
                ),
561
            )
562
            quotient_poly_down[0, :] = 1
1✔
563
        return summary_poly_down, summary_poly_up, interaction_poly_down, quotient_poly_down
1✔
564

565
    def _prepare_variables_for_order(self, interaction_order: int) -> None:
1✔
566
        """Retrieves the precomputed variables for a given interaction order.
567

568
        This function is called before the recursive explanation function is called.
569

570
        Args:
571
            interaction_order (int): The interaction order for which the storage variables should be
572
                loaded.
573

574
        """
575
        self.subset_updates_pos = self._interaction_update_positions[interaction_order]
1✔
576
        self.subset_ancestors = self.subset_ancestors_store[interaction_order]
1✔
577
        self.D = self.D_store[interaction_order]
1✔
578
        self.D_powers = self.D_powers_store[interaction_order]
1✔
579
        self._int_height = self._edge_tree.interaction_height_store[interaction_order]
1✔
580
        self.Ns_id = self.Ns_id_store[interaction_order]
1✔
581
        self.Ns = self.Ns_store[interaction_order]
1✔
582

583
    def _init_interaction_lookup_tables(self) -> None:
1✔
584
        """Initializes the lookup tables for the interaction subsets."""
585
        for order in range(1, self._max_order + 1):
1✔
586
            order_interactions_lookup = generate_interaction_lookup(
1✔
587
                self._n_features_in_tree,
588
                order,
589
                order,
590
            )
591
            self._interactions_lookup[order] = order_interactions_lookup
1✔
592
            _, interaction_update_positions = self._precompute_subsets_with_feature(
1✔
593
                interaction_order=order,
594
                n_features=self._n_features_in_tree,
595
                order_interactions_lookup=order_interactions_lookup,
596
            )
597
            self._interaction_update_positions[order] = interaction_update_positions
1✔
598

599
    @staticmethod
1✔
600
    def _precompute_subsets_with_feature(
1✔
601
        n_features: int,
602
        interaction_order: int,
603
        order_interactions_lookup: dict[tuple, int],
604
    ) -> tuple[dict[int, list[tuple]], dict[int, IntVector]]:
605
        """Precomputes the subsets of interactions that include a given feature.
606

607
        Args:
608
            n_features: The number of features in the model.
609
            interaction_order: The interaction order to be computed.
610
            order_interactions_lookup: The lookup table of interaction subsets to their positions
611
                in the interaction values array for a given interaction order (e.g. all 2-way
612
                interactions for order ``2``).
613

614
        Returns:
615
            interaction_updates: A dictionary (lookup table) containing the interaction subsets
616
                for each feature given an interaction order.
617
            interaction_update_positions: A dictionary (lookup table) containing the positions of
618
                the interaction subsets to update for each feature given an interaction order.
619

620
        """
621
        # stores interactions that include feature i (needs to be updated when feature i appears)
622
        interaction_updates: dict[int, list[tuple]] = {}
1✔
623
        # stores position of interactions that include feature i
624
        interaction_update_positions: dict[int, np.ndarray] = {}
1✔
625

626
        # prepare the interaction updates and positions
627
        for feature_i in range(n_features):
1✔
628
            positions = np.zeros(
1✔
629
                int(sp.special.binom(n_features - 1, interaction_order - 1)),
630
                dtype=int,
631
            )
632
            interaction_update_positions[feature_i] = positions.copy()
1✔
633
            interaction_updates[feature_i] = []
1✔
634

635
        # fill the interaction updates and positions
636
        position_counter = np.zeros(n_features, dtype=int)  # used to keep track of the position
1✔
637
        for interaction in powerset(
1✔
638
            range(n_features),
639
            min_size=interaction_order,
640
            max_size=interaction_order,
641
        ):
642
            for i in interaction:
1✔
643
                interaction_updates[i].append(interaction)
1✔
644
                position = position_counter[i]
1✔
645
                interaction_update_positions[i][position] = order_interactions_lookup[interaction]
1✔
646
                position_counter[i] += 1
1✔
647

648
        return interaction_updates, interaction_update_positions
1✔
649

650
    def _precalculate_interaction_ancestors(
1✔
651
        self,
652
        interaction_order: int,
653
        n_features: int,
654
    ) -> dict[int, np.ndarray]:
655
        """Computes the ancestors of the interactions for a given order of interactions.
656

657
        Calculates the position of the ancestors of the interactions for the tree for a given
658
        order of interactions.
659

660
        Args:
661
            interaction_order: The interaction order for which the ancestors should be computed.
662
            n_features: The number of features in the model.
663

664
        Returns:
665
            subset_ancestors: A dictionary containing the ancestors of the interactions for each
666
                node in the tree.
667

668
        """
669
        # stores position of interactions
670
        subset_ancestors: dict[int, np.ndarray] = {}
1✔
671

672
        for node_id in self._tree.nodes[1:]:  # for all nodes except the root node
1✔
673
            subset_ancestors[node_id] = np.full(
1✔
674
                int(sp.special.binom(n_features, interaction_order)), -1, dtype=int
675
            )
676
        for i, S in enumerate(powerset(range(n_features), interaction_order, interaction_order)):
1✔
677
            for node_id in self._tree.nodes[1:]:  # for all nodes except the root node
1✔
678
                subset_ancestor = -1
1✔
679
                for feature in S:
1✔
680
                    subset_ancestor = max(
1✔
681
                        subset_ancestor,
682
                        self._edge_tree.ancestor_nodes[node_id][feature],
683
                    )
684
                subset_ancestors[node_id][i] = subset_ancestor
1✔
685
        return subset_ancestors
1✔
686

687
    @staticmethod
1✔
688
    def _get_n_matrix(interpolated_poly: np.ndarray) -> np.ndarray:
1✔
689
        """Computes the N matrix for the Shapley interaction values.
690

691
        Args:
692
            interpolated_poly: The interpolated polynomial.
693

694
        Returns:
695
            The N matrix.
696

697
        """
698
        depth = interpolated_poly.shape[0]
1✔
699
        Ns = np.zeros((depth + 1, depth))
1✔
700
        for i in range(1, depth + 1):
1✔
701
            Ns[i, :i] = np.linalg.inv(np.vander(interpolated_poly[:i]).T).dot(
1✔
702
                1.0 / np.array([sp.special.binom(i - 1, k) for k in range(i)])
703
            )
704
        return Ns
1✔
705

706
    def _get_n_cii_matrix(self, interpolated_poly: np.ndarray, order: int) -> np.ndarray:
1✔
707
        """Computes the N matrix for the CII index."""
708
        depth = interpolated_poly.shape[0]
1✔
709
        Ns = np.zeros((depth + 1, depth))
1✔
710
        for i in range(1, depth + 1):
1✔
711
            Ns[i, :i] = np.linalg.inv(np.vander(interpolated_poly[:i]).T).dot(
1✔
712
                i * np.array([self._get_subset_weight_cii(j, order) for j in range(i)]),
713
            )
714
        return Ns
1✔
715

716
    def _get_subset_weight_cii(self, t: int, order: int) -> float | None:
1✔
717
        """Computes the weight for a given subset size and interaction order.
718

719
        Args:
720
            t: The size of the subset.
721
            order: The interaction order.
722

723
        Returns:
724
            float | None: The weight for the subset, or None if the index is not supported.
725
        """
726
        if self._index == "STII":
1✔
727
            return self._max_order / (
1✔
728
                self._n_features_in_tree * sp.special.binom(self._n_features_in_tree - 1, t)
729
            )
730
        if self._index == "FSII":
1✔
731
            return (
1✔
732
                factorial(2 * self._max_order - 1)
733
                / factorial(self._max_order - 1) ** 2
734
                * factorial(self._max_order + t - 1)
735
                * factorial(self._n_features_in_tree - t - 1)
736
                / factorial(self._n_features_in_tree + self._max_order - 1)
737
            )
738
        if self._index == "BII":
1✔
739
            return 1 / (2 ** (self._n_features_in_tree - order))
1✔
740
        return None
×
741

742
    @staticmethod
1✔
743
    def _get_n_id_matrix(D: np.ndarray) -> np.ndarray:
1✔
744
        """Computes N_id matrix."""
745
        depth = D.shape[0]
1✔
746
        Ns_id = np.zeros((depth + 1, depth))
1✔
747
        for i in range(1, depth + 1):
1✔
748
            Ns_id[i, :i] = np.linalg.inv(np.vander(D[:i]).T).dot(np.ones(i))
1✔
749
        return Ns_id
1✔
750

751
    @staticmethod
1✔
752
    def _cache(interpolated_poly: FloatVector) -> FloatVector:
1✔
753
        """Caches the powers of the interpolated polynomial.
754

755
        Args:
756
            interpolated_poly: The interpolated polynomial.
757

758
        Returns:
759
            The cached powers of the interpolated polynomial.
760

761
        """
762
        return np.vander(interpolated_poly + 1).T[::-1]
1✔
763

764
    def _print_tree_info(self) -> None:
1✔
765
        """Prints information about the tree to be explained."""
766
        information = "Tree information:"
1✔
767
        information += f"\nNumber of nodes: {self._n_nodes}"
1✔
768
        information += f"\nNumber of features: {self._n_features_in_tree}"
1✔
769
        information += f"\nMaximum interaction order: {self._max_order}"
1✔
770
        information += f"\nInteraction index: {self._index}"
1✔
771
        information += f"\nEmpty prediction (from _tree): {self._tree.empty_prediction}"
1✔
772
        information += f"\nEmpty prediction (from self): {self.empty_prediction}"
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc