• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18540698301

15 Oct 2025 07:44PM UTC coverage: 92.615% (-0.2%) from 92.799%
18540698301

Pull #431

github

web-flow
Merge 606a077aa into 193f1a9ef
Pull Request #431: Product kernel explainer

185 of 210 new or added lines in 13 files covered. (88.1%)

4 existing lines in 2 files now uncovered.

5167 of 5579 relevant lines covered (92.62%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.05
/src/shapiq/approximator/sparse/base.py
1
"""Base Sparse approximator for fourier-based interaction computation."""
2

3
from __future__ import annotations
1✔
4

5
import copy
1✔
6
from typing import TYPE_CHECKING, Literal, cast, get_args
1✔
7

8
import numpy as np
1✔
9
from sparse_transform.qsft.qsft import transform as sparse_fourier_transform
1✔
10
from sparse_transform.qsft.signals.input_signal_subsampled import (
1✔
11
    SubsampledSignal as SubsampledSignalFourier,
12
)
13
from sparse_transform.qsft.utils.general import fourier_to_mobius as fourier_to_moebius
1✔
14
from sparse_transform.qsft.utils.query import get_bch_decoder
1✔
15

16
from shapiq.approximator.base import Approximator
1✔
17
from shapiq.game_theory.moebius_converter import (
1✔
18
    MoebiusConverter,
19
    ValidMoebiusConverterIndices,
20
)
21
from shapiq.interaction_values import InteractionValues
1✔
22

23
if TYPE_CHECKING:
1✔
24
    from collections.abc import Callable
×
UNCOV
25
    from typing import Any
×
26

UNCOV
27
    from shapiq.game import Game
×
28

29
ValidSparseIndices = ValidMoebiusConverterIndices
1✔
30

31

32
class Sparse(Approximator[ValidSparseIndices]):
1✔
33
    """Approximator interface using sparse transformation techniques.
34

35
    This class implements a sparse approximation method for computing various interaction indices
36
    using sparse Fourier transforms. It efficiently estimates interaction values with a limited
37
    sample budget by leveraging sparsity in the Fourier domain. The notion of sparse approximation
38
    is described in [Kan25]_.
39

40
    See Also:
41
        - :class:`~shapiq.approximator.sparse.SPEX` for a specific implementation of the
42
            sparse approximation using Fourier transforms described in [Kan25]_.
43

44
    Attributes:
45
        transform_type: Type of transform used (currently only ``"fourier"`` is supported).
46

47
        degree_parameter: A parameter that controls the maximum degree of the interactions to
48
            extract during execution of the algorithm. Note that this is a soft limit, and in
49
            practice, the algorithm may extract interactions of any degree. We typically find
50
            that there is little value going beyond ``5``. Defaults to ``5``. Note that
51
            increasing this parameter will need more ``budget`` in the :meth:`approximate`
52
            method.
53

54
        query_args: Parameters for querying the signal.
55

56
        decoder_args: Parameters for decoding the transform.
57

58
    Raises:
59
        ValueError: If transform_type is not "fourier" or if decoder_type is not "soft" or "hard".
60

61
    References:
62
        .. [Kan25] Kang, J.S., Butler, L., Agarwal. A., Erginbas, Y.E., Pedarsani, R., Ramchandran, K., Yu, Bin (2025). SPEX: Scaling Feature Interaction Explanations for LLMs https://arxiv.org/abs/2502.13870
63

64
    """
65

66
    valid_indices: tuple[ValidSparseIndices, ...] = tuple(get_args(ValidSparseIndices))  # type: ignore[assignment]
1✔
67
    """The valid indices for the SPEX approximator."""
1✔
68

69
    def __init__(
1✔
70
        self,
71
        n: int,
72
        index: ValidSparseIndices,
73
        *,
74
        max_order: int | None = None,
75
        top_order: bool = False,
76
        random_state: int | None = None,
77
        transform_type: Literal["fourier"] = "fourier",
78
        decoder_type: Literal["soft", "hard"] = "soft",
79
        degree_parameter: int = 5,
80
    ) -> None:
81
        """Initialize the Sparse approximator.
82

83
        Args:
84
            n: Number of players (features).
85

86
            max_order: Maximum interaction order to consider. Defaults to ``None``, which means
87
                that all orders up to ``n`` will be considered.
88

89
            index: The Interaction index to use. All indices supported by shapiq's
90
                :class:`~shapiq.game_theory.moebius_converter.MoebiusConverter` are supported.
91

92
            top_order: If ``True``, only reports interactions of exactly order ``max_order``.
93
                Otherwise, reports all interactions up to order ``max_order``. Defaults to
94
                ``False``.
95

96
            random_state: Seed for random number generator. Defaults to ``None``.
97

98
            transform_type: Type of transform to use. Currently only "fourier" is supported.
99

100
            decoder_type: Type of decoder to use, either "soft" or "hard". Defaults to "soft".
101

102
            degree_parameter: A parameter that controls the maximum degree of the interactions to
103
                extract during execution of the algorithm. Note that this is a soft limit, and in
104
                practice, the algorithm may extract interactions of any degree. We typically find
105
                that there is little value going beyond ``5``. Defaults to ``5``. Note that
106
                increasing this parameter will need more ``budget`` in the :meth:`approximate`
107
                method.
108

109
        """
110
        if transform_type.lower() not in ["fourier"]:
1✔
111
            msg = "transform_type must be 'fourier'"
1✔
112
            raise ValueError(msg)
1✔
113
        self.transform_type = transform_type.lower()
1✔
114
        self.degree_parameter = degree_parameter
1✔
115
        self.decoder_type = "hard" if decoder_type is None else decoder_type.lower()
1✔
116
        if self.decoder_type not in ["soft", "hard"]:
1✔
117
            msg = "decoder_type must be 'soft' or 'hard'"
1✔
118
            raise ValueError(msg)
1✔
119
        # The sampling parameters for the Fourier transform
120
        self.query_args = {
1✔
121
            "query_method": "complex",
122
            "num_subsample": 3,
123
            "delays_method_source": "joint-coded",
124
            "subsampling_method": "qsft",
125
            "delays_method_channel": "identity-siso",
126
            "num_repeat": 1,
127
            "t": self.degree_parameter,
128
        }
129
        self.decoder_args = {
1✔
130
            "num_subsample": 3,
131
            "num_repeat": 1,
132
            "reconstruct_method_source": "coded",
133
            "peeling_method": "multi-detect",
134
            "reconstruct_method_channel": (
135
                "identity-siso" if self.decoder_type == "soft" else "identity"
136
            ),
137
            "regress": "lasso",
138
            "res_energy_cutoff": 0.9,
139
            "source_decoder": get_bch_decoder(n, self.degree_parameter, self.decoder_type),
140
        }
141
        super().__init__(
1✔
142
            n=n,
143
            max_order=n if max_order is None else max_order,
144
            index=index,
145
            top_order=top_order,
146
            random_state=random_state,
147
            initialize_dict=False,  # Important for performance
148
        )
149

150
    def approximate(
1✔
151
        self,
152
        budget: int,
153
        game: Game | Callable[[np.ndarray], np.ndarray],
154
        **kwargs: Any,  # noqa: ARG002
155
    ) -> InteractionValues:
156
        """Approximates the interaction values using a sparse transform approach.
157

158
        Args:
159
            budget: The budget for the approximation.
160
            game: The game function that returns the values for the coalitions.
161
            **kwargs: Additional keyword arguments (not used).
162

163
        Returns:
164
            The approximated Shapley interaction values.
165
        """
166
        # Find the max value of b that fits within the given sample budget and get the used budget
167
        used_budget = self._set_transform_budget(budget)
1✔
168
        signal = SubsampledSignalFourier(
1✔
169
            func=lambda inputs: game(inputs.astype(bool)),
170
            n=self.n,
171
            q=2,
172
            query_args=self.query_args,
173
        )
174
        # Extract the coefficients of the original transform
175
        initial_transform = {
1✔
176
            tuple(np.nonzero(key)[0]): np.real(value)
177
            for key, value in sparse_fourier_transform(signal, **self.decoder_args).items()
178
        }
179
        # If we are using the fourier transform, we need to convert it to a Moebius transform
180
        moebius_transform = fourier_to_moebius(initial_transform)
1✔
181
        # Convert the Moebius transform to the desired index
182
        result = self._process_moebius(moebius_transform=moebius_transform)
1✔
183
        # Filter the output as needed
184
        if self.top_order:
1✔
185
            result = self._filter_order(result)
1✔
186
        # finalize the interactions
187
        return InteractionValues(
1✔
188
            values=result,
189
            index=self.approximation_index,
190
            min_order=self.min_order,
191
            max_order=self.max_order,
192
            n_players=self.n,
193
            interaction_lookup=copy.deepcopy(self.interaction_lookup),
194
            estimated=True,
195
            estimation_budget=used_budget,
196
            baseline_value=self.interaction_lookup.get((), 0.0),
197
            target_index=self.index,
198
        )
199

200
    def _filter_order(self, result: np.ndarray) -> np.ndarray:
1✔
201
        """Filters the interactions to keep only those of the maximum order.
202

203
        This method is used when top_order=True to filter out all interactions that are not
204
        of exactly the maximum order (self.max_order).
205

206
        Args:
207
            result: Array of interaction values.
208

209
        Returns:
210
            Filtered array containing only interaction values of the maximum order.
211
            The method also updates the internal _interaction_lookup dictionary.
212
        """
213
        filtered_interactions = {}
1✔
214
        filtered_results = []
1✔
215
        i = 0
1✔
216
        for j, key in enumerate(self.interaction_lookup):
1✔
217
            if len(key) == self.max_order:
1✔
218
                filtered_interactions[key] = i
1✔
219
                filtered_results.append(result[j])
1✔
220
                i += 1
1✔
221
        self._interaction_lookup = filtered_interactions
1✔
222
        return np.array(filtered_results)
1✔
223

224
    def _process_moebius(self, moebius_transform: dict[tuple, float]) -> np.ndarray:
1✔
225
        """Convert the Moebius transform into the desired index.
226

227
        Args:
228
            moebius_transform: The Moebius transform to process as a dict mapping tuples to float
229
                values.
230

231
        Returns:
232
            np.ndarray: The converted interaction values based on the specified index.
233
            The function also updates the internal _interaction_lookup dictionary.
234
        """
235
        moebius_interactions = InteractionValues(
1✔
236
            values=np.array([moebius_transform[key] for key in moebius_transform]),
237
            index="Moebius",
238
            min_order=self.min_order,
239
            max_order=self.max_order,
240
            n_players=self.n,
241
            interaction_lookup={key: i for i, key in enumerate(moebius_transform.keys())},
242
            estimated=True,
243
            baseline_value=moebius_transform.get((), 0.0),
244
        )
245
        autoconverter = MoebiusConverter(moebius_coefficients=moebius_interactions)
1✔
246
        converted_interaction_values = autoconverter(
1✔
247
            index=cast(ValidMoebiusConverterIndices, self.index), order=self.max_order
248
        )
249
        self._interaction_lookup = converted_interaction_values.interaction_lookup
1✔
250
        return converted_interaction_values.values
1✔
251

252
    def _set_transform_budget(self, budget: int) -> int:
1✔
253
        """Sets the appropriate transform budget parameters based on the given sample budget.
254

255
        This method calculates the maximum possible 'b' parameter (number of bits to subsample)
256
        that fits within the provided budget, then configures the query and decoder arguments
257
        accordingly. The actual number of samples that will be used is returned.
258

259
        Args:
260
            budget: The maximum number of samples allowed for the approximation.
261

262
        Returns:
263
            int: The actual number of samples that will be used, which is less than or equal to the
264
                budget.
265

266
        Raises:
267
            ValueError: If the budget is too low to compute the transform with acceptable parameters.
268
        """
269
        b = SubsampledSignalFourier.get_b_for_sample_budget(
1✔
270
            budget, self.n, self.degree_parameter, 2, self.query_args
271
        )
272
        used_budget = SubsampledSignalFourier.get_number_of_samples(
1✔
273
            self.n, b, self.degree_parameter, 2, self.query_args
274
        )
275

276
        if b <= 2:
1✔
277
            while self.degree_parameter > 2:
1✔
278
                self.degree_parameter -= 1
1✔
279
                self.query_args["t"] = self.degree_parameter
1✔
280

281
                # Recalculate 'b' with the updated 't'
282
                b = SubsampledSignalFourier.get_b_for_sample_budget(
1✔
283
                    budget, self.n, self.degree_parameter, 2, self.query_args
284
                )
285

286
                # Compute the used budget
287
                used_budget = SubsampledSignalFourier.get_number_of_samples(
1✔
288
                    self.n, b, self.degree_parameter, 2, self.query_args
289
                )
290

291
                # Break if 'b' is now sufficient
292
                if b > 2:
1✔
293
                    self.decoder_args["source_decoder"] = get_bch_decoder(
1✔
294
                        self.n, self.degree_parameter, self.decoder_type
295
                    )
296
                    break
1✔
297

298
            # If 'b' is still too low, raise an error
299
            if b <= 2:
1✔
300
                msg = (
1✔
301
                    "Insufficient budget to compute the transform. Increase the budget or use a "
302
                    "different approximator."
303
                )
304
                raise ValueError(msg)
1✔
305
        # Store the final 'b' value
306
        self.query_args["b"] = b
1✔
307
        self.decoder_args["b"] = b
1✔
308
        return used_budget
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc