• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.74
/src/shapiq/explainer/tree/base.py
1
"""The base class for tree model conversion."""
2

3
from __future__ import annotations
1✔
4

5
from dataclasses import dataclass
1✔
6
from typing import TYPE_CHECKING
1✔
7

8
import numpy as np
1✔
9

10
from .utils import compute_empty_prediction
1✔
11

12
if TYPE_CHECKING:
1✔
13
    from numpy.typing import NDArray
×
14

15

16
@dataclass
1✔
17
class TreeModel:
1✔
18
    """A dataclass for storing the information of a tree model.
19

20
    The dataclass stores the information of a tree model in a way that is easy to access and
21
    manipulate. The dataclass is used to convert tree models from different libraries to a common
22
    format.
23

24
    Attributes:
25
        children_left: The left children of each node in a tree. Leaf nodes are ``-1``.
26
        children_right: The right children of each node in a tree. Leaf nodes are ``-1``.
27
        features: The feature indices of the decision nodes in a tree. Leaf nodes are assumed to be
28
            ``-2`` but no check is performed.
29
        thresholds: The thresholds of the decision nodes in a tree. Leaf nodes are set to ``np.NaN``.
30
        values: The values of the leaf nodes in a tree.
31
        node_sample_weight: The sample weights of the nodes in a tree.
32
        empty_prediction: The empty prediction of the tree model. The default value is ``None`.` Then
33
            the empty prediction is computed from the leaf values and the sample weights.
34
        leaf_mask: The boolean mask of the leaf nodes in a tree. The default value is ``None``. Then the
35
            leaf mask is computed from the children left and right arrays.
36
        n_features_in_tree: The number of features in the tree model. The default value is ``None``.
37
            Then the number of features in the tree model is computed from the unique feature
38
            indices in the features array.
39
        max_feature_id: The maximum feature index in the tree model. The default value is ``None``. Then
40
            the maximum feature index in the tree model is computed from the features array.
41
        feature_ids: The feature indices of the decision nodes in the tree model. The default value
42
            is ``None``. Then the feature indices of the decision nodes in the tree model are computed
43
            from the unique feature indices in the features array.
44
        root_node_id: The root node id of the tree model. The default value is ``None``. Then the root
45
            node id of the tree model is set to ``0``.
46
        n_nodes: The number of nodes in the tree model. The default value is ``None``. Then the number
47
            of nodes in the tree model is computed from the children left array.
48
        nodes: The node ids of the tree model. The default value is ``None``. Then the node ids of the
49
            tree model are computed from the number of nodes in the tree model.
50
        feature_map_original_internal: A mapping of feature indices from the original feature
51
            indices (as in the model) to the internal feature indices (as in the tree model).
52
        feature_map_internal_original: A mapping of feature indices from the internal feature
53
            indices (as in the tree model) to the original feature indices (as in the model).
54
        original_output_type: The original output type of the tree model. The default value is
55
            ``"raw"``.
56

57
    """
58

59
    children_left: NDArray[np.int_]
1✔
60
    children_right: NDArray[np.int_]
1✔
61
    features: NDArray[np.int_]
1✔
62
    thresholds: NDArray[np.floating]
1✔
63
    values: NDArray[np.floating]
1✔
64
    node_sample_weight: NDArray[np.floating]
1✔
65
    empty_prediction: float = None  # type: ignore[assignment]
1✔
66
    leaf_mask: NDArray[np.bool_] = None  # type: ignore[assignment]
1✔
67
    n_features_in_tree: int = None  # type: ignore[assignment]
1✔
68
    max_feature_id: int = None  # type: ignore[assignment]
1✔
69
    feature_ids: set = None  # type: ignore[assignment]
1✔
70
    root_node_id: int = None  # type: ignore[assignment]
1✔
71
    n_nodes: int = None  # type: ignore[assignment]
1✔
72
    nodes: NDArray[np.int_] = None  # type: ignore[assignment]
1✔
73
    feature_map_original_internal: dict[int, int] = None  # type: ignore[assignment]
1✔
74
    feature_map_internal_original: dict[int, int] = None  # type: ignore[assignment]
1✔
75
    original_output_type: str = "raw"  # not used at the moment
1✔
76

77
    def compute_empty_prediction(self) -> None:
1✔
78
        """Compute the empty prediction of the tree model.
79

80
        The method computes the empty prediction of the tree model by taking the weighted average of
81
        the leaf node values. The method modifies the tree model in place.
82
        """
83
        self.empty_prediction = compute_empty_prediction(
1✔
84
            self.values[self.leaf_mask],
85
            self.node_sample_weight[self.leaf_mask],
86
        )
87

88
    def __post_init__(self) -> None:
1✔
89
        """Clean-up after the initialization of the TreeModel dataclass.
90

91
        The method sets up the tree model with the information provided in the constructor.
92
        """
93
        # setup leaf mask
94
        if self.leaf_mask is None:
1✔
95
            self.leaf_mask = np.asarray(self.children_left == -1)
1✔
96
        # sanitize features
97
        self.features = np.where(self.leaf_mask, -2, self.features)
1✔
98
        self.features = self.features.astype(int)  # make features integer type
1✔
99
        # sanitize thresholds
100
        self.thresholds = np.where(self.leaf_mask, np.nan, self.thresholds)
1✔
101
        # setup empty prediction
102
        if self.empty_prediction is None:
1✔
103
            self.compute_empty_prediction()
1✔
104
        unique_features = set(np.unique(self.features))
1✔
105
        unique_features.discard(-2)  # remove leaf node "features"
1✔
106
        # setup number of features
107
        if self.n_features_in_tree is None:
1✔
108
            self.n_features_in_tree = int(len(unique_features))
1✔
109
        # setup max feature id
110
        if self.max_feature_id is None:
1✔
111
            self.max_feature_id = max(unique_features)
1✔
112
        # setup feature names
113
        if self.feature_ids is None:
1✔
114
            self.feature_ids = unique_features
1✔
115
        # setup root node id
116
        if self.root_node_id is None:
1✔
117
            self.root_node_id = 0
1✔
118
        # setup number of nodes
119
        if self.n_nodes is None:
1✔
120
            self.n_nodes = len(self.children_left)
1✔
121
        # setup nodes
122
        if self.nodes is None:
1✔
123
            self.nodes = np.arange(self.n_nodes)
1✔
124
        # setup original feature mapping
125
        if self.feature_map_original_internal is None:
1✔
126
            self.feature_map_original_internal = {i: i for i in unique_features}
1✔
127
        # setup new feature mapping
128
        if self.feature_map_internal_original is None:
1✔
129
            self.feature_map_internal_original = {i: i for i in unique_features}
1✔
130
        # flatten values if necessary
131
        if self.values.ndim > 1:
1✔
132
            if self.values.shape[1] != 1:
×
133
                msg = "Values array has more than one column."
×
134
                raise ValueError(msg)
×
135
            self.values = self.values.flatten()
×
136
        # set all values of non leaf nodes to zero
137
        self.values[~self.leaf_mask] = 0
1✔
138

139
    def reduce_feature_complexity(self) -> None:
1✔
140
        """Reduces the feature complexity of the tree model.
141

142
        The method reduces the feature complexity of the tree model by removing unused features and
143
        reindexing the feature indices of the decision nodes in the tree. The method modifies the
144
        tree model in place. To see the original feature mappings, use the ``feature_mapping_old_new``
145
        and ``feature_mapping_new_old`` attributes.
146

147
        For example, consider a tree model with the following feature indices:
148

149
            [0, 1, 8]
150

151
        The method will remove the unused feature indices and reindex the feature indices of the
152
        decision nodes in the tree to the following:
153

154
            [0, 1, 2]
155

156
        Feature ``'8'`` is 'renamed' to ``'2'`` such that in the internal representation a one-hot vector
157
        (and matrices) of length ``3`` suffices to represent the feature indices.
158
        """
159
        if self.n_features_in_tree < self.max_feature_id + 1:
1✔
160
            new_feature_ids = set(range(self.n_features_in_tree))
1✔
161
            mapping_old_new = {old_id: new_id for new_id, old_id in enumerate(self.feature_ids)}
1✔
162
            mapping_new_old = dict(enumerate(self.feature_ids))
1✔
163
            new_features = np.zeros_like(self.features)
1✔
164
            for i, old_feature in enumerate(self.features):
1✔
165
                new_value = -2 if old_feature == -2 else mapping_old_new[old_feature]
1✔
166
                new_features[i] = new_value
1✔
167
            self.features = new_features
1✔
168
            self.feature_ids = new_feature_ids
1✔
169
            self.feature_map_original_internal = mapping_old_new
1✔
170
            self.feature_map_internal_original = mapping_new_old
1✔
171
            self.n_features_in_tree = len(new_feature_ids)
1✔
172
            self.max_feature_id = self.n_features_in_tree - 1
1✔
173

174
    def predict_one(self, x: np.ndarray) -> float:
1✔
175
        """Predicts the output of a single instance.
176

177
        Args:
178
            x: The instance to predict as a 1-dimensional array.
179

180
        Returns:
181
            The prediction of the instance with the tree model.
182

183
        """
184
        node = self.root_node_id
1✔
185
        is_leaf = self.leaf_mask[node]
1✔
186
        while not is_leaf:
1✔
187
            feature_id_internal = self.features[node]
1✔
188
            feature_id_original = self.feature_map_internal_original[feature_id_internal]
1✔
189
            if x[feature_id_original] <= self.thresholds[node]:
1✔
190
                node = self.children_left[node]
1✔
191
            else:
192
                node = self.children_right[node]
1✔
193
            is_leaf = self.leaf_mask[node]
1✔
194
        return float(self.values[node])
1✔
195

196

197
class EdgeTree:
1✔
198
    """A dataclass for storing the information of an edge representation of the tree.
199

200
    The dataclass stores the information of an edge representation of the tree in a way that is easy
201
    to access and manipulate for the TreeSHAP-IQ algorithm.
202
    """
203

204
    parents: np.ndarray
1✔
205
    ancestors: np.ndarray
1✔
206
    ancestor_nodes: dict[int, np.ndarray]
1✔
207
    p_e_values: np.ndarray
1✔
208
    p_e_storages: np.ndarray
1✔
209
    split_weights: np.ndarray
1✔
210
    empty_predictions: np.ndarray
1✔
211
    edge_heights: np.ndarray
1✔
212
    max_depth: int
1✔
213
    last_feature_node_in_path: np.ndarray
1✔
214
    interaction_height_store: dict[int, np.ndarray]
1✔
215
    has_ancestors: np.ndarray
1✔
216

217
    def __init__(
1✔
218
        self,
219
        parents: np.ndarray,
220
        ancestors: np.ndarray,
221
        ancestor_nodes: dict[int, np.ndarray],
222
        p_e_values: np.ndarray,
223
        p_e_storages: np.ndarray,
224
        split_weights: np.ndarray,
225
        empty_predictions: np.ndarray,
226
        edge_heights: np.ndarray,
227
        max_depth: int,
228
        last_feature_node_in_path: np.ndarray,
229
        interaction_height_store: dict[int, np.ndarray],
230
        *,
231
        has_ancestors: np.ndarray | None = None,
232
    ) -> None:
233
        """Initializes the EdgeTree dataclass."""
234
        self.parents = parents
1✔
235
        self.ancestors = ancestors
1✔
236
        self.ancestor_nodes = ancestor_nodes
1✔
237
        self.p_e_values = p_e_values
1✔
238
        self.p_e_storages = p_e_storages
1✔
239
        self.split_weights = split_weights
1✔
240
        self.empty_predictions = empty_predictions
1✔
241
        self.edge_heights = edge_heights
1✔
242
        self.max_depth = max_depth
1✔
243
        self.last_feature_node_in_path = last_feature_node_in_path
1✔
244
        self.interaction_height_store = interaction_height_store
1✔
245
        if has_ancestors is None:
1✔
246
            self.has_ancestors = self.ancestors > -1
1✔
247
        else:
NEW
248
            self.has_ancestors = has_ancestors
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc