• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18449718359

12 Oct 2025 09:25PM UTC coverage: 93.247% (-0.6%) from 93.845%
18449718359

Pull #430

github

web-flow
Merge a979c7a28 into dede390c9
Pull Request #430: Enhance type safety and fix bugs across the codebase

278 of 326 new or added lines in 46 files covered. (85.28%)

13 existing lines in 10 files now uncovered.

4985 of 5346 relevant lines covered (93.25%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.9
/src/shapiq/explainer/base.py
1
"""The base Explainer classes for the shapiq package."""
2

3
from __future__ import annotations
1✔
4

5
from abc import abstractmethod
1✔
6
from typing import TYPE_CHECKING, Any, Generic, TypeVar
1✔
7

8
from tqdm.auto import tqdm
1✔
9

10
from shapiq.approximator.base import Approximator
1✔
11
from shapiq.game_theory import ExactComputer
1✔
12
from shapiq.imputer.base import Imputer
1✔
13

14
from .utils import (
1✔
15
    get_explainers,
16
    get_predict_function_and_model_type,
17
    print_class,
18
)
19
from .validation import validate_data_predict_function, validate_index_and_max_order
1✔
20

21
if TYPE_CHECKING:
1✔
22
    from collections.abc import Callable
×
23
    from typing import Any
×
24

25
    import numpy as np
×
26

27
    from shapiq.game import Game
×
28
    from shapiq.interaction_values import InteractionValues
×
29
    from shapiq.typing import Model
×
30

31
    from .custom_types import ExplainerIndices
×
32

33

34
def generic_to_specific_explainer(
1✔
35
    generic_explainer: Explainer,
36
    explainer_cls: type[Explainer],
37
    model: Model | Game | Callable[[np.ndarray], np.ndarray],
38
    data: np.ndarray | None = None,
39
    class_index: int | None = None,
40
    index: ExplainerIndices = "k-SII",
41
    max_order: int = 2,
42
    **kwargs: Any,
43
) -> None:
44
    """Transform the base Explainer instance into a specific explainer subclass.
45

46
    This function modifies the class of the given object to the specified explainer class and
47
    initializes it with the provided parameters.
48

49
    Args:
50
        generic_explainer: The base Explainer instance to be transformed.
51
        explainer_cls: The specific explainer subclass to transform into.
52
        model: The model object to be explained.
53
        data: A background dataset to be used for imputation.
54
        class_index: The class index of the model to explain.
55
        index: The type of Shapley interaction index to use.
56
        max_order: The maximum interaction order to be computed.
57
        **kwargs: Additional keyword-only arguments passed to the specific explainer class.
58
    """
59
    generic_explainer.__class__ = explainer_cls
1✔
60
    explainer_cls.__init__(
1✔
61
        generic_explainer,
62
        model=model,
63
        data=data,
64
        class_index=class_index,
65
        index=index,
66
        max_order=max_order,
67
        **kwargs,
68
    )
69

70

71
# Type variables for the generic Explainer class.
72
TApprox = TypeVar("TApprox", Approximator, None)
1✔
73
TImputer = TypeVar("TImputer", Imputer, None)
1✔
74
TExact = TypeVar("TExact", ExactComputer, None)
1✔
75

76

77
class Explainer(Generic[TApprox, TImputer, TExact]):
1✔
78
    """The main Explainer class for a simpler user interface.
79

80
    shapiq.Explainer is a simplified interface for the ``shapiq`` package. It detects between
81
    :class:`~shapiq.explainer.tabular.TabularExplainer`,
82
    :class:`~shapiq.explainer.tree.TreeExplainer`,
83
    and :class:`~shapiq.explainer.tabpfn.TabPFNExplainer`. For a detailed description of the
84
    different explainers, see the respective classes.
85
    """
86

87
    approximator: TApprox
1✔
88
    """The approximator which may be used for the explanation (or None in the base class)."""
1✔
89

90
    exact_computer: TExact
1✔
91
    """An exact computer which computes the :class:`~shapiq.interaction_values.InteractionValues`
1✔
92
    exactly (or None in the base class). Note that this only works for small number of
93
    features as the number of coalitions grows exponentially with the number of features.
94
    """
95

96
    imputer: TImputer
1✔
97
    """An imputer which is used to impute missing values in computing the interaction values
1✔
98
    (or None in the base class)."""
99

100
    model: Model | Game | Callable[[np.ndarray], np.ndarray]
1✔
101
    """The model to be explained, either as a Model instance or a callable function."""
1✔
102

103
    def __init__(
1✔
104
        self,
105
        model: Model | Game | Callable[[np.ndarray], np.ndarray],
106
        data: np.ndarray | None = None,
107
        class_index: int | None = None,
108
        index: ExplainerIndices = "k-SII",
109
        max_order: int = 2,
110
        **kwargs: Any,
111
    ) -> None:
112
        """Initialize the Explainer class.
113

114
        Args:
115
            model: The model object to be explained.
116

117
            data: A background dataset to be used for imputation in
118
                :class:`~shapiq.explainer.tabular.TabularExplainer` or
119
                :class:`~shapiq.explainer.tabpfn.TabPFNExplainer`. This is a 2-dimensional
120
                NumPy array with shape ``(n_samples, n_features)``. Can be empty for the
121
                :class:`~shapiq.explainer.tree.TreeExplainer`, which does not require background
122
                data.
123

124
            class_index: The class index of the model to explain. Defaults to ``None``, which will
125
                set the class index to ``1`` per default for classification models and is ignored
126
                for regression models. Note, it is important to specify the class index for your
127
                classification model.
128

129
            index: The type of Shapley interaction index to use. Defaults to ``"k-SII"``, which
130
                computes the k-Shapley Interaction Index. If ``max_order`` is set to 1, this
131
                corresponds to the Shapley value (``index="SV"``). Options are:
132
                - ``"SV"``: Shapley value
133
                - ``"k-SII"``: k-Shapley Interaction Index
134
                - ``"FSII"``: Faithful Shapley Interaction Index
135
                - ``"FBII"``: Faithful Banzhaf Interaction Index (becomes ``BV`` for order 1)
136
                - ``"STII"``: Shapley Taylor Interaction Index
137
                - ``"SII"``: Shapley Interaction Index
138

139
            max_order: The maximum interaction order to be computed. Defaults to ``2``. Set to
140
                ``1`` for no interactions (single feature attribution).
141

142
            **kwargs: Additional keyword-only arguments passed to the specific explainer classes.
143

144
        """
145
        # If Explainer is instantiated directly, dynamically dispatch to the appropriate subclass
146
        if self.__class__ is Explainer:
1✔
147
            model_class = print_class(model)
1✔
148
            _, model_type = get_predict_function_and_model_type(model, model_class, class_index)
1✔
149
            explainer_classes = get_explainers()
1✔
150
            if model_type in explainer_classes:
1✔
151
                explainer_cls = explainer_classes[model_type]
1✔
152
                generic_to_specific_explainer(
1✔
153
                    self,
154
                    explainer_cls,
155
                    model=model,
156
                    data=data,
157
                    class_index=class_index,
158
                    index=index,
159
                    max_order=max_order,
160
                    **kwargs,
161
                )
162
                return
1✔
163
            msg = f"Model '{model_class}' with type '{model_type}' is not supported by shapiq.Explainer."
×
164
            raise TypeError(msg)
×
165

166
        # proceed with the base Explainer initialization
167
        self._model_class = print_class(model)
1✔
168
        self._shapiq_predict_function, self._model_type = get_predict_function_and_model_type(
1✔
169
            model, self._model_class, class_index
170
        )
171

172
        # validate the model and data
173
        self.model = model
1✔
174
        if data is not None:
1✔
175
            validate_data_predict_function(data, predict_function=self.predict, raise_error=False)
1✔
176
            self._data: np.ndarray = data
1✔
177

178
        # validate index and max_order and set them as attributes
179
        self._index, self._max_order = validate_index_and_max_order(index, max_order)
1✔
180

181
    @property
1✔
182
    def index(self) -> ExplainerIndices:
1✔
183
        """The type of Shapley interaction index the explainer is using."""
184
        return self._index  # type: ignore[return-type]
1✔
185

186
    @property
1✔
187
    def max_order(self) -> int:
1✔
188
        """The maximum interaction order the explainer is using."""
189
        return self._max_order
1✔
190

191
    def explain(self, x: np.ndarray | None = None, **kwargs: Any) -> InteractionValues:
1✔
192
        """Explain a single prediction in terms of interaction values.
193

194
        Args:
195
            x: A numpy array of a data point to be explained.
196
            **kwargs: Additional keyword-only arguments passed to the specific explainer's
197
                ``explain_function`` method.
198

199
        Returns:
200
            The interaction values of the prediction.
201

202
        """
203
        return self.explain_function(x=x, **kwargs)
1✔
204

205
    def set_random_state(self, random_state: int | None = None) -> None:
1✔
206
        """Set the random state for the explainer and its components.
207

208
        Note:
209
            Setting the random state in the explainer will also overwrite the random state
210
            in the approximator and imputer, if they are set.
211

212
        Args:
213
            random_state: The random state to set. If ``None``, no random state is set.
214

215
        """
216
        if random_state is None:
1✔
217
            return
1✔
218

219
        if self.approximator is not None:
1✔
220
            self.approximator.set_random_state(random_state=random_state)
1✔
221

222
        if self.imputer is not None:
1✔
223
            self.imputer.set_random_state(random_state=random_state)
1✔
224

225
    @abstractmethod
1✔
226
    def explain_function(
1✔
227
        self, x: np.ndarray | None, *args: Any, **kwargs: Any
228
    ) -> InteractionValues:
229
        """Explain a single prediction in terms of interaction values.
230

231
        Args:
232
            x: A numpy array of a data point to be explained.
233
            *args: Additional positional arguments passed to the explainer.
234
            **kwargs: Additional keyword-only arguments passed to the explainer.
235

236
        Returns:
237
            The interaction values of the prediction.
238

239
        """
240
        msg = "The method `explain` must be implemented in a subclass."
×
241
        raise NotImplementedError(msg)
×
242

243
    def explain_X(
1✔
244
        self,
245
        X: np.ndarray,
246
        *,
247
        n_jobs: int | None = None,
248
        random_state: int | None = None,
249
        verbose: bool = False,
250
        **kwargs: Any,
251
    ) -> list[InteractionValues]:
252
        """Explain multiple predictions at once.
253

254
        This method is a wrapper around the ``explain`` method. It allows to explain multiple
255
        predictions at once. It is a convenience method that uses the ``joblib`` library to
256
        parallelize the computation of the interaction values.
257

258
        Args:
259
            X: A 2-dimensional matrix of inputs to be explained with shape (n_samples, n_features).
260

261
            n_jobs: Number of jobs for ``joblib.Parallel``. Defaults to ``None``, which will
262
                use no parallelization. If set to ``-1``, all available cores will be used.
263

264
            random_state: The random state to re-initialize Imputer and Approximator with. Defaults
265
                to ``None``.
266

267
            verbose: Whether to print a progress bar. Defaults to ``False``.
268

269
            **kwargs: Additional keyword-only arguments passed to the explainer's
270
                ``explain_function`` method.
271

272
        Returns:
273
            A list of interaction values for each prediction in the input matrix ``X``.
274

275
        """
276
        if len(X.shape) != 2:
1✔
277
            msg = "The `X` must be a 2-dimensional matrix."
×
278
            raise TypeError(msg)
×
279

280
        self.set_random_state(random_state=random_state)
1✔
281

282
        if n_jobs:  # parallelization with joblib
1✔
283
            import joblib
×
284

285
            parallel = joblib.Parallel(n_jobs=n_jobs)
×
NEW
286
            ivs: list[InteractionValues] = list(
×
287
                parallel(  # type: ignore[assignment]
288
                    joblib.delayed(self.explain)(X[i, :], **kwargs) for i in range(X.shape[0])
289
                )
290
            )
291
        else:
292
            ivs: list[InteractionValues] = []
1✔
293
            pbar = tqdm(total=X.shape[0], desc="Explaining") if verbose else None
1✔
294
            for i in range(X.shape[0]):
1✔
295
                ivs.append(self.explain(X[i, :], **kwargs))
1✔
296
                if pbar is not None:
1✔
297
                    pbar.update(1)
1✔
298
        return ivs
1✔
299

300
    def predict(self, x: np.ndarray) -> np.ndarray:
1✔
301
        """Provides a unified prediction interface for the explainer.
302

303
        Args:
304
            x: An instance/point/sample/observation to be explained.
305

306
        Returns:
307
            The model's prediction for the given data point as a vector.
308
        """
309
        if isinstance(self._shapiq_predict_function, RuntimeError):
1✔
NEW
310
            raise self._shapiq_predict_function
×
311
        return self._shapiq_predict_function(self.model, x)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc