• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 16345920221

17 Jul 2025 01:05PM UTC coverage: 77.589% (-13.5%) from 91.075%
16345920221

push

github

web-flow
🔨 Refactors library into a src structure. (#415)

* moves shapiq into a src folder

* moves shapiq tests into tests_shapiq subfolder in tests

* refactors tests to work properly

* removes pickle support and closes #413

* changes unit tests to only run the unit tests

* adds workflow for running shapiq_games

* updates coverage to only run for shapiq

* update workflow to check for shapiq_games import

* update CHANGELOG.md

* fixes install-import.yml

* fixes version in docs

* moved deprecated tests out of the main test suite

* moves fixtures in the correct test suite

* installs libomp on macos runner (try bugfix)

* correct spelling

* removes libomp again

* moves os runs into individual workflows for easier debugging

* runs macOS on py3.13

* renames workflows

* installs libomp again on macOS

* downgraded to 3.11 and reinstall python

* try different uv version

* adds libomp

* changes skip to xfail in integration tests with wrong index/order combinations

* moves test out for debugging CI

* removes outdated test

* adds concurrency for quicker testsing

* re-adds randomly

* dont reset seed

* removed pytest-randomly again

* adds the tests back in

3 of 21 new or added lines in 19 files covered. (14.29%)

5536 of 7135 relevant lines covered (77.59%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.41
/src/shapiq/explainer/utils.py
1
"""This module contains utility functions for the explainer module."""
2

3
from __future__ import annotations
1✔
4

5
import re
1✔
6
from typing import TYPE_CHECKING, Literal, cast, get_args
1✔
7

8
if TYPE_CHECKING:
1✔
9
    from collections.abc import Callable
×
10
    from typing import Any
×
11

12
    import numpy as np
×
13

14
    from shapiq.explainer.base import Explainer
×
15
    from shapiq.games.base import Game
×
NEW
16
    from shapiq.typing import Model
×
17

18
WARNING_NO_CLASS_INDEX = (
1✔
19
    "No class_index provided. "
20
    "Explaining the 2nd '1' class for classification models. "
21
    "Please provide the class_index to explain a different class. "
22
    "Disregard this warning for regression models."
23
)
24

25
ExplainerTypes = Literal["tabular", "tree", "tabpfn", "game"]
1✔
26

27

28
def get_explainers() -> dict[ExplainerTypes, type[Explainer]]:
1✔
29
    """Return a dictionary of all available explainer classes.
30

31
    Returns:
32
        A dictionary of all available explainer classes.
33

34
    """
35
    import shapiq.explainer.agnostic as ag
1✔
36
    import shapiq.explainer.tabpfn as tp
1✔
37
    import shapiq.explainer.tabular as tb
1✔
38
    import shapiq.explainer.tree.explainer as tr
1✔
39

40
    return {
1✔
41
        "tabular": tb.TabularExplainer,
42
        "tree": tr.TreeExplainer,
43
        "tabpfn": tp.TabPFNExplainer,
44
        "game": ag.AgnosticExplainer,
45
    }
46

47

48
def get_predict_function_and_model_type(
1✔
49
    model: Model | Game | Callable[[np.ndarray], np.ndarray],
50
    model_class: str | None = None,
51
    class_index: int | None = None,
52
) -> tuple[
53
    Callable[[Model, np.ndarray], np.ndarray] | RuntimeError,
54
    ExplainerTypes,
55
]:
56
    """Get the predict function and model type for a given model.
57

58
    The prediction function is used in the explainer to predict the model's output for a given data
59
    point. The function has the following signature: ``predict_function(model, data)``.
60

61
    Args:
62
        model: The model to explain. Can be any model object or callable function. We try to infer
63
            the model type from the model object.
64

65
        model_class: The class of the model. as a string. If not provided, it will be inferred from
66
            the model object.
67

68
        class_index: The class index of the model to explain. Defaults to ``None``, which will set
69
            the class index to ``1`` per default for classification models and is ignored for
70
            regression models.
71

72
    Returns:
73
        A tuple of the predict function and the model type.
74

75
    """
76
    from shapiq.games.base import Game
1✔
77

78
    from .tree import TreeModel
1✔
79

80
    if model_class is None:
1✔
81
        model_class = print_class(model)
1✔
82

83
    _model_type = "tabular"  # default
1✔
84
    _predict_function = None
1✔
85

86
    if isinstance(model, Game) or model_class == "shapiq.games.base.Game":
1✔
87
        _predict_function = RuntimeError("Games cannot be used for prediction.")
1✔
88
        return _predict_function, "game"
1✔
89

90
    if callable(model):
1✔
91
        _predict_function = predict_callable
1✔
92

93
    # sklearn
94
    if model_class in [
1✔
95
        "sklearn.tree.DecisionTreeRegressor",
96
        "sklearn.tree._classes.DecisionTreeRegressor",
97
        "sklearn.tree.DecisionTreeClassifier",
98
        "sklearn.tree._classes.DecisionTreeClassifier",
99
        "sklearn.ensemble.RandomForestClassifier",
100
        "sklearn.ensemble._forest.RandomForestClassifier",
101
        "sklearn.ensemble.ExtraTreesClassifier",
102
        "sklearn.ensemble._forest.ExtraTreesClassifier",
103
        "sklearn.ensemble.RandomForestRegressor",
104
        "sklearn.ensemble._forest.RandomForestRegressor",
105
        "sklearn.ensemble.ExtraTreesRegressor",
106
        "sklearn.ensemble._forest.ExtraTreesRegressor",
107
        "sklearn.ensemble.IsolationForest",
108
        "sklearn.ensemble._iforest.IsolationForest",
109
    ]:
110
        _model_type = "tree"
1✔
111

112
    # lightgbm
113
    if model_class in [
1✔
114
        "lightgbm.basic.Booster",
115
        "lightgbm.sklearn.LGBMRegressor",
116
        "lightgbm.sklearn.LGBMClassifier",
117
    ]:
118
        _model_type = "tree"
1✔
119

120
    # xgboost
121
    if model_class == "xgboost.core.Booster":
1✔
122
        _predict_function = predict_xgboost
×
123
    if model_class in [
1✔
124
        "xgboost.core.Booster",
125
        "xgboost.sklearn.XGBRegressor",
126
        "xgboost.sklearn.XGBClassifier",
127
    ]:
128
        _model_type = "tree"
1✔
129

130
    # pytorch
131
    if model_class in [
1✔
132
        "torch.nn.modules.container.Sequential",
133
        "torch.nn.modules.module.Module",
134
        "torch.nn.modules.container.ModuleList",
135
        "torch.nn.modules.container.ModuleDict",
136
    ]:
137
        _model_type = "tabular"
1✔
138
        _predict_function = predict_torch
1✔
139

140
    # tensorflow
141
    if model_class in [
1✔
142
        "tensorflow.python.keras.engine.sequential.Sequential",
143
        "tensorflow.python.keras.engine.training.Model",
144
        "tensorflow.python.keras.engine.functional.Functional",
145
        "keras.engine.sequential.Sequential",
146
        "keras.engine.training.Model",
147
        "keras.engine.functional.Functional",
148
        "keras.src.models.sequential.Sequential",
149
    ]:
150
        _model_type = "tabular"
×
151
        _predict_function = predict_tensorflow
×
152

153
    if model_class in [
1✔
154
        "tabpfn.classifier.TabPFNClassifier",
155
        "tabpfn.regressor.TabPFNRegressor",
156
    ]:
157
        _model_type = "tabpfn"
1✔
158

159
    # default extraction (sklearn api)
160
    if _predict_function is None and hasattr(model, "predict_proba"):
1✔
161
        _predict_function = predict_proba
1✔
162
    elif _predict_function is None and hasattr(model, "predict"):
1✔
163
        _predict_function = predict
1✔
164
    # extraction for tree models
165
    elif isinstance(model, TreeModel):  # test scenario
1✔
166
        _predict_function = model.compute_empty_prediction
1✔
167
        _model_type = "tree"
1✔
168
    elif isinstance(model, list) and all(isinstance(m, TreeModel) for m in model):
1✔
169
        _predict_function = model[0].compute_empty_prediction
×
170
        _model_type = "tree"
×
171
    elif _predict_function is None:
1✔
172
        msg = (
1✔
173
            f"`model` is of unsupported type: {model_class}.\n"
174
            "Please, raise a new issue at https://github.com/mmschlk/shapiq/issues if you want this model type\n"
175
            "to be handled automatically by shapiq.Explainer. Otherwise, use one of the supported explainers:\n"
176
            f"{', '.join(print_classes_nicely(get_explainers()))}"
177
        )
178
        raise TypeError(msg)
1✔
179

180
    if class_index is None:
1✔
181
        class_index = 1
1✔
182

183
    def _predict_function_with_class_index(model: Model, data: np.ndarray) -> np.ndarray:
1✔
184
        """A wrapper prediction function to retrieve class_index predictions for classifiers.
185

186
        Note:
187
            Regression models are not affected by this function.
188

189
        Args:
190
            model: The model to predict with.
191
            data: The data to predict on.
192

193
        Returns:
194
            The model's prediction for the given data point as a vector.
195

196
        """
197
        predictions = _predict_function(model, data)
1✔
198
        if predictions.ndim == 1:
1✔
199
            return predictions
1✔
200
        if predictions.shape[1] == 1:
1✔
201
            return predictions[:, 0]
1✔
202
        return predictions[:, class_index]
1✔
203

204
    # validate model type before returning
205
    if _model_type not in list(get_args(ExplainerTypes)):
1✔
206
        msg = f"Model type {_model_type} is not supported."
×
207
        raise ValueError(msg)
×
208
    _model_type_literal = cast(ExplainerTypes, _model_type)
1✔
209

210
    return _predict_function_with_class_index, _model_type_literal
1✔
211

212

213
def predict_callable(model: Model, data: np.ndarray) -> np.ndarray:
1✔
214
    """Makes predictions with a model that is callable."""
215
    return model(data)
1✔
216

217

218
def predict(model: Model, data: np.ndarray) -> np.ndarray:
1✔
219
    """Makes predictions with a model that has a ``predict`` method."""
220
    return model.predict(data)
1✔
221

222

223
def predict_proba(model: Model, data: np.ndarray) -> np.ndarray:
1✔
224
    """Makes predictions with a model that has a ``predict_proba`` method."""
225
    return model.predict_proba(data)
1✔
226

227

228
def predict_xgboost(model: Model, data: np.ndarray) -> np.ndarray:
1✔
229
    """Makes predictions with an XGBoost model."""
230
    from xgboost import DMatrix
×
231

232
    return model.predict(DMatrix(data))
×
233

234

235
def predict_tensorflow(model: Model, data: np.ndarray) -> np.ndarray:
1✔
236
    """Makes predictions with a TensorFlow model."""
237
    return model.predict(data, verbose=0)
×
238

239

240
def predict_torch(model: Model, data: np.ndarray) -> np.ndarray:
1✔
241
    """Makes predictions with a PyTorch model."""
242
    import torch
1✔
243

244
    return model(torch.from_numpy(data).float()).detach().numpy()
1✔
245

246

247
def print_classes_nicely(obj: list[Any] | dict[str, Any]) -> list[str] | None:
1✔
248
    """Converts a collection of classes into *user-readable* class names.
249

250
    I/O examples:
251
        - ``[shapiq.explainer._base.Explainer]`` -> ``['shapiq.Explainer']``
252
        - ``{'tree': shapiq.explainer.tree.explainer.TreeExplainer}``  -> ``['shapiq.TreeExplainer']``
253
        - ``{'tree': shapiq.TreeExplainer}  -> ``['shapiq.TreeExplainer']``.
254

255
    Args:
256
        obj: The objects as a list or dictionary to convert. Can be a class or a class type.
257
        Can be a list or dictionary of classes or class types.
258

259
    Returns:
260
        The user-readable class names as a list. If the input is not a list or dictionary, returns
261
            ``None``.
262

263
    """
264
    if isinstance(obj, dict):
1✔
265
        return [".".join([print_class(v).split(".")[i] for i in (0, -1)]) for _, v in obj.items()]
1✔
266
    if isinstance(obj, list):
×
267
        return [".".join([print_class(v).split(".")[i] for i in (0, -1)]) for v in obj]
×
268
    return None
×
269

270

271
def print_class(obj: object) -> str:
1✔
272
    """Converts a class or class type into a *user-readable* class name.
273

274
    I/O Examples:
275
        - ``sklearn.ensemble._forest.RandomForestRegressor`` -> ``'sklearn.ensemble._forest.RandomForestRegressor'``
276
        - ``type(sklearn.ensemble._forest.RandomForestRegressor)`` -> ``'sklearn.ensemble._forest.RandomForestRegressor'``
277
        - ``shapiq.explainer.tree.explainer.TreeExplainer`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
278
        - ``shapiq.TreeExplainer`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
279
        - ``type(shapiq.TreeExplainer)`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
280

281
    Args:
282
        obj: The object to convert. Can be a class or a class type.
283

284
    Returns:
285
        The user-readable class name.
286

287
    """
288
    if isinstance(obj, type):
1✔
289
        return re.search("(?<=<class ').*(?='>)", str(obj))[0]
1✔
290
    return re.search("(?<=<class ').*(?='>)", str(type(obj)))[0]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc