• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.81
/src/shapiq/explainer/utils.py
1
"""This module contains utility functions for the explainer module."""
2

3
from __future__ import annotations
1✔
4

5
import re
1✔
6
from typing import TYPE_CHECKING, Literal, cast, get_args
1✔
7

8
if TYPE_CHECKING:
1✔
9
    from collections.abc import Callable
×
10
    from typing import Any
×
11

12
    import numpy as np
×
13

14
    from shapiq.explainer.base import Explainer
×
15
    from shapiq.game import Game
×
16
    from shapiq.typing import Model
×
17

18

19
WARNING_NO_CLASS_INDEX = (
1✔
20
    "No class_index provided. "
21
    "Explaining the 2nd '1' class for classification models. "
22
    "Please provide the class_index to explain a different class. "
23
    "Disregard this warning for regression models."
24
)
25

26
ExplainerTypes = Literal["tabular", "tree", "tabpfn", "game"]
1✔
27

28

29
def get_explainers() -> dict[ExplainerTypes, type[Explainer]]:
1✔
30
    """Return a dictionary of all available explainer classes.
31

32
    Returns:
33
        A dictionary of all available explainer classes.
34

35
    """
36
    import shapiq.explainer.agnostic as ag
1✔
37
    import shapiq.explainer.tabpfn as tp
1✔
38
    import shapiq.explainer.tabular as tb
1✔
39
    import shapiq.explainer.tree.explainer as tr
1✔
40

41
    return {
1✔
42
        "tabular": tb.TabularExplainer,
43
        "tree": tr.TreeExplainer,
44
        "tabpfn": tp.TabPFNExplainer,
45
        "game": ag.AgnosticExplainer,
46
    }
47

48

49
def get_predict_function_and_model_type(
1✔
50
    model: Model | Game | Callable[[np.ndarray], np.ndarray],
51
    model_class: str | None = None,
52
    class_index: int | None = None,
53
) -> tuple[
54
    Callable[..., np.ndarray] | RuntimeError,
55
    ExplainerTypes,
56
]:
57
    """Get the predict function and model type for a given model.
58

59
    The prediction function is used in the explainer to predict the model's output for a given data
60
    point. The function has the following signature: ``predict_function(model, data)``.
61

62
    Args:
63
        model: The model to explain. Can be any model object or callable function. We try to infer
64
            the model type from the model object.
65

66
        model_class: The class of the model. as a string. If not provided, it will be inferred from
67
            the model object.
68

69
        class_index: The class index of the model to explain. Defaults to ``None``, which will set
70
            the class index to ``1`` per default for classification models and is ignored for
71
            regression models.
72

73
    Returns:
74
        A tuple of the predict function and the model type.
75

76
    """
77
    from shapiq.game import Game
1✔
78

79
    from .tree import TreeModel
1✔
80

81
    if model_class is None:
1✔
82
        model_class = print_class(model)
1✔
83

84
    _model_type = "tabular"  # default
1✔
85
    _predict_function: Any = None
1✔
86

87
    if isinstance(model, Game) or model_class == "shapiq.games.base.Game":
1✔
88
        _predict_function = RuntimeError("Games cannot be used for prediction.")
1✔
89
        return _predict_function, "game"
1✔
90

91
    if callable(model):
1✔
92
        _predict_function = predict_callable
1✔
93

94
    # sklearn
95
    if model_class in [
1✔
96
        "sklearn.tree.DecisionTreeRegressor",
97
        "sklearn.tree._classes.DecisionTreeRegressor",
98
        "sklearn.tree.DecisionTreeClassifier",
99
        "sklearn.tree._classes.DecisionTreeClassifier",
100
        "sklearn.ensemble.RandomForestClassifier",
101
        "sklearn.ensemble._forest.RandomForestClassifier",
102
        "sklearn.ensemble.ExtraTreesClassifier",
103
        "sklearn.ensemble._forest.ExtraTreesClassifier",
104
        "sklearn.ensemble.RandomForestRegressor",
105
        "sklearn.ensemble._forest.RandomForestRegressor",
106
        "sklearn.ensemble.ExtraTreesRegressor",
107
        "sklearn.ensemble._forest.ExtraTreesRegressor",
108
        "sklearn.ensemble.IsolationForest",
109
        "sklearn.ensemble._iforest.IsolationForest",
110
    ]:
111
        _model_type = "tree"
1✔
112

113
    # lightgbm
114
    if model_class in [
1✔
115
        "lightgbm.basic.Booster",
116
        "lightgbm.sklearn.LGBMRegressor",
117
        "lightgbm.sklearn.LGBMClassifier",
118
    ]:
119
        _model_type = "tree"
1✔
120

121
    # xgboost
122
    if model_class == "xgboost.core.Booster":
1✔
123
        _predict_function = predict_xgboost
×
124
    if model_class in [
1✔
125
        "xgboost.core.Booster",
126
        "xgboost.sklearn.XGBRegressor",
127
        "xgboost.sklearn.XGBClassifier",
128
    ]:
129
        _model_type = "tree"
1✔
130

131
    # pytorch
132
    if model_class in [
1✔
133
        "torch.nn.modules.container.Sequential",
134
        "torch.nn.modules.module.Module",
135
        "torch.nn.modules.container.ModuleList",
136
        "torch.nn.modules.container.ModuleDict",
137
    ]:
138
        _model_type = "tabular"
1✔
139
        _predict_function = predict_torch
1✔
140

141
    # tensorflow
142
    if model_class in [
1✔
143
        "tensorflow.python.keras.engine.sequential.Sequential",
144
        "tensorflow.python.keras.engine.training.Model",
145
        "tensorflow.python.keras.engine.functional.Functional",
146
        "keras.engine.sequential.Sequential",
147
        "keras.engine.training.Model",
148
        "keras.engine.functional.Functional",
149
        "keras.src.models.sequential.Sequential",
150
    ]:
151
        _model_type = "tabular"
×
152
        _predict_function = predict_tensorflow
×
153

154
    if model_class in [
1✔
155
        "tabpfn.classifier.TabPFNClassifier",
156
        "tabpfn.regressor.TabPFNRegressor",
157
    ]:
158
        _model_type = "tabpfn"
1✔
159

160
    # default extraction (sklearn api)
161
    if _predict_function is None and hasattr(model, "predict_proba"):
1✔
162
        _predict_function = predict_proba
1✔
163
    elif _predict_function is None and hasattr(model, "predict"):
1✔
164
        _predict_function = predict
1✔
165
    # extraction for tree models
166
    elif isinstance(model, TreeModel):  # test scenario
1✔
167
        _predict_function = model.compute_empty_prediction
1✔
168
        _model_type = "tree"
1✔
169
    elif isinstance(model, list) and all(isinstance(m, TreeModel) for m in model):
1✔
170
        _predict_function = model[0].compute_empty_prediction
×
171
        _model_type = "tree"
×
172
    elif _predict_function is None:
1✔
173
        msg = (
1✔
174
            f"`model` is of unsupported type: {model_class}.\n"
175
            "Please, raise a new issue at https://github.com/mmschlk/shapiq/issues if you want this model type\n"
176
            "to be handled automatically by shapiq.Explainer. Otherwise, use one of the supported explainers:\n"
177
            f"{', '.join(print_classes_nicely(get_explainers()))}"
178
        )
179
        raise TypeError(msg)
1✔
180

181
    if class_index is None:
1✔
182
        class_index = 1
1✔
183

184
    def _predict_function_with_class_index(model: Model, data: np.ndarray) -> np.ndarray:
1✔
185
        """A wrapper prediction function to retrieve class_index predictions for classifiers.
186

187
        Note:
188
            Regression models are not affected by this function.
189

190
        Args:
191
            model: The model to predict with.
192
            data: The data to predict on.
193

194
        Returns:
195
            The model's prediction for the given data point as a vector.
196

197
        """
198
        predictions = _predict_function(model, data)  # pyright: ignore[reportCallIssue] TODO: We here assume that the predict_function takes at least tow arguements. Yet we also define it as None and () -> Float
1✔
199
        if predictions.ndim == 1:
1✔
200
            return predictions
1✔
201
        if predictions.shape[1] == 1:
1✔
202
            return predictions[:, 0]
1✔
203
        return predictions[:, class_index]
1✔
204

205
    # validate model type before returning
206
    if _model_type not in list(get_args(ExplainerTypes)):
1✔
207
        msg = f"Model type {_model_type} is not supported."
×
208
        raise ValueError(msg)
×
209
    _model_type_literal = cast(ExplainerTypes, _model_type)
1✔
210

211
    return _predict_function_with_class_index, _model_type_literal
1✔
212

213

214
def predict_callable(model: Model, data: np.ndarray) -> np.ndarray:
1✔
215
    """Makes predictions with a model that is callable."""
216
    return model(data)
1✔
217

218

219
def predict(model: Model, data: np.ndarray) -> np.ndarray:
1✔
220
    """Makes predictions with a model that has a ``predict`` method."""
221
    return model.predict(data)
1✔
222

223

224
def predict_proba(model: Model, data: np.ndarray) -> np.ndarray:
1✔
225
    """Makes predictions with a model that has a ``predict_proba`` method."""
226
    return model.predict_proba(data)
1✔
227

228

229
def predict_xgboost(model: Model, data: np.ndarray) -> np.ndarray:
1✔
230
    """Makes predictions with an XGBoost model."""
231
    from xgboost import DMatrix
×
232

233
    return model.predict(DMatrix(data))
×
234

235

236
def predict_tensorflow(model: Model, data: np.ndarray) -> np.ndarray:
1✔
237
    """Makes predictions with a TensorFlow model."""
238
    return model.predict(data, verbose=0)
×
239

240

241
def predict_torch(model: Model, data: np.ndarray) -> np.ndarray:
1✔
242
    """Makes predictions with a PyTorch model."""
243
    import torch
1✔
244

245
    return model(torch.from_numpy(data).float()).detach().numpy()
1✔
246

247

248
def print_classes_nicely(obj: list[Any] | dict[ExplainerTypes, Any]) -> list[str]:
1✔
249
    """Converts a collection of classes into *user-readable* class names.
250

251
    I/O examples:
252
        - ``[shapiq.explainer._base.Explainer]`` -> ``['shapiq.Explainer']``
253
        - ``{'tree': shapiq.explainer.tree.explainer.TreeExplainer}``  -> ``['shapiq.TreeExplainer']``
254
        - ``{'tree': shapiq.TreeExplainer}  -> ``['shapiq.TreeExplainer']``.
255

256
    Args:
257
        obj: The objects as a list or dictionary to convert. Can be a class or a class type.
258
        Can be a list or dictionary of classes or class types.
259

260
    Returns:
261
        The user-readable class names as a list. If the input is not a list or dictionary, returns
262
            ``None``.
263

264
    """
265
    if isinstance(obj, dict):
1✔
266
        return [".".join([print_class(v).split(".")[i] for i in (0, -1)]) for _, v in obj.items()]
1✔
267
    if isinstance(obj, list):
×
268
        return [".".join([print_class(v).split(".")[i] for i in (0, -1)]) for v in obj]
×
NEW
269
    return []
×
270

271

272
def print_class(obj: object) -> str:
1✔
273
    """Converts a class or class type into a *user-readable* class name.
274

275
    I/O Examples:
276
        - ``sklearn.ensemble._forest.RandomForestRegressor`` -> ``'sklearn.ensemble._forest.RandomForestRegressor'``
277
        - ``type(sklearn.ensemble._forest.RandomForestRegressor)`` -> ``'sklearn.ensemble._forest.RandomForestRegressor'``
278
        - ``shapiq.explainer.tree.explainer.TreeExplainer`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
279
        - ``shapiq.TreeExplainer`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
280
        - ``type(shapiq.TreeExplainer)`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
281

282
    Args:
283
        obj: The object to convert. Can be a class or a class type.
284

285
    Returns:
286
        The user-readable class name.
287

288
    Raises:
289
        ValueError: If the class name cannot be determined.
290
    """
291
    msg = f"Could not determine class name for object: {obj}"
1✔
292
    # TODO(advueu963): Might even want to just ignore it here # noqa: TD003
293
    if isinstance(obj, type):
1✔
294
        search = re.search("(?<=<class ').*(?='>)", str(obj))
1✔
295
        if not search:
1✔
NEW
296
            raise ValueError(msg)
×
297
        return search[0]
1✔
298
    search = re.search("(?<=<class ').*(?='>)", str(type(obj)))
1✔
299
    if not search:
1✔
NEW
300
        raise ValueError(msg)
×
301
    return search[0]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc