• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 16345920221

17 Jul 2025 01:05PM UTC coverage: 77.589% (-13.5%) from 91.075%
16345920221

push

github

web-flow
🔨 Refactors library into a src structure. (#415)

* moves shapiq into a src folder

* moves shapiq tests into tests_shapiq subfolder in tests

* refactors tests to work properly

* removes pickle support and closes #413

* changes unit tests to only run the unit tests

* adds workflow for running shapiq_games

* updates coverage to only run for shapiq

* update workflow to check for shapiq_games import

* update CHANGELOG.md

* fixes install-import.yml

* fixes version in docs

* moved deprecated tests out of the main test suite

* moves fixtures in the correct test suite

* installs libomp on macos runner (try bugfix)

* correct spelling

* removes libomp again

* moves os runs into individual workflows for easier debugging

* runs macOS on py3.13

* renames workflows

* installs libomp again on macOS

* downgraded to 3.11 and reinstall python

* try different uv version

* adds libomp

* changes skip to xfail in integration tests with wrong index/order combinations

* moves test out for debugging CI

* removes outdated test

* adds concurrency for quicker testsing

* re-adds randomly

* dont reset seed

* removed pytest-randomly again

* adds the tests back in

3 of 21 new or added lines in 19 files covered. (14.29%)

5536 of 7135 relevant lines covered (77.59%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.23
/src/shapiq/explainer/tree/conversion/sklearn.py
1
"""Functions for converting scikit-learn decision trees to the format used by shapiq."""
2

3
from __future__ import annotations
1✔
4

5
from typing import TYPE_CHECKING
1✔
6

7
import numpy as np
1✔
8

9
from shapiq.explainer.tree.base import TreeModel
1✔
10
from shapiq.utils import safe_isinstance
1✔
11

12
if TYPE_CHECKING:
1✔
NEW
13
    from shapiq.typing import Model
×
14

15

16
def convert_sklearn_forest(
1✔
17
    tree_model: Model,
18
    class_label: int | None = None,
19
) -> list[TreeModel]:
20
    """Transforms a scikit-learn random forest to the format used by shapiq.
21

22
    Args:
23
        tree_model: The scikit-learn random forest model to convert.
24
        class_label: The class label of the model to explain. Only used for classification models.
25
            Defaults to ``1``.
26

27
    Returns:
28
        The converted random forest model.
29

30
    """
31
    scaling = 1.0 / len(tree_model.estimators_)
1✔
32
    return [
1✔
33
        convert_sklearn_tree(tree, scaling=scaling, class_label=class_label)
34
        for tree in tree_model.estimators_
35
    ]
36

37

38
def convert_sklearn_tree(
1✔
39
    tree_model: Model,
40
    class_label: int | None = None,
41
    scaling: float = 1.0,
42
) -> TreeModel:
43
    """Convert a scikit-learn decision tree to the format used by shapiq.
44

45
    Args:
46
        tree_model: The scikit-learn decision tree model to convert.
47
        class_label: The class label of the model to explain. Only used for classification models.
48
            Defaults to ``1``.
49
        scaling: The scaling factor for the tree values.
50

51
    Returns:
52
        The converted decision tree model.
53

54
    """
55
    output_type = "raw"
1✔
56
    tree_values = tree_model.tree_.value.copy()
1✔
57
    # set class label if not given and model is a classifier
58
    if (
1✔
59
        safe_isinstance(tree_model, "sklearn.tree.DecisionTreeClassifier")
60
        or safe_isinstance(tree_model, "sklearn.tree._classes.DecisionTreeClassifier")
61
    ) and class_label is None:
62
        class_label = 1
1✔
63

64
    if class_label is not None:
1✔
65
        # turn node values into probabilities
66
        if len(tree_values.shape) == 3:
1✔
67
            tree_values = tree_values[:, 0, :]
1✔
68
        tree_values = tree_values / np.sum(tree_values, axis=1, keepdims=True)
1✔
69
        tree_values = tree_values[:, class_label]
1✔
70
        output_type = "probability"
1✔
71
    tree_values = tree_values.flatten()
1✔
72
    tree_values *= scaling
1✔
73
    return TreeModel(
1✔
74
        children_left=tree_model.tree_.children_left,
75
        children_right=tree_model.tree_.children_right,
76
        features=tree_model.tree_.feature,
77
        thresholds=tree_model.tree_.threshold,
78
        values=tree_values,
79
        node_sample_weight=tree_model.tree_.weighted_n_node_samples,
80
        empty_prediction=None,  # compute empty prediction later
81
        original_output_type=output_type,
82
    )
83

84

85
def average_path_length(isolation_forest: Model) -> float:
1✔
86
    """Compute the average path length of the isolation forest.
87

88
    Args:
89
        isolation_forest: The isolation forest model.
90

91
    Returns:
92
        The average path length of the isolation forest.
93

94
    """
95
    from sklearn.ensemble._iforest import _average_path_length
×
96

97
    max_samples = isolation_forest._max_samples  # noqa: SLF001
×
98
    return _average_path_length([max_samples])
×
99

100

101
def convert_sklearn_isolation_forest(
1✔
102
    tree_model: Model,
103
) -> list[TreeModel]:
104
    """Transforms a scikit-learn isolation forest to the format used by shapiq.
105

106
    Args:
107
        tree_model: The scikit-learn isolation forest model to convert.
108

109
    Returns:
110
        The converted isolation forest model.
111

112
    """
113
    scaling = 1.0 / len(tree_model.estimators_)
1✔
114

115
    return [
1✔
116
        convert_isolation_tree(tree, features, scaling=scaling)
117
        for tree, features in zip(
118
            tree_model.estimators_,
119
            tree_model.estimators_features_,
120
            strict=False,
121
        )
122
    ]
123

124

125
def convert_isolation_tree(
1✔
126
    tree_model: Model,
127
    tree_features: np.ndarray,
128
    scaling: float = 1.0,
129
) -> TreeModel:
130
    """Convert a scikit-learn decision tree to the format used by shapiq.
131

132
    Args:
133
        tree_model: The scikit-learn decision tree model to convert.
134
        tree_features: The features used in the tree.
135
        scaling: The scaling factor for the tree values.
136

137
    Returns:
138
        The converted decision tree model.
139

140
    """
141
    output_type = "raw"
1✔
142
    features_updated, values_updated = isotree_value_traversal(
1✔
143
        tree_model.tree_,
144
        tree_features,
145
        normalize=False,
146
        scaling=1.0,
147
    )
148
    values_updated = values_updated * scaling
1✔
149
    values_updated = values_updated.flatten()
1✔
150

151
    return TreeModel(
1✔
152
        children_left=tree_model.tree_.children_left,
153
        children_right=tree_model.tree_.children_right,
154
        features=features_updated,
155
        thresholds=tree_model.tree_.threshold,
156
        values=values_updated,
157
        node_sample_weight=tree_model.tree_.weighted_n_node_samples,
158
        empty_prediction=None,  # compute empty prediction later
159
        original_output_type=output_type,
160
    )
161

162

163
def isotree_value_traversal(
1✔
164
    tree: Model,
165
    tree_features: np.ndarray,
166
    *,
167
    normalize: bool = False,
168
    scaling: float = 1.0,
169
) -> tuple[np.ndarray, np.ndarray]:
170
    """Traverse the tree and calculate the average path length for each node.
171

172
    Args:
173
        tree: The tree to traverse.
174
        tree_features: The features used in the tree.
175
        normalize: Whether to normalize the values.
176
        scaling: The scaling factor for the values.
177

178
    Returns:
179
        The updated features and values.
180

181
    """
182
    from sklearn.ensemble._iforest import _average_path_length
1✔
183

184
    features = tree.feature.copy()
1✔
185
    corrected_values = tree.value.copy()
1✔
186
    if safe_isinstance(tree, "sklearn.tree._tree.Tree"):
1✔
187

188
        def _recalculate_value(tree: Model, i: int, level: int = 0) -> float:
1✔
189
            if tree.children_left[i] == -1 and tree.children_right[i] == -1:
1✔
190
                value = level + _average_path_length(np.array([tree.n_node_samples[i]]))[0]
1✔
191
                corrected_values[i, 0] = value
1✔
192
                return value * tree.n_node_samples[i]
1✔
193
            value_left = _recalculate_value(tree, tree.children_left[i], level + 1)
1✔
194
            value_right = _recalculate_value(tree, tree.children_right[i], level + 1)
1✔
195
            corrected_values[i, 0] = (value_left + value_right) / tree.n_node_samples[i]
1✔
196
            return value_left + value_right
1✔
197

198
        _recalculate_value(tree, 0, 0)
1✔
199
        if normalize:
1✔
200
            corrected_values = (corrected_values.T / corrected_values.sum(1)).T
×
201
        corrected_values = corrected_values * scaling
1✔
202
        # re-number the features if each tree gets a different set of features
203
        features = np.where(features >= 0, tree_features[features], features)
1✔
204
    return features, corrected_values
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc