• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18490684955

14 Oct 2025 08:39AM UTC coverage: 93.111% (-0.7%) from 93.845%
18490684955

Pull #430

github

web-flow
Merge f75cf495e into ddda4a730
Pull Request #430: Enhance type safety and fix bugs across the codebase

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5356 relevant lines covered (93.11%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.48
/src/shapiq/explainer/tree/conversion/edges.py
1
"""Conversion functions to parse a :class:`~shapiq.explainer.tree.base.TreeModel` into the :class:`~shapiq.explainer.tree.base.EdgeTree` format."""
2

3
from __future__ import annotations
1✔
4

5
from typing import TYPE_CHECKING
1✔
6

7
import numpy as np
1✔
8
from scipy.special import binom
1✔
9

10
from shapiq.explainer.tree.base import EdgeTree
1✔
11

12
if TYPE_CHECKING:
1✔
NEW
13
    from shapiq.typing import FloatVector, IntVector
×
14

15

16
def create_edge_tree(
1✔
17
    children_left: IntVector,
18
    children_right: IntVector,
19
    features: IntVector,
20
    node_sample_weight: FloatVector,
21
    values: FloatVector,
22
    n_nodes: int,
23
    n_features: int,
24
    max_interaction: int,
25
    subset_updates_pos_store: dict[int, dict[int, IntVector]],
26
) -> EdgeTree:
27
    """Extracts edge information recursively from the tree information.
28

29
    Parses the tree recursively to create an edge-based representation of the tree. It
30
    precalculates the ``p_e`` and ``p_e_ancestors`` of the interaction subsets up to order
31
    ``max_interaction``.
32

33
    Args:
34
        children_left (np.ndarray[int]): The left children of each node. Leaf nodes are denoted
35
            with ``-1``.
36
        children_right (np.ndarray[int]): The right children of each node. Leaf nodes are denoted
37
            with ``-1``.
38
        features (np.ndarray[int]): The feature used for splitting at each node. Leaf nodes have
39
            the value ``-2``.
40
        node_sample_weight (np.ndarray[float]): The sample weights of the tree.
41
        values (np.ndarray[float]): The output values at the leaf values of the tree.
42
        n_nodes (int): The number of nodes in the tree.
43
        n_features (int): The number of features of the dataset.
44
        max_interaction (int, optional): The maximum interaction order to be computed. An
45
            interaction order of ``1`` corresponds to the Shapley value. Any value higher than ``1``
46
            computes the Shapley interactions values up to that order. Defaults to ``1`` (i.e. SV).
47
        subset_updates_pos_store (dict[int, np.ndarray[int]]): A dictionary containing the
48
            interaction subsets for each feature given an interaction order.
49

50
    Returns:
51
        EdgeTree: A dataclass containing the edge information of the tree.
52

53
    """
54
    # variables to be filled with recursive function
55
    parents = np.full(n_nodes, -1, dtype=int)
1✔
56
    ancestors: np.ndarray = np.full(n_nodes, -1, dtype=int)
1✔
57

58
    ancestor_nodes: dict[int, np.ndarray] = {}
1✔
59

60
    p_e_values: np.ndarray = np.ones(n_nodes, dtype=float)
1✔
61
    p_e_storages: np.ndarray = np.ones((n_nodes, n_features), dtype=float)
1✔
62
    split_weights: np.ndarray = np.ones(n_nodes, dtype=float)
1✔
63
    empty_predictions: np.ndarray = np.zeros(n_nodes, dtype=float)
1✔
64
    edge_heights: np.ndarray = np.full_like(children_left, -1, dtype=int)
1✔
65
    max_depth: list[int] = [0]
1✔
66
    interaction_height_store = {
1✔
67
        i: np.zeros((n_nodes, int(binom(n_features, i))), dtype=int)
68
        for i in range(1, max_interaction + 1)
69
    }
70

71
    features_last_seen_in_tree: dict[int, int] = {}
1✔
72
    last_feature_node_in_path: np.ndarray = np.full_like(
1✔
73
        children_left, fill_value=False, dtype=bool
74
    )
75

76
    def recursive_search(
1✔
77
        node_id: int = 0,
78
        depth: int = 0,
79
        prod_weight: float = 1.0,
80
        seen_features: np.ndarray | None = None,
81
    ) -> int:
82
        """Traverses the tree recursively and collects all relevant information.
83

84
        Args:
85
            node_id (int): The current node id.
86
            depth (int): The depth of the current node.
87
            prod_weight (float): The product of the node weights on the path to the current
88
                node.
89
            seen_features (np.ndarray[int]): The features seen on the path to the current node.
90
                Maps the feature id to the node id where the feature was last seen on the way.
91

92
        Returns:
93
            The edge height of the current node.
94

95
        """
96
        # if root node, initialize seen_features and p_e_storage
97
        if seen_features is None:
1✔
98
            # map feature_id to ancestor node_id
99
            seen_features = np.full(n_features, -1, dtype=int)
1✔
100

101
        # update the maximum depth of the tree
102
        max_depth[0] = max(max_depth[0], depth)
1✔
103

104
        # set the parents of the children nodes
105
        left_child, right_child = children_left[node_id], children_right[node_id]
1✔
106
        is_leaf = left_child == -1
1✔
107
        if not is_leaf:
1✔
108
            parents[left_child], parents[right_child] = node_id, node_id
1✔
109
            features_last_seen_in_tree[int(features[node_id])] = node_id
1✔
110

111
        # if root_node, step into the tree and end recursion
112
        if node_id == 0:
1✔
113
            edge_heights_left = recursive_search(
1✔
114
                int(left_child),
115
                depth + 1,
116
                prod_weight,
117
                seen_features.copy(),
118
            )
119
            edge_heights_right = recursive_search(
1✔
120
                int(right_child),
121
                depth + 1,
122
                prod_weight,
123
                seen_features.copy(),
124
            )
125
            edge_heights[node_id] = max(edge_heights_left, edge_heights_right)
1✔
126
            return edge_heights[node_id]  # final return ending the recursion
1✔
127

128
        # node is not root node follow the path and compute weights
129

130
        ancestor_nodes[node_id] = seen_features.copy()
1✔
131

132
        # get the feature id of the current node
133
        feature_id = features[parents[node_id]]
1✔
134

135
        # Assume it is the last occurrence of feature
136
        last_feature_node_in_path[node_id] = True
1✔
137

138
        # compute prod_weight with node samples
139
        n_sample = node_sample_weight[node_id]
1✔
140
        n_parent = node_sample_weight[parents[node_id]]
1✔
141
        weight = n_sample / n_parent
1✔
142
        split_weights[node_id] = weight
1✔
143
        prod_weight *= weight
1✔
144

145
        # calculate the p_e value of the current node
146
        p_e = 1 / weight
1✔
147

148
        # copy parent height information
149
        for order in range(1, max_interaction + 1):
1✔
150
            interaction_height_store[order][node_id] = interaction_height_store[order][
1✔
151
                parents[node_id]
152
            ].copy()
153
        # correct if feature was seen before
154
        if seen_features[feature_id] > -1:  # feature has been seen before in the path
1✔
155
            ancestor_id = seen_features[feature_id]  # get ancestor node with same feature
1✔
156
            ancestors[node_id] = ancestor_id  # store ancestor node
1✔
157
            last_feature_node_in_path[ancestor_id] = False  # correct previous assumption
1✔
158
            p_e *= p_e_values[ancestor_id]  # add ancestor weight to p_e
1✔
159
        else:
160
            for order in range(1, max_interaction + 1):
1✔
161
                indices_to_update = subset_updates_pos_store[order][int(feature_id)]
1✔
162
                interaction_height_store[order][node_id][indices_to_update] += 1
1✔
163

164
        # store the p_e value of the current node
165
        p_e_values[node_id] = p_e
1✔
166
        p_e_storages[node_id] = p_e_storages[parents[node_id]].copy()
1✔
167
        p_e_storages[node_id][feature_id] = p_e
1✔
168

169
        # update seen features with current node
170
        seen_features[feature_id] = node_id
1✔
171

172
        # update the edge heights
173
        if not is_leaf:  # if node is not a leaf, continue recursion
1✔
174
            edge_heights_left = recursive_search(
1✔
175
                int(left_child),
176
                depth + 1,
177
                prod_weight,
178
                seen_features.copy(),
179
            )
180
            edge_heights_right = recursive_search(
1✔
181
                int(right_child),
182
                depth + 1,
183
                prod_weight,
184
                seen_features.copy(),
185
            )
186
            edge_heights[node_id] = max(edge_heights_left, edge_heights_right)
1✔
187
        else:  # if node is a leaf, end recursion
188
            edge_heights[node_id] = np.sum(seen_features > -1)
1✔
189
            empty_predictions[node_id] = prod_weight * values[node_id]
1✔
190
        return edge_heights[node_id]  # return upwards in the recursion
1✔
191

192
    _ = recursive_search()
1✔
193
    return EdgeTree(
1✔
194
        parents=parents,
195
        ancestors=ancestors,
196
        ancestor_nodes=ancestor_nodes,
197
        p_e_values=p_e_values,
198
        p_e_storages=p_e_storages,
199
        split_weights=split_weights,
200
        empty_predictions=empty_predictions,
201
        edge_heights=edge_heights,
202
        max_depth=max_depth[0],
203
        last_feature_node_in_path=last_feature_node_in_path,
204
        interaction_height_store=interaction_height_store,
205
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc