• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18540698301

15 Oct 2025 07:44PM UTC coverage: 92.615% (-0.2%) from 92.799%
18540698301

Pull #431

github

web-flow
Merge 606a077aa into 193f1a9ef
Pull Request #431: Product kernel explainer

185 of 210 new or added lines in 13 files covered. (88.1%)

4 existing lines in 2 files now uncovered.

5167 of 5579 relevant lines covered (92.62%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.44
/src/shapiq/explainer/utils.py
1
"""This module contains utility functions for the explainer module."""
2

3
from __future__ import annotations
1✔
4

5
import re
1✔
6
from typing import TYPE_CHECKING, Literal, cast, get_args
1✔
7

8
if TYPE_CHECKING:
1✔
9
    from collections.abc import Callable
×
10
    from typing import Any
×
11

12
    import numpy as np
×
13

14
    from shapiq.explainer.base import Explainer
×
15
    from shapiq.game import Game
×
16
    from shapiq.typing import Model
×
17

18

19
WARNING_NO_CLASS_INDEX = (
1✔
20
    "No class_index provided. "
21
    "Explaining the 2nd '1' class for classification models. "
22
    "Please provide the class_index to explain a different class. "
23
    "Disregard this warning for regression models."
24
)
25

26
ExplainerTypes = Literal["tabular", "tree", "tabpfn", "game", "product_kernel"]
1✔
27

28

29
def get_explainers() -> dict[ExplainerTypes, type[Explainer]]:
1✔
30
    """Return a dictionary of all available explainer classes.
31

32
    Returns:
33
        A dictionary of all available explainer classes.
34

35
    """
36
    import shapiq.explainer.agnostic as ag
1✔
37
    import shapiq.explainer.product_kernel.explainer as pk
1✔
38
    import shapiq.explainer.tabpfn as tp
1✔
39
    import shapiq.explainer.tabular as tb
1✔
40
    import shapiq.explainer.tree.explainer as tr
1✔
41

42
    return {
1✔
43
        "tabular": tb.TabularExplainer,
44
        "tree": tr.TreeExplainer,
45
        "tabpfn": tp.TabPFNExplainer,
46
        "game": ag.AgnosticExplainer,
47
        "product_kernel": pk.ProductKernelExplainer,
48
    }
49

50

51
def get_predict_function_and_model_type(
1✔
52
    model: Model | Game | Callable[[np.ndarray], np.ndarray],
53
    model_class: str | None = None,
54
    class_index: int | None = None,
55
) -> tuple[
56
    Callable[..., np.ndarray] | RuntimeError,
57
    ExplainerTypes,
58
]:
59
    """Get the predict function and model type for a given model.
60

61
    The prediction function is used in the explainer to predict the model's output for a given data
62
    point. The function has the following signature: ``predict_function(model, data)``.
63

64
    Args:
65
        model: The model to explain. Can be any model object or callable function. We try to infer
66
            the model type from the model object.
67

68
        model_class: The class of the model. as a string. If not provided, it will be inferred from
69
            the model object.
70

71
        class_index: The class index of the model to explain. Defaults to ``None``, which will set
72
            the class index to ``1`` per default for classification models and is ignored for
73
            regression models.
74

75
    Returns:
76
        A tuple of the predict function and the model type.
77

78
    """
79
    from shapiq.game import Game
1✔
80

81
    from .tree import TreeModel
1✔
82

83
    if model_class is None:
1✔
84
        model_class = print_class(model)
1✔
85

86
    _model_type = "tabular"  # default
1✔
87
    _predict_function: Any = None
1✔
88

89
    if isinstance(model, Game) or model_class == "shapiq.games.base.Game":
1✔
90
        _predict_function = RuntimeError("Games cannot be used for prediction.")
1✔
91
        return _predict_function, "game"
1✔
92

93
    if callable(model):
1✔
94
        _predict_function = predict_callable
1✔
95

96
    # sklearn
97
    if model_class in [
1✔
98
        "sklearn.tree.DecisionTreeRegressor",
99
        "sklearn.tree._classes.DecisionTreeRegressor",
100
        "sklearn.tree.DecisionTreeClassifier",
101
        "sklearn.tree._classes.DecisionTreeClassifier",
102
        "sklearn.ensemble.RandomForestClassifier",
103
        "sklearn.ensemble._forest.RandomForestClassifier",
104
        "sklearn.ensemble.ExtraTreesClassifier",
105
        "sklearn.ensemble._forest.ExtraTreesClassifier",
106
        "sklearn.ensemble.RandomForestRegressor",
107
        "sklearn.ensemble._forest.RandomForestRegressor",
108
        "sklearn.ensemble.ExtraTreesRegressor",
109
        "sklearn.ensemble._forest.ExtraTreesRegressor",
110
        "sklearn.ensemble.IsolationForest",
111
        "sklearn.ensemble._iforest.IsolationForest",
112
    ]:
113
        _model_type = "tree"
1✔
114

115
    # lightgbm
116
    if model_class in [
1✔
117
        "lightgbm.basic.Booster",
118
        "lightgbm.sklearn.LGBMRegressor",
119
        "lightgbm.sklearn.LGBMClassifier",
120
    ]:
121
        _model_type = "tree"
1✔
122

123
    # xgboost
124
    if model_class == "xgboost.core.Booster":
1✔
125
        _predict_function = predict_xgboost
×
126
    if model_class in [
1✔
127
        "xgboost.core.Booster",
128
        "xgboost.sklearn.XGBRegressor",
129
        "xgboost.sklearn.XGBClassifier",
130
    ]:
131
        _model_type = "tree"
1✔
132

133
    # pytorch
134
    if model_class in [
1✔
135
        "torch.nn.modules.container.Sequential",
136
        "torch.nn.modules.module.Module",
137
        "torch.nn.modules.container.ModuleList",
138
        "torch.nn.modules.container.ModuleDict",
139
    ]:
140
        _model_type = "tabular"
1✔
141
        _predict_function = predict_torch
1✔
142

143
    # tensorflow
144
    if model_class in [
1✔
145
        "tensorflow.python.keras.engine.sequential.Sequential",
146
        "tensorflow.python.keras.engine.training.Model",
147
        "tensorflow.python.keras.engine.functional.Functional",
148
        "keras.engine.sequential.Sequential",
149
        "keras.engine.training.Model",
150
        "keras.engine.functional.Functional",
151
        "keras.src.models.sequential.Sequential",
152
    ]:
153
        _model_type = "tabular"
×
154
        _predict_function = predict_tensorflow
×
155

156
    if model_class in [
1✔
157
        "tabpfn.classifier.TabPFNClassifier",
158
        "tabpfn.regressor.TabPFNRegressor",
159
    ]:
160
        _model_type = "tabpfn"
1✔
161

162
    if model_class in [
1✔
163
        "sklearn.svm.SVR",
164
        "sklearn.svm.SVC",
165
        "sklearn.gaussian_process.GaussianProcessRegressor",
166
    ]:
NEW
167
        _model_type = "product_kernel"
×
168

169
    # default extraction (sklearn api)
170
    if _predict_function is None and hasattr(model, "predict_proba"):
1✔
171
        _predict_function = predict_proba
1✔
172
    elif _predict_function is None and hasattr(model, "predict"):
1✔
173
        _predict_function = predict
1✔
174
    # extraction for tree models
175
    elif isinstance(model, TreeModel):  # test scenario
1✔
176
        _predict_function = model.compute_empty_prediction
1✔
177
        _model_type = "tree"
1✔
178
    elif isinstance(model, list) and all(isinstance(m, TreeModel) for m in model):
1✔
179
        _predict_function = model[0].compute_empty_prediction
×
180
        _model_type = "tree"
×
181
    elif _predict_function is None:
1✔
182
        msg = (
1✔
183
            f"`model` is of unsupported type: {model_class}.\n"
184
            "Please, raise a new issue at https://github.com/mmschlk/shapiq/issues if you want this model type\n"
185
            "to be handled automatically by shapiq.Explainer. Otherwise, use one of the supported explainers:\n"
186
            f"{', '.join(print_classes_nicely(get_explainers()))}"
187
        )
188
        raise TypeError(msg)
1✔
189

190
    if class_index is None:
1✔
191
        class_index = 1
1✔
192

193
    def _predict_function_with_class_index(model: Model, data: np.ndarray) -> np.ndarray:
1✔
194
        """A wrapper prediction function to retrieve class_index predictions for classifiers.
195

196
        Note:
197
            Regression models are not affected by this function.
198

199
        Args:
200
            model: The model to predict with.
201
            data: The data to predict on.
202

203
        Returns:
204
            The model's prediction for the given data point as a vector.
205

206
        """
207
        predictions = _predict_function(model, data)  # pyright: ignore[reportCallIssue] TODO: We here assume that the predict_function takes at least tow arguements. Yet we also define it as None and () -> Float
1✔
208
        if predictions.ndim == 1:
1✔
209
            return predictions
1✔
210
        if predictions.shape[1] == 1:
1✔
211
            return predictions[:, 0]
1✔
212
        return predictions[:, class_index]
1✔
213

214
    # validate model type before returning
215
    if _model_type not in list(get_args(ExplainerTypes)):
1✔
216
        msg = f"Model type {_model_type} is not supported."
×
217
        raise ValueError(msg)
×
218
    _model_type_literal = cast(ExplainerTypes, _model_type)
1✔
219

220
    return _predict_function_with_class_index, _model_type_literal
1✔
221

222

223
def predict_callable(model: Model, data: np.ndarray) -> np.ndarray:
1✔
224
    """Makes predictions with a model that is callable."""
225
    return model(data)
1✔
226

227

228
def predict(model: Model, data: np.ndarray) -> np.ndarray:
1✔
229
    """Makes predictions with a model that has a ``predict`` method."""
230
    return model.predict(data)
1✔
231

232

233
def predict_proba(model: Model, data: np.ndarray) -> np.ndarray:
1✔
234
    """Makes predictions with a model that has a ``predict_proba`` method."""
235
    return model.predict_proba(data)
1✔
236

237

238
def predict_xgboost(model: Model, data: np.ndarray) -> np.ndarray:
1✔
239
    """Makes predictions with an XGBoost model."""
240
    from xgboost import DMatrix
×
241

242
    return model.predict(DMatrix(data))
×
243

244

245
def predict_tensorflow(model: Model, data: np.ndarray) -> np.ndarray:
1✔
246
    """Makes predictions with a TensorFlow model."""
247
    return model.predict(data, verbose=0)
×
248

249

250
def predict_torch(model: Model, data: np.ndarray) -> np.ndarray:
1✔
251
    """Makes predictions with a PyTorch model."""
252
    import torch
1✔
253

254
    return model(torch.from_numpy(data).float()).detach().numpy()
1✔
255

256

257
def print_classes_nicely(obj: list[Any] | dict[ExplainerTypes, Any]) -> list[str]:
1✔
258
    """Converts a collection of classes into *user-readable* class names.
259

260
    I/O examples:
261
        - ``[shapiq.explainer._base.Explainer]`` -> ``['shapiq.Explainer']``
262
        - ``{'tree': shapiq.explainer.tree.explainer.TreeExplainer}``  -> ``['shapiq.TreeExplainer']``
263
        - ``{'tree': shapiq.TreeExplainer}  -> ``['shapiq.TreeExplainer']``.
264

265
    Args:
266
        obj: The objects as a list or dictionary to convert. Can be a class or a class type.
267
        Can be a list or dictionary of classes or class types.
268

269
    Returns:
270
        The user-readable class names as a list. If the input is not a list or dictionary, returns
271
            ``None``.
272

273
    """
274
    if isinstance(obj, dict):
1✔
275
        return [".".join([print_class(v).split(".")[i] for i in (0, -1)]) for _, v in obj.items()]
1✔
276
    if isinstance(obj, list):
×
277
        return [".".join([print_class(v).split(".")[i] for i in (0, -1)]) for v in obj]
×
278
    return []
×
279

280

281
def print_class(obj: object) -> str:
1✔
282
    """Converts a class or class type into a *user-readable* class name.
283

284
    I/O Examples:
285
        - ``sklearn.ensemble._forest.RandomForestRegressor`` -> ``'sklearn.ensemble._forest.RandomForestRegressor'``
286
        - ``type(sklearn.ensemble._forest.RandomForestRegressor)`` -> ``'sklearn.ensemble._forest.RandomForestRegressor'``
287
        - ``shapiq.explainer.tree.explainer.TreeExplainer`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
288
        - ``shapiq.TreeExplainer`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
289
        - ``type(shapiq.TreeExplainer)`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
290

291
    Args:
292
        obj: The object to convert. Can be a class or a class type.
293

294
    Returns:
295
        The user-readable class name.
296

297
    Raises:
298
        ValueError: If the class name cannot be determined.
299
    """
300
    msg = f"Could not determine class name for object: {obj}"
1✔
301
    # TODO(advueu963): Might even want to just ignore it here # noqa: TD003
302
    if isinstance(obj, type):
1✔
303
        search = re.search("(?<=<class ').*(?='>)", str(obj))
1✔
304
        if not search:
1✔
305
            raise ValueError(msg)
×
306
        return search[0]
1✔
307
    search = re.search("(?<=<class ').*(?='>)", str(type(obj)))
1✔
308
    if not search:
1✔
309
        raise ValueError(msg)
×
310
    return search[0]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc