• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 16345920221

17 Jul 2025 01:05PM UTC coverage: 77.589% (-13.5%) from 91.075%
16345920221

push

github

web-flow
🔨 Refactors library into a src structure. (#415)

* moves shapiq into a src folder

* moves shapiq tests into tests_shapiq subfolder in tests

* refactors tests to work properly

* removes pickle support and closes #413

* changes unit tests to only run the unit tests

* adds workflow for running shapiq_games

* updates coverage to only run for shapiq

* update workflow to check for shapiq_games import

* update CHANGELOG.md

* fixes install-import.yml

* fixes version in docs

* moved deprecated tests out of the main test suite

* moves fixtures in the correct test suite

* installs libomp on macos runner (try bugfix)

* correct spelling

* removes libomp again

* moves os runs into individual workflows for easier debugging

* runs macOS on py3.13

* renames workflows

* installs libomp again on macOS

* downgraded to 3.11 and reinstall python

* try different uv version

* adds libomp

* changes skip to xfail in integration tests with wrong index/order combinations

* moves test out for debugging CI

* removes outdated test

* adds concurrency for quicker testsing

* re-adds randomly

* dont reset seed

* removed pytest-randomly again

* adds the tests back in

3 of 21 new or added lines in 19 files covered. (14.29%)

5536 of 7135 relevant lines covered (77.59%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.89
/src/shapiq/explainer/tree/validation.py
1
"""Conversion functions for the tree explainer implementation."""
2

3
from __future__ import annotations
1✔
4

5
from typing import TYPE_CHECKING
1✔
6

7
from shapiq.utils.modules import safe_isinstance
1✔
8

9
from .base import TreeModel
1✔
10
from .conversion.lightgbm import convert_lightgbm_booster
1✔
11
from .conversion.sklearn import (
1✔
12
    convert_sklearn_forest,
13
    convert_sklearn_isolation_forest,
14
    convert_sklearn_tree,
15
)
16
from .conversion.xgboost import convert_xgboost_booster
1✔
17

18
if TYPE_CHECKING:
1✔
NEW
19
    from shapiq.typing import Model
×
20

21
SUPPORTED_MODELS = {
1✔
22
    "sklearn.tree.DecisionTreeRegressor",
23
    "sklearn.tree._classes.DecisionTreeRegressor",
24
    "sklearn.tree.DecisionTreeClassifier",
25
    "sklearn.tree._classes.DecisionTreeClassifier",
26
    "sklearn.ensemble.RandomForestClassifier",
27
    "sklearn.ensemble._forest.RandomForestClassifier",
28
    "sklearn.ensemble.ExtraTreesClassifier",
29
    "sklearn.ensemble._forest.ExtraTreesClassifier",
30
    "sklearn.ensemble.RandomForestRegressor",
31
    "sklearn.ensemble._forest.RandomForestRegressor",
32
    "sklearn.ensemble.ExtraTreesRegressor",
33
    "sklearn.ensemble._forest.ExtraTreesRegressor",
34
    "sklearn.ensemble.IsolationForest",
35
    "sklearn.ensemble._iforest.IsolationForest",
36
    "lightgbm.sklearn.LGBMRegressor",
37
    "lightgbm.sklearn.LGBMClassifier",
38
    "lightgbm.basic.Booster",
39
    "xgboost.sklearn.XGBRegressor",
40
    "xgboost.sklearn.XGBClassifier",
41
}
42

43

44
def validate_tree_model(
1✔
45
    model: Model,
46
    class_label: int | None = None,
47
) -> TreeModel | list[TreeModel]:
48
    """Validate the model.
49

50
    Args:
51
        model: The model to validate.
52
        class_label: The class label of the model to explain. Only used for classification models.
53

54
    Returns:
55
        The validated model and the model function.
56

57
    """
58
    # direct returns for base tree models and dict as model
59
    # tree model (is already in the correct format)
60
    if type(model).__name__ == "TreeModel":
1✔
61
        tree_model = model
1✔
62
    # direct return if list of tree models
63
    elif type(model).__name__ == "list":
1✔
64
        # check if all elements are TreeModel
65
        if all(type(tree).__name__ == "TreeModel" for tree in model):
×
66
            tree_model = model
×
67
    # dict as model is parsed to TreeModel (the dict needs to have the correct format and names)
68
    elif type(model).__name__ == "dict":
1✔
69
        tree_model = TreeModel(**model)
1✔
70
    # transformation of common machine learning libraries to TreeModel
71
    # sklearn decision trees
72
    elif (
1✔
73
        safe_isinstance(model, "sklearn.tree.DecisionTreeRegressor")
74
        or safe_isinstance(model, "sklearn.tree._classes.DecisionTreeRegressor")
75
        or safe_isinstance(model, "sklearn.tree.DecisionTreeClassifier")
76
        or safe_isinstance(model, "sklearn.tree._classes.DecisionTreeClassifier")
77
    ):
78
        tree_model = convert_sklearn_tree(model, class_label=class_label)
1✔
79
    # sklearn random forests
80
    elif (
1✔
81
        safe_isinstance(model, "sklearn.ensemble.RandomForestRegressor")
82
        or safe_isinstance(model, "sklearn.ensemble._forest.RandomForestRegressor")
83
        or safe_isinstance(model, "sklearn.ensemble.RandomForestClassifier")
84
        or safe_isinstance(model, "sklearn.ensemble._forest.RandomForestClassifier")
85
        or safe_isinstance(model, "sklearn.ensemble.ExtraTreesRegressor")
86
        or safe_isinstance(model, "sklearn.ensemble._forest.ExtraTreesRegressor")
87
        or safe_isinstance(model, "sklearn.ensemble.ExtraTreesClassifier")
88
        or safe_isinstance(model, "sklearn.ensemble._forest.ExtraTreesClassifier")
89
    ):
90
        tree_model = convert_sklearn_forest(model, class_label=class_label)
1✔
91
    elif safe_isinstance(model, "sklearn.ensemble.IsolationForest") or safe_isinstance(
1✔
92
        model,
93
        "sklearn.ensemble._iforest.IsolationForest",
94
    ):
95
        tree_model = convert_sklearn_isolation_forest(model)
1✔
96
    elif safe_isinstance(model, "lightgbm.sklearn.LGBMRegressor") or safe_isinstance(
1✔
97
        model,
98
        "lightgbm.sklearn.LGBMClassifier",
99
    ):
100
        tree_model = convert_lightgbm_booster(model.booster_, class_label=class_label)
1✔
101
    elif safe_isinstance(model, "lightgbm.basic.Booster"):
1✔
102
        tree_model = convert_lightgbm_booster(model, class_label=class_label)
1✔
103
    elif safe_isinstance(model, "xgboost.sklearn.XGBRegressor") or safe_isinstance(
1✔
104
        model,
105
        "xgboost.sklearn.XGBClassifier",
106
    ):
107
        tree_model = convert_xgboost_booster(model, class_label=class_label)
1✔
108
    # unsupported model
109
    else:
110
        msg = f"Unsupported model type.Supported models are: {SUPPORTED_MODELS}"
1✔
111
        raise TypeError(msg)
1✔
112

113
    # if single tree model put it in a list
114
    if not isinstance(tree_model, list):
1✔
115
        tree_model = [tree_model]
1✔
116

117
    if len(tree_model) == 1:
1✔
118
        tree_model = tree_model[0]
1✔
119

120
    return tree_model
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc