• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 16400353732

20 Jul 2025 01:21PM UTC coverage: 93.901% (+16.3%) from 77.589%
16400353732

push

github

web-flow
🔨 moves `shapiq.benchmark` and `shapiq.games` out of `shapiq`  (#416)

* moves benchmark and games out of shapiq core

* make macos single process test run

* add comment in macos

* updated CHANGELOG.md

* update type imports

25 of 46 new or added lines in 14 files covered. (54.35%)

4 existing lines in 1 file now uncovered.

4927 of 5247 relevant lines covered (93.9%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.65
/src/shapiq/explainer/base.py
1
"""The base Explainer classes for the shapiq package."""
2

3
from __future__ import annotations
1✔
4

5
from abc import abstractmethod
1✔
6
from typing import TYPE_CHECKING
1✔
7

8
from tqdm.auto import tqdm
1✔
9

10
from .utils import (
1✔
11
    get_explainers,
12
    get_predict_function_and_model_type,
13
    print_class,
14
)
15
from .validation import validate_data_predict_function, validate_index_and_max_order
1✔
16

17
if TYPE_CHECKING:
1✔
18
    from collections.abc import Callable
×
19
    from typing import Any
×
20

21
    import numpy as np
×
22

23
    from shapiq.approximator.base import Approximator
×
NEW
24
    from shapiq.game import Game
×
25
    from shapiq.game_theory import ExactComputer
×
NEW
26
    from shapiq.imputer.base import Imputer
×
27
    from shapiq.interaction_values import InteractionValues
×
NEW
28
    from shapiq.typing import Model
×
29

30
    from .custom_types import ExplainerIndices
×
31

32

33
class Explainer:
1✔
34
    """The main Explainer class for a simpler user interface.
35

36
    shapiq.Explainer is a simplified interface for the ``shapiq`` package. It detects between
37
    :class:`~shapiq.explainer.tabular.TabularExplainer`,
38
    :class:`~shapiq.explainer.tree.TreeExplainer`,
39
    and :class:`~shapiq.explainer.tabpfn.TabPFNExplainer`. For a detailed description of the
40
    different explainers, see the respective classes.
41
    """
42

43
    approximator: Approximator | None
1✔
44
    """The approximator which may be used for the explanation."""
1✔
45

46
    exact_computer: ExactComputer | None
1✔
47
    """An exact computer which computes the :class:`~shapiq.interaction_values.InteractionValues`
1✔
48
    exactly (without the need for approximations). Note that this only works for small number of
49
    features as the number of coalitions grows exponentially with the number of features.
50
    """
51

52
    imputer: Imputer | None
1✔
53
    """An imputer which is used to impute missing values in computing the interaction values."""
1✔
54

55
    model: Model | Game | Callable[[np.ndarray], np.ndarray]
1✔
56
    """The model to be explained, either as a Model instance or a callable function."""
1✔
57

58
    def __init__(
1✔
59
        self,
60
        model: Model | Game | Callable[[np.ndarray], np.ndarray],
61
        data: np.ndarray | None = None,
62
        class_index: int | None = None,
63
        index: ExplainerIndices = "k-SII",
64
        max_order: int = 2,
65
        **kwargs: Any,
66
    ) -> None:
67
        """Initialize the Explainer class.
68

69
        Args:
70
            model: The model object to be explained.
71

72
            data: A background dataset to be used for imputation in
73
                :class:`~shapiq.explainer.tabular.TabularExplainer` or
74
                :class:`~shapiq.explainer.tabpfn.TabPFNExplainer`. This is a 2-dimensional
75
                NumPy array with shape ``(n_samples, n_features)``. Can be ``None`` for the
76
                :class:`~shapiq.explainer.tree.TreeExplainer`, which does not require background
77
                data.
78

79
            class_index: The class index of the model to explain. Defaults to ``None``, which will
80
                set the class index to ``1`` per default for classification models and is ignored
81
                for regression models. Note, it is important to specify the class index for your
82
                classification model.
83

84
            index: The type of Shapley interaction index to use. Defaults to ``"k-SII"``, which
85
                computes the k-Shapley Interaction Index. If ``max_order`` is set to 1, this
86
                corresponds to the Shapley value (``index="SV"``). Options are:
87
                - ``"SV"``: Shapley value
88
                - ``"k-SII"``: k-Shapley Interaction Index
89
                - ``"FSII"``: Faithful Shapley Interaction Index
90
                - ``"FBII"``: Faithful Banzhaf Interaction Index (becomes ``BV`` for order 1)
91
                - ``"STII"``: Shapley Taylor Interaction Index
92
                - ``"SII"``: Shapley Interaction Index
93

94
            max_order: The maximum interaction order to be computed. Defaults to ``2``. Set to
95
                ``1`` for no interactions (single feature attribution).
96

97
            **kwargs: Additional keyword-only arguments passed to the specific explainer classes.
98

99
        """
100
        # If Explainer is instantiated directly, dynamically dispatch to the appropriate subclass
101
        if self.__class__ is Explainer:
1✔
102
            model_class = print_class(model)
1✔
103
            _, model_type = get_predict_function_and_model_type(model, model_class, class_index)
1✔
104
            explainer_classes = get_explainers()
1✔
105
            if model_type in explainer_classes:
1✔
106
                explainer_cls = explainer_classes[model_type]
1✔
107
                self.__class__ = explainer_cls
1✔
108
                explainer_cls.__init__(
1✔
109
                    self,
110
                    model=model,
111
                    data=data,
112
                    class_index=class_index,
113
                    index=index,
114
                    max_order=max_order,
115
                    **kwargs,
116
                )
117
                return  # avoid continuing in base Explainer
1✔
118
            msg = f"Model '{model_class}' with type '{model_type}' is not supported by shapiq.Explainer."
×
119
            raise TypeError(msg)
×
120

121
        # proceed with the base Explainer initialization
122
        self._model_class = print_class(model)
1✔
123
        self._shapiq_predict_function, self._model_type = get_predict_function_and_model_type(
1✔
124
            model, self._model_class, class_index
125
        )
126

127
        # validate the model and data
128
        self.model = model
1✔
129
        if data is not None:
1✔
130
            validate_data_predict_function(data, predict_function=self.predict, raise_error=False)
1✔
131
        self._data: np.ndarray | None = data
1✔
132

133
        # validate index and max_order and set them as attributes
134
        self._index, self._max_order = validate_index_and_max_order(index, max_order)
1✔
135

136
        # set the class attributes
137
        self.approximator = None
1✔
138
        self.exact_computer = None
1✔
139
        self.imputer = None
1✔
140

141
    @property
1✔
142
    def index(self) -> ExplainerIndices:
1✔
143
        """The type of Shapley interaction index the explainer is using."""
144
        return self._index
1✔
145

146
    @property
1✔
147
    def max_order(self) -> int:
1✔
148
        """The maximum interaction order the explainer is using."""
149
        return self._max_order
1✔
150

151
    def explain(self, x: np.ndarray | None = None, **kwargs: Any) -> InteractionValues:
1✔
152
        """Explain a single prediction in terms of interaction values.
153

154
        Args:
155
            x: A numpy array of a data point to be explained.
156
            **kwargs: Additional keyword-only arguments passed to the specific explainer's
157
                ``explain_function`` method.
158

159
        Returns:
160
            The interaction values of the prediction.
161

162
        """
163
        return self.explain_function(x=x, **kwargs)
1✔
164

165
    def set_random_state(self, random_state: int | None = None) -> None:
1✔
166
        """Set the random state for the explainer and its components.
167

168
        Note:
169
            Setting the random state in the explainer will also overwrite the random state
170
            in the approximator and imputer, if they are set.
171

172
        Args:
173
            random_state: The random state to set. If ``None``, no random state is set.
174

175
        """
176
        if random_state is None:
1✔
177
            return
1✔
178

179
        if self.approximator is not None:
1✔
180
            self.approximator.set_random_state(random_state=random_state)
1✔
181

182
        if self.imputer is not None:
1✔
183
            self.imputer.set_random_state(random_state=random_state)
1✔
184

185
    @abstractmethod
1✔
186
    def explain_function(self, x: np.ndarray, *args: Any, **kwargs: Any) -> InteractionValues:
1✔
187
        """Explain a single prediction in terms of interaction values.
188

189
        Args:
190
            x: A numpy array of a data point to be explained.
191
            *args: Additional positional arguments passed to the explainer.
192
            **kwargs: Additional keyword-only arguments passed to the explainer.
193

194
        Returns:
195
            The interaction values of the prediction.
196

197
        """
198
        msg = "The method `explain` must be implemented in a subclass."
×
199
        raise NotImplementedError(msg)
×
200

201
    def explain_X(
1✔
202
        self,
203
        X: np.ndarray,
204
        *,
205
        n_jobs: int | None = None,
206
        random_state: int | None = None,
207
        verbose: bool = False,
208
        **kwargs: Any,
209
    ) -> list[InteractionValues]:
210
        """Explain multiple predictions at once.
211

212
        This method is a wrapper around the ``explain`` method. It allows to explain multiple
213
        predictions at once. It is a convenience method that uses the ``joblib`` library to
214
        parallelize the computation of the interaction values.
215

216
        Args:
217
            X: A 2-dimensional matrix of inputs to be explained with shape (n_samples, n_features).
218

219
            n_jobs: Number of jobs for ``joblib.Parallel``. Defaults to ``None``, which will
220
                use no parallelization. If set to ``-1``, all available cores will be used.
221

222
            random_state: The random state to re-initialize Imputer and Approximator with. Defaults
223
                to ``None``.
224

225
            verbose: Whether to print a progress bar. Defaults to ``False``.
226

227
            **kwargs: Additional keyword-only arguments passed to the explainer's
228
                ``explain_function`` method.
229

230
        Returns:
231
            A list of interaction values for each prediction in the input matrix ``X``.
232

233
        """
234
        if len(X.shape) != 2:
1✔
235
            msg = "The `X` must be a 2-dimensional matrix."
×
236
            raise TypeError(msg)
×
237

238
        self.set_random_state(random_state=random_state)
1✔
239

240
        if n_jobs:  # parallelization with joblib
1✔
241
            import joblib
×
242

243
            parallel = joblib.Parallel(n_jobs=n_jobs)
×
244
            ivs = parallel(
×
245
                joblib.delayed(self.explain)(X[i, :], **kwargs) for i in range(X.shape[0])
246
            )
247
        else:
248
            ivs = []
1✔
249
            pbar = tqdm(total=X.shape[0], desc="Explaining") if verbose else None
1✔
250
            for i in range(X.shape[0]):
1✔
251
                ivs.append(self.explain(X[i, :], **kwargs))
1✔
252
                if pbar is not None:
1✔
253
                    pbar.update(1)
1✔
254
        return ivs
1✔
255

256
    def predict(self, x: np.ndarray) -> np.ndarray:
1✔
257
        """Provides a unified prediction interface for the explainer.
258

259
        Args:
260
            x: An instance/point/sample/observation to be explained.
261

262
        Returns:
263
            The model's prediction for the given data point as a vector.
264
        """
265
        return self._shapiq_predict_function(self.model, x)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc