• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 16647002974

31 Jul 2025 10:52AM UTC coverage: 93.906% (+0.005%) from 93.901%
16647002974

push

github

mmschlk
fix: Remove defense checks

4916 of 5235 relevant lines covered (93.91%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.37
/src/shapiq/explainer/tree/explainer.py
1
"""Implementation of the TreeExplainer class.
2

3
The :class:`~shapiq.explainer.tree.explainer.TreeSHAPIQ` uses the
4
:class:`~shapiq.explainer.tree.treeshapiq.TreeSHAPIQ` algorithm for computing any-order Interactions
5
for tree ensembles.
6
"""
7

8
from __future__ import annotations
1✔
9

10
import copy
1✔
11
from typing import TYPE_CHECKING, Any
1✔
12

13
from shapiq.explainer.base import Explainer
1✔
14

15
from .treeshapiq import TreeSHAPIQ, TreeSHAPIQIndices
1✔
16
from .validation import validate_tree_model
1✔
17

18
if TYPE_CHECKING:
1✔
19
    import numpy as np
×
20

21
    from shapiq.interaction_values import InteractionValues
×
22
    from shapiq.typing import Model
×
23

24
    from .base import TreeModel
×
25

26

27
class TreeExplainer(Explainer):
1✔
28
    """The TreeExplainer class for tree-based models.
29

30
    The explainer for tree-based models using the
31
    :class:`~shapiq.explainer.tree.treeshapiq.TreeSHAPIQ` algorithm. For details, refer to
32
    `Muschalik et al. (2024)` [Mus24]_.
33

34
    TreeSHAP-IQ is an algorithm for computing Shapley Interaction values for tree-based models.
35
    It is based on the Linear TreeSHAP algorithm by `Yu et al. (2022)` [Yu22]_, but extended to
36
    compute Shapley Interaction values up to a given order. TreeSHAP-IQ needs to visit each node
37
    only once and makes use of polynomial arithmetic to compute the Shapley Interaction values
38
    efficiently.
39

40
    The TreeExplainer can be used with a variety of tree-based models, including
41
    ``scikit-learn``, ``XGBoost``, and ``LightGBM``. The explainer can handle both regression and
42
    classification models.
43

44
    References:
45
        .. [Yu22] Peng Yu, Chao Xu, Albert Bifet, Jesse Read. (2022). Linear Tree Shap. In: Proceedings of 36th Conference on Neural Information Processing Systems. https://openreview.net/forum?id=OzbkiUo24g
46
        .. [Mus24] Maximilian Muschalik, Fabian Fumagalli, Barbara Hammer, & Eyke Hüllermeier (2024). Beyond TreeSHAP: Efficient Computation of Any-Order Shapley Interactions for Tree Ensembles. In: Proceedings of the AAAI Conference on Artificial Intelligence, 38(13), 14388-14396. https://doi.org/10.1609/aaai.v38i13.29352
47

48
    """
49

50
    def __init__(
1✔
51
        self,
52
        model: dict | TreeModel | list[TreeModel] | Model,
53
        *,
54
        max_order: int = 2,
55
        min_order: int = 0,
56
        index: TreeSHAPIQIndices = "k-SII",
57
        class_index: int | None = None,
58
        **kwargs: Any,  # noqa: ARG002
59
    ) -> None:
60
        """Initializes the TreeExplainer.
61

62
        Args:
63
            model: A tree-based model to explain.
64

65
            max_order: The maximum interaction order to be computed. An interaction order of ``1``
66
                corresponds to the Shapley value. Any value higher than ``1`` computes the Shapley
67
                interaction values up to that order. Defaults to ``2``.
68

69
            min_order: The minimum interaction order to be computed. Defaults to ``1``.
70

71
            index: The type of interaction to be computed. It can be one of
72
                ``["k-SII", "SII", "STII", "FSII", "BII", "SV"]``. All indices apart from ``"BII"``
73
                will reduce to the ``"SV"`` (Shapley value) for order 1. Defaults to ``"k-SII"``.
74

75
            class_index: The class index of the model to explain. Defaults to ``None``, which will
76
                set the class index to ``1`` per default for classification models and is ignored
77
                for regression models.
78

79
            **kwargs: Additional keyword arguments are ignored.
80

81
        """
82
        super().__init__(model, index=index, max_order=max_order)
1✔
83

84
        # validate and parse model
85
        validated_model = validate_tree_model(model, class_label=class_index)
1✔
86
        self._trees: list[TreeModel] | TreeModel = copy.deepcopy(validated_model)
1✔
87
        if not isinstance(self._trees, list):
1✔
88
            self._trees = [self._trees]
1✔
89
        self._n_trees = len(self._trees)
1✔
90

91
        self._min_order: int = min_order
1✔
92
        self._class_label: int | None = class_index
1✔
93

94
        # setup explainers for all trees
95
        self._treeshapiq_explainers: list[TreeSHAPIQ] = [
1✔
96
            TreeSHAPIQ(model=_tree, max_order=self._max_order, index=index) for _tree in self._trees
97
        ]
98
        self.baseline_value = self._compute_baseline_value()
1✔
99

100
    def explain_function(
1✔
101
        self,
102
        x: np.ndarray,
103
        **kwargs: Any,  # noqa: ARG002
104
    ) -> InteractionValues:
105
        """Computes the Shapley Interaction values for a single instance.
106

107
        Args:
108
            x: The instance to explain as a 1-dimensional array.
109
            **kwargs: Additional keyword arguments are ignored.
110

111
        Returns:
112
            The interaction values for the instance.
113

114
        """
115
        if len(x.shape) != 1:
1✔
116
            msg = "explain expects a single instance, not a batch."
×
117
            raise TypeError(msg)
×
118
        # run treeshapiq for all trees
119
        interaction_values: list[InteractionValues] = []
1✔
120
        for explainer in self._treeshapiq_explainers:
1✔
121
            tree_explanation = explainer.explain(x)
1✔
122
            interaction_values.append(tree_explanation)
1✔
123

124
        # combine the explanations for all trees
125
        final_explanation = interaction_values[0]
1✔
126
        if len(interaction_values) > 1:
1✔
127
            for i in range(1, len(interaction_values)):
1✔
128
                final_explanation += interaction_values[i]
1✔
129

130
        if self._min_order == 0 and final_explanation.min_order == 1:
1✔
131
            final_explanation.min_order = 0
1✔
132
            # Add the baseline value to the empty prediction
133
            # might break for some edge cases
134
            final_explanation.interactions[()] = float(final_explanation.baseline_value)
1✔
135

136
        return final_explanation
1✔
137

138
    def _compute_baseline_value(self) -> float:
1✔
139
        """Computes the baseline value for the explainer.
140

141
        The baseline value is the sum of the empty predictions of all trees in the ensemble.
142

143
        Returns:
144
            The baseline value for the explainer.
145

146
        """
147
        return sum(
1✔
148
            [treeshapiq.empty_prediction for treeshapiq in self._treeshapiq_explainers],
149
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc