• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 17895082095

21 Sep 2025 02:45PM UTC coverage: 93.618% (-0.2%) from 93.845%
17895082095

Pull #431

github

web-flow
Merge b4b2bdc5c into dede390c9
Pull Request #431: Product kernel explainer

186 of 211 new or added lines in 14 files covered. (88.15%)

4 existing lines in 2 files now uncovered.

5105 of 5453 relevant lines covered (93.62%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.0
/src/shapiq/explainer/utils.py
1
"""This module contains utility functions for the explainer module."""
2

3
from __future__ import annotations
1✔
4

5
import re
1✔
6
from typing import TYPE_CHECKING, Literal, cast, get_args
1✔
7

8
if TYPE_CHECKING:
1✔
9
    from collections.abc import Callable
×
10
    from typing import Any
×
11

12
    import numpy as np
×
13

14
    from shapiq.explainer.base import Explainer
×
15
    from shapiq.game import Game
×
16
    from shapiq.typing import Model
×
17

18
WARNING_NO_CLASS_INDEX = (
1✔
19
    "No class_index provided. "
20
    "Explaining the 2nd '1' class for classification models. "
21
    "Please provide the class_index to explain a different class. "
22
    "Disregard this warning for regression models."
23
)
24

25
ExplainerTypes = Literal["tabular", "tree", "tabpfn", "game", "product_kernel"]
1✔
26

27

28
def get_explainers() -> dict[ExplainerTypes, type[Explainer]]:
1✔
29
    """Return a dictionary of all available explainer classes.
30

31
    Returns:
32
        A dictionary of all available explainer classes.
33

34
    """
35
    import shapiq.explainer.agnostic as ag
1✔
36
    import shapiq.explainer.product_kernel.explainer as pk
1✔
37
    import shapiq.explainer.tabpfn as tp
1✔
38
    import shapiq.explainer.tabular as tb
1✔
39
    import shapiq.explainer.tree.explainer as tr
1✔
40

41
    return {
1✔
42
        "tabular": tb.TabularExplainer,
43
        "tree": tr.TreeExplainer,
44
        "tabpfn": tp.TabPFNExplainer,
45
        "game": ag.AgnosticExplainer,
46
        "product_kernel": pk.ProductKernelExplainer,
47
    }
48

49

50
def get_predict_function_and_model_type(
1✔
51
    model: Model | Game | Callable[[np.ndarray], np.ndarray],
52
    model_class: str | None = None,
53
    class_index: int | None = None,
54
) -> tuple[
55
    Callable[[Model, np.ndarray], np.ndarray] | RuntimeError,
56
    ExplainerTypes,
57
]:
58
    """Get the predict function and model type for a given model.
59

60
    The prediction function is used in the explainer to predict the model's output for a given data
61
    point. The function has the following signature: ``predict_function(model, data)``.
62

63
    Args:
64
        model: The model to explain. Can be any model object or callable function. We try to infer
65
            the model type from the model object.
66

67
        model_class: The class of the model. as a string. If not provided, it will be inferred from
68
            the model object.
69

70
        class_index: The class index of the model to explain. Defaults to ``None``, which will set
71
            the class index to ``1`` per default for classification models and is ignored for
72
            regression models.
73

74
    Returns:
75
        A tuple of the predict function and the model type.
76

77
    """
78
    from shapiq.game import Game
1✔
79

80
    from .tree import TreeModel
1✔
81

82
    if model_class is None:
1✔
83
        model_class = print_class(model)
1✔
84

85
    _model_type = "tabular"  # default
1✔
86
    _predict_function = None
1✔
87

88
    if isinstance(model, Game) or model_class == "shapiq.games.base.Game":
1✔
89
        _predict_function = RuntimeError("Games cannot be used for prediction.")
1✔
90
        return _predict_function, "game"
1✔
91

92
    if callable(model):
1✔
93
        _predict_function = predict_callable
1✔
94

95
    # sklearn
96
    if model_class in [
1✔
97
        "sklearn.tree.DecisionTreeRegressor",
98
        "sklearn.tree._classes.DecisionTreeRegressor",
99
        "sklearn.tree.DecisionTreeClassifier",
100
        "sklearn.tree._classes.DecisionTreeClassifier",
101
        "sklearn.ensemble.RandomForestClassifier",
102
        "sklearn.ensemble._forest.RandomForestClassifier",
103
        "sklearn.ensemble.ExtraTreesClassifier",
104
        "sklearn.ensemble._forest.ExtraTreesClassifier",
105
        "sklearn.ensemble.RandomForestRegressor",
106
        "sklearn.ensemble._forest.RandomForestRegressor",
107
        "sklearn.ensemble.ExtraTreesRegressor",
108
        "sklearn.ensemble._forest.ExtraTreesRegressor",
109
        "sklearn.ensemble.IsolationForest",
110
        "sklearn.ensemble._iforest.IsolationForest",
111
    ]:
112
        _model_type = "tree"
1✔
113

114
    # lightgbm
115
    if model_class in [
1✔
116
        "lightgbm.basic.Booster",
117
        "lightgbm.sklearn.LGBMRegressor",
118
        "lightgbm.sklearn.LGBMClassifier",
119
    ]:
120
        _model_type = "tree"
1✔
121

122
    # xgboost
123
    if model_class == "xgboost.core.Booster":
1✔
124
        _predict_function = predict_xgboost
×
125
    if model_class in [
1✔
126
        "xgboost.core.Booster",
127
        "xgboost.sklearn.XGBRegressor",
128
        "xgboost.sklearn.XGBClassifier",
129
    ]:
130
        _model_type = "tree"
1✔
131

132
    # pytorch
133
    if model_class in [
1✔
134
        "torch.nn.modules.container.Sequential",
135
        "torch.nn.modules.module.Module",
136
        "torch.nn.modules.container.ModuleList",
137
        "torch.nn.modules.container.ModuleDict",
138
    ]:
139
        _model_type = "tabular"
1✔
140
        _predict_function = predict_torch
1✔
141

142
    # tensorflow
143
    if model_class in [
1✔
144
        "tensorflow.python.keras.engine.sequential.Sequential",
145
        "tensorflow.python.keras.engine.training.Model",
146
        "tensorflow.python.keras.engine.functional.Functional",
147
        "keras.engine.sequential.Sequential",
148
        "keras.engine.training.Model",
149
        "keras.engine.functional.Functional",
150
        "keras.src.models.sequential.Sequential",
151
    ]:
152
        _model_type = "tabular"
×
153
        _predict_function = predict_tensorflow
×
154

155
    if model_class in [
1✔
156
        "tabpfn.classifier.TabPFNClassifier",
157
        "tabpfn.regressor.TabPFNRegressor",
158
    ]:
159
        _model_type = "tabpfn"
1✔
160

161
    if model_class in [
1✔
162
        "sklearn.svm.SVR",
163
        "sklearn.svm.SVC",
164
        "sklearn.gaussian_process.GaussianProcessRegressor",
165
    ]:
NEW
166
        _model_type = "product_kernel"
×
167

168
    # default extraction (sklearn api)
169
    if _predict_function is None and hasattr(model, "predict_proba"):
1✔
170
        _predict_function = predict_proba
1✔
171
    elif _predict_function is None and hasattr(model, "predict"):
1✔
172
        _predict_function = predict
1✔
173
    # extraction for tree models
174
    elif isinstance(model, TreeModel):  # test scenario
1✔
175
        _predict_function = model.compute_empty_prediction
1✔
176
        _model_type = "tree"
1✔
177
    elif isinstance(model, list) and all(isinstance(m, TreeModel) for m in model):
1✔
178
        _predict_function = model[0].compute_empty_prediction
×
179
        _model_type = "tree"
×
180
    elif _predict_function is None:
1✔
181
        msg = (
1✔
182
            f"`model` is of unsupported type: {model_class}.\n"
183
            "Please, raise a new issue at https://github.com/mmschlk/shapiq/issues if you want this model type\n"
184
            "to be handled automatically by shapiq.Explainer. Otherwise, use one of the supported explainers:\n"
185
            f"{', '.join(print_classes_nicely(get_explainers()))}"
186
        )
187
        raise TypeError(msg)
1✔
188

189
    if class_index is None:
1✔
190
        class_index = 1
1✔
191

192
    def _predict_function_with_class_index(model: Model, data: np.ndarray) -> np.ndarray:
1✔
193
        """A wrapper prediction function to retrieve class_index predictions for classifiers.
194

195
        Note:
196
            Regression models are not affected by this function.
197

198
        Args:
199
            model: The model to predict with.
200
            data: The data to predict on.
201

202
        Returns:
203
            The model's prediction for the given data point as a vector.
204

205
        """
206
        predictions = _predict_function(model, data)
1✔
207
        if predictions.ndim == 1:
1✔
208
            return predictions
1✔
209
        if predictions.shape[1] == 1:
1✔
210
            return predictions[:, 0]
1✔
211
        return predictions[:, class_index]
1✔
212

213
    # validate model type before returning
214
    if _model_type not in list(get_args(ExplainerTypes)):
1✔
215
        msg = f"Model type {_model_type} is not supported."
×
216
        raise ValueError(msg)
×
217
    _model_type_literal = cast(ExplainerTypes, _model_type)
1✔
218

219
    return _predict_function_with_class_index, _model_type_literal
1✔
220

221

222
def predict_callable(model: Model, data: np.ndarray) -> np.ndarray:
1✔
223
    """Makes predictions with a model that is callable."""
224
    return model(data)
1✔
225

226

227
def predict(model: Model, data: np.ndarray) -> np.ndarray:
1✔
228
    """Makes predictions with a model that has a ``predict`` method."""
229
    return model.predict(data)
1✔
230

231

232
def predict_proba(model: Model, data: np.ndarray) -> np.ndarray:
1✔
233
    """Makes predictions with a model that has a ``predict_proba`` method."""
234
    return model.predict_proba(data)
1✔
235

236

237
def predict_xgboost(model: Model, data: np.ndarray) -> np.ndarray:
1✔
238
    """Makes predictions with an XGBoost model."""
239
    from xgboost import DMatrix
×
240

241
    return model.predict(DMatrix(data))
×
242

243

244
def predict_tensorflow(model: Model, data: np.ndarray) -> np.ndarray:
1✔
245
    """Makes predictions with a TensorFlow model."""
246
    return model.predict(data, verbose=0)
×
247

248

249
def predict_torch(model: Model, data: np.ndarray) -> np.ndarray:
1✔
250
    """Makes predictions with a PyTorch model."""
251
    import torch
1✔
252

253
    return model(torch.from_numpy(data).float()).detach().numpy()
1✔
254

255

256
def print_classes_nicely(obj: list[Any] | dict[str, Any]) -> list[str] | None:
1✔
257
    """Converts a collection of classes into *user-readable* class names.
258

259
    I/O examples:
260
        - ``[shapiq.explainer._base.Explainer]`` -> ``['shapiq.Explainer']``
261
        - ``{'tree': shapiq.explainer.tree.explainer.TreeExplainer}``  -> ``['shapiq.TreeExplainer']``
262
        - ``{'tree': shapiq.TreeExplainer}  -> ``['shapiq.TreeExplainer']``.
263

264
    Args:
265
        obj: The objects as a list or dictionary to convert. Can be a class or a class type.
266
        Can be a list or dictionary of classes or class types.
267

268
    Returns:
269
        The user-readable class names as a list. If the input is not a list or dictionary, returns
270
            ``None``.
271

272
    """
273
    if isinstance(obj, dict):
1✔
274
        return [".".join([print_class(v).split(".")[i] for i in (0, -1)]) for _, v in obj.items()]
1✔
275
    if isinstance(obj, list):
×
276
        return [".".join([print_class(v).split(".")[i] for i in (0, -1)]) for v in obj]
×
277
    return None
×
278

279

280
def print_class(obj: object) -> str:
1✔
281
    """Converts a class or class type into a *user-readable* class name.
282

283
    I/O Examples:
284
        - ``sklearn.ensemble._forest.RandomForestRegressor`` -> ``'sklearn.ensemble._forest.RandomForestRegressor'``
285
        - ``type(sklearn.ensemble._forest.RandomForestRegressor)`` -> ``'sklearn.ensemble._forest.RandomForestRegressor'``
286
        - ``shapiq.explainer.tree.explainer.TreeExplainer`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
287
        - ``shapiq.TreeExplainer`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
288
        - ``type(shapiq.TreeExplainer)`` -> ``'shapiq.explainer.tree.explainer.TreeExplainer'``
289

290
    Args:
291
        obj: The object to convert. Can be a class or a class type.
292

293
    Returns:
294
        The user-readable class name.
295

296
    """
297
    if isinstance(obj, type):
1✔
298
        return re.search("(?<=<class ').*(?='>)", str(obj))[0]
1✔
299
    return re.search("(?<=<class ').*(?='>)", str(type(obj)))[0]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc