• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18499875864

14 Oct 2025 02:26PM UTC coverage: 92.799% (-0.7%) from 93.522%
18499875864

push

github

web-flow
Enhance type safety and fix bugs across the codebase (#430)

* First Pyright cleanup

* TypeChecked game

* fixed introduced bugs in game and interaction_values

* Pyright Save Sampling

* TypeSafe Approximator

* Typechecked Datasets

* Explainer folder typechecked

* GameTheory Typechecked

* Imputer Typechecked

* Plot Typechecked

* Added static typechecking to pre-commit

* Refactoring

* Add pyright change to CHANGELOG

* Activate code quality show diff

* changed uv sync in pre-commit hook

* made fixtures local import

* Introduced Generic TypeVar in Approximator, reducing ignores

* Introduced Generic Types for Explainer. Approximator, Imputer and ExactComputer can either exist or not, depending on dynamic Type

* Bug fix caused through refactoring

* updated overrides

* tightened CoalitionMatrix to accept only bool arrays

* Remove Python reinstallation step in CI workflow

Removed the step to reinstall Python on Windows due to issues with tkinter. The linked GitHub issue was solved. Doing this as a first try.

* Add Python reinstallation and Tkinter installation steps

Reinstall Python and install Tkinter for Windows tests. prior commit did not help

* Fix command for installing Tkinter in workflow

* Update Windows workflow to install Tkinter via Chocolatey

* Remove Tkinter installation step from Windows workflow and adjust matplotlib usage for headless environments

* adapted some pyright types

* removed generics from explainer again

* tightened index type check

* made n_players None at assignment again

* moved comments

---------

Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>

304 of 360 new or added lines in 51 files covered. (84.44%)

12 existing lines in 9 files now uncovered.

4987 of 5374 relevant lines covered (92.8%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

72.82
/src/shapiq/explainer/base.py
1
"""The base Explainer classes for the shapiq package."""
2

3
from __future__ import annotations
1✔
4

5
from abc import abstractmethod
1✔
6
from typing import TYPE_CHECKING, Any
1✔
7

8
from tqdm.auto import tqdm
1✔
9

10
from .utils import (
1✔
11
    get_explainers,
12
    get_predict_function_and_model_type,
13
    print_class,
14
)
15
from .validation import validate_data_predict_function, validate_index_and_max_order
1✔
16

17
if TYPE_CHECKING:
1✔
18
    from collections.abc import Callable
×
19
    from typing import Any
×
20

21
    import numpy as np
×
22

23
    from shapiq.approximator.base import Approximator
×
24
    from shapiq.game import Game
×
25
    from shapiq.game_theory import ExactComputer
×
26
    from shapiq.imputer.base import Imputer
×
27
    from shapiq.interaction_values import InteractionValues
×
28
    from shapiq.typing import Model
×
29

30
    from .custom_types import ExplainerIndices
×
31

32

33
def generic_to_specific_explainer(
1✔
34
    generic_explainer: Explainer,
35
    explainer_cls: type[Explainer],
36
    model: Model | Game | Callable[[np.ndarray], np.ndarray],
37
    data: np.ndarray | None = None,
38
    class_index: int | None = None,
39
    index: ExplainerIndices = "k-SII",
40
    max_order: int = 2,
41
    **kwargs: Any,
42
) -> None:
43
    """Transform the base Explainer instance into a specific explainer subclass.
44

45
    This function modifies the class of the given object to the specified explainer class and
46
    initializes it with the provided parameters.
47

48
    Args:
49
        generic_explainer: The base Explainer instance to be transformed.
50
        explainer_cls: The specific explainer subclass to transform into.
51
        model: The model object to be explained.
52
        data: A background dataset to be used for imputation.
53
        class_index: The class index of the model to explain.
54
        index: The type of Shapley interaction index to use.
55
        max_order: The maximum interaction order to be computed.
56
        **kwargs: Additional keyword-only arguments passed to the specific explainer class.
57
    """
58
    generic_explainer.__class__ = explainer_cls
1✔
59
    explainer_cls.__init__(
1✔
60
        generic_explainer,
61
        model=model,
62
        data=data,
63
        class_index=class_index,
64
        index=index,
65
        max_order=max_order,
66
        **kwargs,
67
    )
68

69

70
class Explainer:
1✔
71
    """The main Explainer class for a simpler user interface.
72

73
    shapiq.Explainer is a simplified interface for the ``shapiq`` package. It detects between
74
    :class:`~shapiq.explainer.tabular.TabularExplainer`,
75
    :class:`~shapiq.explainer.tree.TreeExplainer`,
76
    and :class:`~shapiq.explainer.tabpfn.TabPFNExplainer`. For a detailed description of the
77
    different explainers, see the respective classes.
78
    """
79

80
    model: Model | Game | Callable[[np.ndarray], np.ndarray]
1✔
81
    """The model to be explained, either as a Model instance or a callable function."""
1✔
82

83
    _index: ExplainerIndices
1✔
84
    _max_order: int
1✔
85

86
    def __init__(
1✔
87
        self,
88
        model: Model | Game | Callable[[np.ndarray], np.ndarray],
89
        data: np.ndarray | None = None,
90
        class_index: int | None = None,
91
        index: ExplainerIndices = "k-SII",
92
        max_order: int = 2,
93
        **kwargs: Any,
94
    ) -> None:
95
        """Initialize the Explainer class.
96

97
        Args:
98
            model: The model object to be explained.
99

100
            data: A background dataset to be used for imputation in
101
                :class:`~shapiq.explainer.tabular.TabularExplainer` or
102
                :class:`~shapiq.explainer.tabpfn.TabPFNExplainer`. This is a 2-dimensional
103
                NumPy array with shape ``(n_samples, n_features)``. Can be empty for the
104
                :class:`~shapiq.explainer.tree.TreeExplainer`, which does not require background
105
                data.
106

107
            class_index: The class index of the model to explain. Defaults to ``None``, which will
108
                set the class index to ``1`` per default for classification models and is ignored
109
                for regression models. Note, it is important to specify the class index for your
110
                classification model.
111

112
            index: The type of Shapley interaction index to use. Defaults to ``"k-SII"``, which
113
                computes the k-Shapley Interaction Index. If ``max_order`` is set to 1, this
114
                corresponds to the Shapley value (``index="SV"``). Options are:
115
                - ``"SV"``: Shapley value
116
                - ``"k-SII"``: k-Shapley Interaction Index
117
                - ``"FSII"``: Faithful Shapley Interaction Index
118
                - ``"FBII"``: Faithful Banzhaf Interaction Index (becomes ``BV`` for order 1)
119
                - ``"STII"``: Shapley Taylor Interaction Index
120
                - ``"SII"``: Shapley Interaction Index
121

122
            max_order: The maximum interaction order to be computed. Defaults to ``2``. Set to
123
                ``1`` for no interactions (single feature attribution).
124

125
            **kwargs: Additional keyword-only arguments passed to the specific explainer classes.
126

127
        """
128
        # If Explainer is instantiated directly, dynamically dispatch to the appropriate subclass
129
        if self.__class__ is Explainer:
1✔
130
            model_class = print_class(model)
1✔
131
            _, model_type = get_predict_function_and_model_type(model, model_class, class_index)
1✔
132
            explainer_classes = get_explainers()
1✔
133
            if model_type in explainer_classes:
1✔
134
                explainer_cls = explainer_classes[model_type]
1✔
135
                generic_to_specific_explainer(
1✔
136
                    self,
137
                    explainer_cls,
138
                    model=model,
139
                    data=data,
140
                    class_index=class_index,
141
                    index=index,
142
                    max_order=max_order,
143
                    **kwargs,
144
                )
145
                return
1✔
146
            msg = f"Model '{model_class}' with type '{model_type}' is not supported by shapiq.Explainer."
×
147
            raise TypeError(msg)
×
148

149
        # proceed with the base Explainer initialization
150
        self._model_class = print_class(model)
1✔
151
        self._shapiq_predict_function, self._model_type = get_predict_function_and_model_type(
1✔
152
            model, self._model_class, class_index
153
        )
154

155
        # validate the model and data
156
        self.model = model
1✔
157
        if data is not None:
1✔
158
            validate_data_predict_function(data, predict_function=self.predict, raise_error=False)
1✔
159
            self._data: np.ndarray = data
1✔
160

161
        # validate index and max_order and set them as attributes
162
        self._index, self._max_order = validate_index_and_max_order(index, max_order)
1✔
163

164
        # initialize private attributes
165
        self._imputer: Imputer | None = None
1✔
166
        self._approximator: Approximator | None = None
1✔
167
        self._exact_computer: ExactComputer | None = None
1✔
168

169
    @property
1✔
170
    def imputer(self) -> Imputer:
1✔
171
        """The imputer used by the explainer (or None in the base class)."""
172
        if self._imputer is None:
1✔
NEW
173
            msg = "The explainer does not have an imputer. Use a specific explainer class."
×
NEW
174
            raise NotImplementedError(msg)
×
175
        return self._imputer
1✔
176

177
    @property
1✔
178
    def exact_computer(self) -> ExactComputer:
1✔
179
        """The exact computer used by the explainer (or None in the base class)."""
NEW
180
        if self._exact_computer is None:
×
NEW
181
            msg = "The explainer does not have an exact computer. Use a specific explainer class."
×
NEW
182
            raise NotImplementedError(msg)
×
NEW
183
        return self._exact_computer
×
184

185
    @property
1✔
186
    def approximator(self) -> Approximator:
1✔
187
        """The approximator used by the explainer (or None in the base class)."""
188
        if self._approximator is None:
1✔
NEW
189
            msg = "The explainer does not have an approximator. Use a specific explainer class."
×
NEW
190
            raise NotImplementedError(msg)
×
191
        return self._approximator
1✔
192

193
    @property
1✔
194
    def index(self) -> ExplainerIndices:
1✔
195
        """The type of Shapley interaction index the explainer is using."""
196
        return self._index
1✔
197

198
    @property
1✔
199
    def max_order(self) -> int:
1✔
200
        """The maximum interaction order the explainer is using."""
201
        return self._max_order
1✔
202

203
    def explain(self, x: np.ndarray | None = None, **kwargs: Any) -> InteractionValues:
1✔
204
        """Explain a single prediction in terms of interaction values.
205

206
        Args:
207
            x: A numpy array of a data point to be explained.
208
            **kwargs: Additional keyword-only arguments passed to the specific explainer's
209
                ``explain_function`` method.
210

211
        Returns:
212
            The interaction values of the prediction.
213

214
        """
215
        return self.explain_function(x=x, **kwargs)
1✔
216

217
    def set_random_state(self, random_state: int | None = None) -> None:
1✔
218
        """Set the random state for the explainer and its components.
219

220
        Note:
221
            Setting the random state in the explainer will also overwrite the random state
222
            in the approximator and imputer, if they are set.
223

224
        Args:
225
            random_state: The random state to set. If ``None``, no random state is set.
226

227
        """
228
        if random_state is None:
1✔
229
            return
1✔
230

231
        if self.approximator is not None:
1✔
232
            self.approximator.set_random_state(random_state=random_state)
1✔
233

234
        if self.imputer is not None:
1✔
235
            self.imputer.set_random_state(random_state=random_state)
1✔
236

237
    @abstractmethod
1✔
238
    def explain_function(
1✔
239
        self, x: np.ndarray | None, *args: Any, **kwargs: Any
240
    ) -> InteractionValues:
241
        """Explain a single prediction in terms of interaction values.
242

243
        Args:
244
            x: A numpy array of a data point to be explained.
245
            *args: Additional positional arguments passed to the explainer.
246
            **kwargs: Additional keyword-only arguments passed to the explainer.
247

248
        Returns:
249
            The interaction values of the prediction.
250

251
        """
252
        msg = "The method `explain` must be implemented in a subclass."
×
253
        raise NotImplementedError(msg)
×
254

255
    def explain_X(
1✔
256
        self,
257
        X: np.ndarray,
258
        *,
259
        n_jobs: int | None = None,
260
        random_state: int | None = None,
261
        verbose: bool = False,
262
        **kwargs: Any,
263
    ) -> list[InteractionValues]:
264
        """Explain multiple predictions at once.
265

266
        This method is a wrapper around the ``explain`` method. It allows to explain multiple
267
        predictions at once. It is a convenience method that uses the ``joblib`` library to
268
        parallelize the computation of the interaction values.
269

270
        Args:
271
            X: A 2-dimensional matrix of inputs to be explained with shape (n_samples, n_features).
272

273
            n_jobs: Number of jobs for ``joblib.Parallel``. Defaults to ``None``, which will
274
                use no parallelization. If set to ``-1``, all available cores will be used.
275

276
            random_state: The random state to re-initialize Imputer and Approximator with. Defaults
277
                to ``None``.
278

279
            verbose: Whether to print a progress bar. Defaults to ``False``.
280

281
            **kwargs: Additional keyword-only arguments passed to the explainer's
282
                ``explain_function`` method.
283

284
        Returns:
285
            A list of interaction values for each prediction in the input matrix ``X``.
286

287
        """
288
        if len(X.shape) != 2:
1✔
289
            msg = "The `X` must be a 2-dimensional matrix."
×
290
            raise TypeError(msg)
×
291

292
        self.set_random_state(random_state=random_state)
1✔
293

294
        if n_jobs:  # parallelization with joblib
1✔
295
            import joblib
×
296

297
            parallel = joblib.Parallel(n_jobs=n_jobs)
×
NEW
298
            ivs: list[InteractionValues] = list(
×
299
                parallel(  # type: ignore[assignment]
300
                    joblib.delayed(self.explain)(X[i, :], **kwargs) for i in range(X.shape[0])
301
                )
302
            )
303
        else:
304
            ivs: list[InteractionValues] = []
1✔
305
            pbar = tqdm(total=X.shape[0], desc="Explaining") if verbose else None
1✔
306
            for i in range(X.shape[0]):
1✔
307
                ivs.append(self.explain(X[i, :], **kwargs))
1✔
308
                if pbar is not None:
1✔
309
                    pbar.update(1)
1✔
310
        return ivs
1✔
311

312
    def predict(self, x: np.ndarray) -> np.ndarray:
1✔
313
        """Provides a unified prediction interface for the explainer.
314

315
        Args:
316
            x: An instance/point/sample/observation to be explained.
317

318
        Returns:
319
            The model's prediction for the given data point as a vector.
320
        """
321
        if isinstance(self._shapiq_predict_function, RuntimeError):
1✔
NEW
322
            raise self._shapiq_predict_function
×
323
        return self._shapiq_predict_function(self.model, x)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc