• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 16350591025

17 Jul 2025 04:26PM UTC coverage: 93.901% (+16.3%) from 77.589%
16350591025

Pull #416

github

web-flow
Merge 158d774a2 into 82c8f7562
Pull Request #416: moves benchmark and games out of shapiq core

25 of 44 new or added lines in 14 files covered. (56.82%)

4 existing lines in 1 file now uncovered.

4927 of 5247 relevant lines covered (93.9%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.1
/src/shapiq/approximator/sparse/base.py
1
"""Base Sparse approximator for fourier-based interaction computation."""
2

3
from __future__ import annotations
1✔
4

5
import copy
1✔
6
from typing import TYPE_CHECKING, Literal, get_args
1✔
7

8
import numpy as np
1✔
9
from sparse_transform.qsft.qsft import transform as sparse_fourier_transform
1✔
10
from sparse_transform.qsft.signals.input_signal_subsampled import (
1✔
11
    SubsampledSignal as SubsampledSignalFourier,
12
)
13
from sparse_transform.qsft.utils.general import fourier_to_mobius as fourier_to_moebius
1✔
14
from sparse_transform.qsft.utils.query import get_bch_decoder
1✔
15

16
from shapiq.approximator.base import Approximator
1✔
17
from shapiq.game_theory.moebius_converter import MoebiusConverter, ValidMoebiusConverterIndices
1✔
18
from shapiq.interaction_values import InteractionValues, finalize_computed_interactions
1✔
19

20
if TYPE_CHECKING:
1✔
21
    from collections.abc import Callable
×
22
    from typing import Any
×
23

NEW
24
    from shapiq.game import Game
×
25

26
ValidSparseIndices = ValidMoebiusConverterIndices
1✔
27

28

29
class Sparse(Approximator):
1✔
30
    """Approximator interface using sparse transformation techniques.
31

32
    This class implements a sparse approximation method for computing various interaction indices
33
    using sparse Fourier transforms. It efficiently estimates interaction values with a limited
34
    sample budget by leveraging sparsity in the Fourier domain. The notion of sparse approximation
35
    is described in [Kan25]_.
36

37
    See Also:
38
        - :class:`~shapiq.approximator.sparse.SPEX` for a specific implementation of the
39
            sparse approximation using Fourier transforms described in [Kan25]_.
40

41
    Attributes:
42
        transform_type: Type of transform used (currently only ``"fourier"`` is supported).
43

44
        degree_parameter: A parameter that controls the maximum degree of the interactions to
45
            extract during execution of the algorithm. Note that this is a soft limit, and in
46
            practice, the algorithm may extract interactions of any degree. We typically find
47
            that there is little value going beyond ``5``. Defaults to ``5``. Note that
48
            increasing this parameter will need more ``budget`` in the :meth:`approximate`
49
            method.
50

51
        query_args: Parameters for querying the signal.
52

53
        decoder_args: Parameters for decoding the transform.
54

55
    Raises:
56
        ValueError: If transform_type is not "fourier" or if decoder_type is not "soft" or "hard".
57

58
    References:
59
        .. [Kan25] Kang, J.S., Butler, L., Agarwal. A., Erginbas, Y.E., Pedarsani, R., Ramchandran, K., Yu, Bin (2025). SPEX: Scaling Feature Interaction Explanations for LLMs https://arxiv.org/abs/2502.13870
60

61
    """
62

63
    valid_indices: tuple[ValidSparseIndices] = tuple(get_args(ValidSparseIndices))
1✔
64
    """The valid indices for the SPEX approximator."""
1✔
65

66
    def __init__(
1✔
67
        self,
68
        n: int,
69
        index: ValidSparseIndices,
70
        *,
71
        max_order: int | None = None,
72
        top_order: bool = False,
73
        random_state: int | None = None,
74
        transform_type: Literal["fourier"] = "fourier",
75
        decoder_type: Literal["soft", "hard"] = "soft",
76
        degree_parameter: int = 5,
77
    ) -> None:
78
        """Initialize the Sparse approximator.
79

80
        Args:
81
            n: Number of players (features).
82

83
            max_order: Maximum interaction order to consider. Defaults to ``None``, which means
84
                that all orders up to ``n`` will be considered.
85

86
            index: The Interaction index to use. All indices supported by shapiq's
87
                :class:`~shapiq.game_theory.moebius_converter.MoebiusConverter` are supported.
88

89
            top_order: If ``True``, only reports interactions of exactly order ``max_order``.
90
                Otherwise, reports all interactions up to order ``max_order``. Defaults to
91
                ``False``.
92

93
            random_state: Seed for random number generator. Defaults to ``None``.
94

95
            transform_type: Type of transform to use. Currently only "fourier" is supported.
96

97
            decoder_type: Type of decoder to use, either "soft" or "hard". Defaults to "soft".
98

99
            degree_parameter: A parameter that controls the maximum degree of the interactions to
100
                extract during execution of the algorithm. Note that this is a soft limit, and in
101
                practice, the algorithm may extract interactions of any degree. We typically find
102
                that there is little value going beyond ``5``. Defaults to ``5``. Note that
103
                increasing this parameter will need more ``budget`` in the :meth:`approximate`
104
                method.
105

106
        """
107
        if transform_type.lower() not in ["fourier"]:
1✔
108
            msg = "transform_type must be 'fourier'"
1✔
109
            raise ValueError(msg)
1✔
110
        self.transform_type = transform_type.lower()
1✔
111
        self.degree_parameter = degree_parameter
1✔
112
        self.decoder_type = "hard" if decoder_type is None else decoder_type.lower()
1✔
113
        if self.decoder_type not in ["soft", "hard"]:
1✔
114
            msg = "decoder_type must be 'soft' or 'hard'"
1✔
115
            raise ValueError(msg)
1✔
116
        # The sampling parameters for the Fourier transform
117
        self.query_args = {
1✔
118
            "query_method": "complex",
119
            "num_subsample": 3,
120
            "delays_method_source": "joint-coded",
121
            "subsampling_method": "qsft",
122
            "delays_method_channel": "identity-siso",
123
            "num_repeat": 1,
124
            "t": self.degree_parameter,
125
        }
126
        self.decoder_args = {
1✔
127
            "num_subsample": 3,
128
            "num_repeat": 1,
129
            "reconstruct_method_source": "coded",
130
            "peeling_method": "multi-detect",
131
            "reconstruct_method_channel": "identity-siso"
132
            if self.decoder_type == "soft"
133
            else "identity",
134
            "regress": "lasso",
135
            "res_energy_cutoff": 0.9,
136
            "source_decoder": get_bch_decoder(n, self.degree_parameter, self.decoder_type),
137
        }
138
        super().__init__(
1✔
139
            n=n,
140
            max_order=n if max_order is None else max_order,
141
            index=index,
142
            top_order=top_order,
143
            random_state=random_state,
144
            initialize_dict=False,  # Important for performance
145
        )
146

147
    def approximate(
1✔
148
        self,
149
        budget: int,
150
        game: Game | Callable[[np.ndarray], np.ndarray],
151
        **kwargs: Any,  # noqa: ARG002
152
    ) -> InteractionValues:
153
        """Approximates the interaction values using a sparse transform approach.
154

155
        Args:
156
            budget: The budget for the approximation.
157
            game: The game function that returns the values for the coalitions.
158
            **kwargs: Additional keyword arguments (not used).
159

160
        Returns:
161
            The approximated Shapley interaction values.
162
        """
163
        # Find the max value of b that fits within the given sample budget and get the used budget
164
        used_budget = self._set_transform_budget(budget)
1✔
165
        signal = SubsampledSignalFourier(
1✔
166
            func=lambda inputs: game(inputs.astype(bool)),
167
            n=self.n,
168
            q=2,
169
            query_args=self.query_args,
170
        )
171
        # Extract the coefficients of the original transform
172
        initial_transform = {
1✔
173
            tuple(np.nonzero(key)[0]): np.real(value)
174
            for key, value in sparse_fourier_transform(signal, **self.decoder_args).items()
175
        }
176
        # If we are using the fourier transform, we need to convert it to a Moebius transform
177
        moebius_transform = fourier_to_moebius(initial_transform)
1✔
178
        # Convert the Moebius transform to the desired index
179
        result = self._process_moebius(moebius_transform=moebius_transform)
1✔
180
        # Filter the output as needed
181
        if self.top_order:
1✔
182
            result = self._filter_order(result)
1✔
183
        output = InteractionValues(
1✔
184
            values=result,
185
            index=self.approximation_index,
186
            min_order=self.min_order,
187
            max_order=self.max_order,
188
            n_players=self.n,
189
            interaction_lookup=copy.deepcopy(self.interaction_lookup),
190
            estimated=True,
191
            estimation_budget=used_budget,
192
            baseline_value=self.interaction_lookup.get((), 0.0),
193
        )
194
        # finalize the interactions
195
        return finalize_computed_interactions(output, target_index=self.index)
1✔
196

197
    def _filter_order(self, result: np.ndarray) -> np.ndarray:
1✔
198
        """Filters the interactions to keep only those of the maximum order.
199

200
        This method is used when top_order=True to filter out all interactions that are not
201
        of exactly the maximum order (self.max_order).
202

203
        Args:
204
            result: Array of interaction values.
205

206
        Returns:
207
            Filtered array containing only interaction values of the maximum order.
208
            The method also updates the internal _interaction_lookup dictionary.
209
        """
210
        filtered_interactions = {}
1✔
211
        filtered_results = []
1✔
212
        i = 0
1✔
213
        for j, key in enumerate(self.interaction_lookup):
1✔
214
            if len(key) == self.max_order:
1✔
215
                filtered_interactions[key] = i
1✔
216
                filtered_results.append(result[j])
1✔
217
                i += 1
1✔
218
        self._interaction_lookup = filtered_interactions
1✔
219
        return np.array(filtered_results)
1✔
220

221
    def _process_moebius(self, moebius_transform: dict[tuple, float]) -> np.ndarray:
1✔
222
        """Convert the Moebius transform into the desired index.
223

224
        Args:
225
            moebius_transform: The Moebius transform to process as a dict mapping tuples to float
226
                values.
227

228
        Returns:
229
            np.ndarray: The converted interaction values based on the specified index.
230
            The function also updates the internal _interaction_lookup dictionary.
231
        """
232
        moebius_interactions = InteractionValues(
1✔
233
            values=np.array([moebius_transform[key] for key in moebius_transform]),
234
            index="Moebius",
235
            min_order=self.min_order,
236
            max_order=self.max_order,
237
            n_players=self.n,
238
            interaction_lookup={key: i for i, key in enumerate(moebius_transform.keys())},
239
            estimated=True,
240
            baseline_value=moebius_transform.get((), 0.0),
241
        )
242
        autoconverter = MoebiusConverter(moebius_coefficients=moebius_interactions)
1✔
243
        converted_interaction_values = autoconverter(index=self.index, order=self.max_order)
1✔
244
        self._interaction_lookup = converted_interaction_values.interaction_lookup
1✔
245
        return converted_interaction_values.values
1✔
246

247
    def _set_transform_budget(self, budget: int) -> int:
1✔
248
        """Sets the appropriate transform budget parameters based on the given sample budget.
249

250
        This method calculates the maximum possible 'b' parameter (number of bits to subsample)
251
        that fits within the provided budget, then configures the query and decoder arguments
252
        accordingly. The actual number of samples that will be used is returned.
253

254
        Args:
255
            budget: The maximum number of samples allowed for the approximation.
256

257
        Returns:
258
            int: The actual number of samples that will be used, which is less than or equal to the
259
                budget.
260

261
        Raises:
262
            ValueError: If the budget is too low to compute the transform with acceptable parameters.
263
        """
264
        b = SubsampledSignalFourier.get_b_for_sample_budget(
1✔
265
            budget, self.n, self.degree_parameter, 2, self.query_args
266
        )
267
        used_budget = SubsampledSignalFourier.get_number_of_samples(
1✔
268
            self.n, b, self.degree_parameter, 2, self.query_args
269
        )
270

271
        if b <= 2:
1✔
272
            while self.degree_parameter > 2:
1✔
273
                self.degree_parameter -= 1
1✔
274
                self.query_args["t"] = self.degree_parameter
1✔
275

276
                # Recalculate 'b' with the updated 't'
277
                b = SubsampledSignalFourier.get_b_for_sample_budget(
1✔
278
                    budget, self.n, self.degree_parameter, 2, self.query_args
279
                )
280

281
                # Compute the used budget
282
                used_budget = SubsampledSignalFourier.get_number_of_samples(
1✔
283
                    self.n, b, self.degree_parameter, 2, self.query_args
284
                )
285

286
                # Break if 'b' is now sufficient
287
                if b > 2:
1✔
288
                    self.decoder_args["source_decoder"] = get_bch_decoder(
1✔
289
                        self.n, self.degree_parameter, self.decoder_type
290
                    )
291
                    break
1✔
292

293
            # If 'b' is still too low, raise an error
294
            if b <= 2:
1✔
295
                msg = (
1✔
296
                    "Insufficient budget to compute the transform. Increase the budget or use a "
297
                    "different approximator."
298
                )
299
                raise ValueError(msg)
1✔
300
        # Store the final 'b' value
301
        self.query_args["b"] = b
1✔
302
        self.decoder_args["b"] = b
1✔
303
        return used_budget
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc