• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 16345920221

17 Jul 2025 01:05PM UTC coverage: 77.589% (-13.5%) from 91.075%
16345920221

push

github

web-flow
🔨 Refactors library into a src structure. (#415)

* moves shapiq into a src folder

* moves shapiq tests into tests_shapiq subfolder in tests

* refactors tests to work properly

* removes pickle support and closes #413

* changes unit tests to only run the unit tests

* adds workflow for running shapiq_games

* updates coverage to only run for shapiq

* update workflow to check for shapiq_games import

* update CHANGELOG.md

* fixes install-import.yml

* fixes version in docs

* moved deprecated tests out of the main test suite

* moves fixtures in the correct test suite

* installs libomp on macos runner (try bugfix)

* correct spelling

* removes libomp again

* moves os runs into individual workflows for easier debugging

* runs macOS on py3.13

* renames workflows

* installs libomp again on macOS

* downgraded to 3.11 and reinstall python

* try different uv version

* adds libomp

* changes skip to xfail in integration tests with wrong index/order combinations

* moves test out for debugging CI

* removes outdated test

* adds concurrency for quicker testsing

* re-adds randomly

* dont reset seed

* removed pytest-randomly again

* adds the tests back in

3 of 21 new or added lines in 19 files covered. (14.29%)

5536 of 7135 relevant lines covered (77.59%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.8
/src/shapiq/explainer/tree/explainer.py
1
"""Implementation of the TreeExplainer class.
2

3
The :class:`~shapiq.explainer.tree.explainer.TreeSHAPIQ` uses the
4
:class:`~shapiq.explainer.tree.treeshapiq.TreeSHAPIQ` algorithm for computing any-order Interactions
5
for tree ensembles.
6
"""
7

8
from __future__ import annotations
1✔
9

10
import copy
1✔
11
from typing import TYPE_CHECKING, Any
1✔
12

13
from shapiq.explainer.base import Explainer
1✔
14
from shapiq.interaction_values import InteractionValues, finalize_computed_interactions
1✔
15

16
from .treeshapiq import TreeSHAPIQ, TreeSHAPIQIndices
1✔
17
from .validation import validate_tree_model
1✔
18

19
if TYPE_CHECKING:
1✔
20
    import numpy as np
×
21

NEW
22
    from shapiq.typing import Model
×
23

24
    from .base import TreeModel
×
25

26

27
class TreeExplainer(Explainer):
1✔
28
    """The TreeExplainer class for tree-based models.
29

30
    The explainer for tree-based models using the
31
    :class:`~shapiq.explainer.tree.treeshapiq.TreeSHAPIQ` algorithm. For details, refer to
32
    `Muschalik et al. (2024)` [Mus24]_.
33

34
    TreeSHAP-IQ is an algorithm for computing Shapley Interaction values for tree-based models.
35
    It is based on the Linear TreeSHAP algorithm by `Yu et al. (2022)` [Yu22]_, but extended to
36
    compute Shapley Interaction values up to a given order. TreeSHAP-IQ needs to visit each node
37
    only once and makes use of polynomial arithmetic to compute the Shapley Interaction values
38
    efficiently.
39

40
    The TreeExplainer can be used with a variety of tree-based models, including
41
    ``scikit-learn``, ``XGBoost``, and ``LightGBM``. The explainer can handle both regression and
42
    classification models.
43

44
    References:
45
        .. [Yu22] Peng Yu, Chao Xu, Albert Bifet, Jesse Read. (2022). Linear Tree Shap. In: Proceedings of 36th Conference on Neural Information Processing Systems. https://openreview.net/forum?id=OzbkiUo24g
46
        .. [Mus24] Maximilian Muschalik, Fabian Fumagalli, Barbara Hammer, & Eyke Hüllermeier (2024). Beyond TreeSHAP: Efficient Computation of Any-Order Shapley Interactions for Tree Ensembles. In: Proceedings of the AAAI Conference on Artificial Intelligence, 38(13), 14388-14396. https://doi.org/10.1609/aaai.v38i13.29352
47

48
    """
49

50
    def __init__(
1✔
51
        self,
52
        model: dict | TreeModel | list[TreeModel] | Model,
53
        *,
54
        max_order: int = 2,
55
        min_order: int = 0,
56
        index: TreeSHAPIQIndices = "k-SII",
57
        class_index: int | None = None,
58
        **kwargs: Any,  # noqa: ARG002
59
    ) -> None:
60
        """Initializes the TreeExplainer.
61

62
        Args:
63
            model: A tree-based model to explain.
64

65
            max_order: The maximum interaction order to be computed. An interaction order of ``1``
66
                corresponds to the Shapley value. Any value higher than ``1`` computes the Shapley
67
                interaction values up to that order. Defaults to ``2``.
68

69
            min_order: The minimum interaction order to be computed. Defaults to ``1``.
70

71
            index: The type of interaction to be computed. It can be one of
72
                ``["k-SII", "SII", "STII", "FSII", "BII", "SV"]``. All indices apart from ``"BII"``
73
                will reduce to the ``"SV"`` (Shapley value) for order 1. Defaults to ``"k-SII"``.
74

75
            class_index: The class index of the model to explain. Defaults to ``None``, which will
76
                set the class index to ``1`` per default for classification models and is ignored
77
                for regression models.
78

79
            **kwargs: Additional keyword arguments are ignored.
80

81
        """
82
        super().__init__(model, index=index, max_order=max_order)
1✔
83

84
        # validate and parse model
85
        validated_model = validate_tree_model(model, class_label=class_index)
1✔
86
        self._trees: list[TreeModel] | TreeModel = copy.deepcopy(validated_model)
1✔
87
        if not isinstance(self._trees, list):
1✔
88
            self._trees = [self._trees]
1✔
89
        self._n_trees = len(self._trees)
1✔
90

91
        self._min_order: int = min_order
1✔
92
        self._class_label: int | None = class_index
1✔
93

94
        # setup explainers for all trees
95
        self._treeshapiq_explainers: list[TreeSHAPIQ] = [
1✔
96
            TreeSHAPIQ(model=_tree, max_order=self._max_order, index=index) for _tree in self._trees
97
        ]
98
        self.baseline_value = self._compute_baseline_value()
1✔
99

100
    def explain_function(
1✔
101
        self,
102
        x: np.ndarray,
103
        **kwargs: Any,  # noqa: ARG002
104
    ) -> InteractionValues:
105
        """Computes the Shapley Interaction values for a single instance.
106

107
        Args:
108
            x: The instance to explain as a 1-dimensional array.
109
            **kwargs: Additional keyword arguments are ignored.
110

111
        Returns:
112
            The interaction values for the instance.
113

114
        """
115
        if len(x.shape) != 1:
1✔
116
            msg = "explain expects a single instance, not a batch."
×
117
            raise TypeError(msg)
×
118
        # run treeshapiq for all trees
119
        interaction_values: list[InteractionValues] = []
1✔
120
        for explainer in self._treeshapiq_explainers:
1✔
121
            tree_explanation = explainer.explain(x)
1✔
122
            interaction_values.append(tree_explanation)
1✔
123

124
        # combine the explanations for all trees
125
        final_explanation = interaction_values[0]
1✔
126
        if len(interaction_values) > 1:
1✔
127
            for i in range(1, len(interaction_values)):
1✔
128
                final_explanation += interaction_values[i]
1✔
129

130
        if self._min_order == 0 and final_explanation.min_order == 1:
1✔
131
            final_explanation.min_order = 0
1✔
132
            final_explanation = finalize_computed_interactions(
1✔
133
                final_explanation,
134
                target_index=self._index,
135
            )
136
        return finalize_computed_interactions(
1✔
137
            final_explanation,
138
            target_index=self._index,
139
        )
140

141
    def _compute_baseline_value(self) -> float:
1✔
142
        """Computes the baseline value for the explainer.
143

144
        The baseline value is the sum of the empty predictions of all trees in the ensemble.
145

146
        Returns:
147
            The baseline value for the explainer.
148

149
        """
150
        return sum(
1✔
151
            [treeshapiq.empty_prediction for treeshapiq in self._treeshapiq_explainers],
152
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc