• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.23
/src/shapiq/explainer/tree/conversion/sklearn.py
1
"""Functions for converting scikit-learn decision trees to the format used by shapiq."""
2

3
from __future__ import annotations
1✔
4

5
from typing import TYPE_CHECKING
1✔
6

7
import numpy as np
1✔
8

9
from shapiq.explainer.tree.base import TreeModel
1✔
10
from shapiq.utils import safe_isinstance
1✔
11

12
if TYPE_CHECKING:
1✔
13
    from shapiq.typing import Model
×
14

15

16
def convert_sklearn_forest(
1✔
17
    tree_model: Model,
18
    class_label: int | None = None,
19
) -> list[TreeModel]:
20
    """Transforms a scikit-learn random forest to the format used by shapiq.
21

22
    Args:
23
        tree_model: The scikit-learn random forest model to convert.
24
        class_label: The class label of the model to explain. Only used for classification models.
25
            Defaults to ``1``.
26

27
    Returns:
28
        The converted random forest model.
29

30
    """
31
    scaling = 1.0 / len(tree_model.estimators_)
1✔
32
    return [
1✔
33
        convert_sklearn_tree(tree, scaling=scaling, class_label=class_label)
34
        for tree in tree_model.estimators_
35
    ]
36

37

38
def convert_sklearn_tree(
1✔
39
    tree_model: Model,
40
    class_label: int | None = None,
41
    scaling: float = 1.0,
42
) -> TreeModel:
43
    """Convert a scikit-learn decision tree to the format used by shapiq.
44

45
    Args:
46
        tree_model: The scikit-learn decision tree model to convert.
47
        class_label: The class label of the model to explain. Only used for classification models.
48
            Defaults to ``1``.
49
        scaling: The scaling factor for the tree values.
50

51
    Returns:
52
        The converted decision tree model.
53

54
    """
55
    output_type = "raw"
1✔
56
    tree_values = tree_model.tree_.value.copy()
1✔
57
    # set class label if not given and model is a classifier
58
    if (
1✔
59
        safe_isinstance(tree_model, "sklearn.tree.DecisionTreeClassifier")
60
        or safe_isinstance(tree_model, "sklearn.tree._classes.DecisionTreeClassifier")
61
    ) and class_label is None:
62
        class_label = 1
1✔
63

64
    if class_label is not None:
1✔
65
        # turn node values into probabilities
66
        if len(tree_values.shape) == 3:
1✔
67
            tree_values = tree_values[:, 0, :]
1✔
68
        tree_values = tree_values / np.sum(tree_values, axis=1, keepdims=True)
1✔
69
        tree_values = tree_values[:, class_label]
1✔
70
        output_type = "probability"
1✔
71
    tree_values = tree_values.flatten()
1✔
72
    tree_values *= scaling
1✔
73
    return TreeModel(
1✔
74
        children_left=tree_model.tree_.children_left,
75
        children_right=tree_model.tree_.children_right,
76
        features=tree_model.tree_.feature,
77
        thresholds=tree_model.tree_.threshold,
78
        values=tree_values,
79
        node_sample_weight=tree_model.tree_.weighted_n_node_samples,
80
        empty_prediction=None,  # pyright: ignore[reportArgumentType] compute empty prediction later
81
        original_output_type=output_type,
82
    )
83

84

85
def average_path_length(isolation_forest: Model) -> float:
1✔
86
    """Compute the average path length of the isolation forest.
87

88
    Args:
89
        isolation_forest: The isolation forest model.
90

91
    Returns:
92
        The average path length of the isolation forest.
93

94
    """
NEW
95
    from sklearn.ensemble._iforest import (
×
96
        _average_path_length,  # pyright: ignore[reportAttributeAccessIssue]
97
    )
98

99
    max_samples = isolation_forest._max_samples  # noqa: SLF001
×
NEW
100
    return _average_path_length([max_samples]).item()
×
101

102

103
def convert_sklearn_isolation_forest(
1✔
104
    tree_model: Model,
105
) -> list[TreeModel]:
106
    """Transforms a scikit-learn isolation forest to the format used by shapiq.
107

108
    Args:
109
        tree_model: The scikit-learn isolation forest model to convert.
110

111
    Returns:
112
        The converted isolation forest model.
113

114
    """
115
    scaling = 1.0 / len(tree_model.estimators_)
1✔
116

117
    return [
1✔
118
        convert_isolation_tree(tree, features, scaling=scaling)
119
        for tree, features in zip(
120
            tree_model.estimators_,
121
            tree_model.estimators_features_,
122
            strict=False,
123
        )
124
    ]
125

126

127
def convert_isolation_tree(
1✔
128
    tree_model: Model,
129
    tree_features: np.ndarray,
130
    scaling: float = 1.0,
131
) -> TreeModel:
132
    """Convert a scikit-learn decision tree to the format used by shapiq.
133

134
    Args:
135
        tree_model: The scikit-learn decision tree model to convert.
136
        tree_features: The features used in the tree.
137
        scaling: The scaling factor for the tree values.
138

139
    Returns:
140
        The converted decision tree model.
141

142
    """
143
    output_type = "raw"
1✔
144
    features_updated, values_updated = isotree_value_traversal(
1✔
145
        tree_model.tree_,
146
        tree_features,
147
        normalize=False,
148
        scaling=1.0,
149
    )
150
    values_updated = values_updated * scaling
1✔
151
    values_updated = values_updated.flatten()
1✔
152

153
    return TreeModel(
1✔
154
        children_left=tree_model.tree_.children_left,
155
        children_right=tree_model.tree_.children_right,
156
        features=features_updated,
157
        thresholds=tree_model.tree_.threshold,
158
        values=values_updated,
159
        node_sample_weight=tree_model.tree_.weighted_n_node_samples,
160
        empty_prediction=None,  # pyright: ignore[reportArgumentType] compute empty prediction later
161
        original_output_type=output_type,
162
    )
163

164

165
def isotree_value_traversal(
1✔
166
    tree: Model,
167
    tree_features: np.ndarray,
168
    *,
169
    normalize: bool = False,
170
    scaling: float = 1.0,
171
) -> tuple[np.ndarray, np.ndarray]:
172
    """Traverse the tree and calculate the average path length for each node.
173

174
    Args:
175
        tree: The tree to traverse.
176
        tree_features: The features used in the tree.
177
        normalize: Whether to normalize the values.
178
        scaling: The scaling factor for the values.
179

180
    Returns:
181
        The updated features and values.
182

183
    """
184
    from sklearn.ensemble._iforest import (
1✔
185
        _average_path_length,  # pyright: ignore[reportAttributeAccessIssue]
186
    )
187

188
    features = tree.feature.copy()
1✔
189
    corrected_values = tree.value.copy()
1✔
190
    if safe_isinstance(tree, "sklearn.tree._tree.Tree"):
1✔
191

192
        def _recalculate_value(tree: Model, i: int, level: int = 0) -> float:
1✔
193
            if tree.children_left[i] == -1 and tree.children_right[i] == -1:
1✔
194
                value = level + _average_path_length(np.array([tree.n_node_samples[i]]))[0]
1✔
195
                corrected_values[i, 0] = value
1✔
196
                return value * tree.n_node_samples[i]
1✔
197
            value_left = _recalculate_value(tree, tree.children_left[i], level + 1)
1✔
198
            value_right = _recalculate_value(tree, tree.children_right[i], level + 1)
1✔
199
            corrected_values[i, 0] = (value_left + value_right) / tree.n_node_samples[i]
1✔
200
            return value_left + value_right
1✔
201

202
        _recalculate_value(tree, 0, 0)
1✔
203
        if normalize:
1✔
204
            corrected_values = (corrected_values.T / corrected_values.sum(1)).T
×
205
        corrected_values = corrected_values * scaling
1✔
206
        # re-number the features if each tree gets a different set of features
207
        features = np.where(features >= 0, tree_features[features], features)
1✔
208
    return features, corrected_values
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc