• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18777985242

24 Oct 2025 11:10AM UTC coverage: 93.119% (+0.09%) from 93.032%
18777985242

Pull #442

github

web-flow
Merge 1f371daec into 830c6bc23
Pull Request #442: ProxySPEX Approximator

94 of 94 new or added lines in 4 files covered. (100.0%)

2 existing lines in 1 file now uncovered.

5278 of 5668 relevant lines covered (93.12%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.67
/src/shapiq/approximator/sparse/base.py
1
"""Base Sparse approximator for fourier-based interaction computation."""
2

3
from __future__ import annotations
1✔
4

5
import copy
1✔
6
import math
1✔
7
from collections import defaultdict
1✔
8
from typing import TYPE_CHECKING, Literal, cast, get_args
1✔
9

10
import numpy as np
1✔
11
import pandas as pd
1✔
12
from sklearn.linear_model import RidgeCV
1✔
13
from sklearn.model_selection import GridSearchCV
1✔
14
from sparse_transform.qsft.qsft import transform as sparse_fourier_transform
1✔
15
from sparse_transform.qsft.signals.input_signal_subsampled import (
1✔
16
    SubsampledSignal as SubsampledSignalFourier,
17
)
18
from sparse_transform.qsft.utils.general import fourier_to_mobius as fourier_to_moebius
1✔
19
from sparse_transform.qsft.utils.query import get_bch_decoder
1✔
20

21
from shapiq.approximator.base import Approximator
1✔
22
from shapiq.approximator.sampling import CoalitionSampler
1✔
23
from shapiq.game_theory.moebius_converter import MoebiusConverter, ValidMoebiusConverterIndices
1✔
24
from shapiq.interaction_values import InteractionValues
1✔
25

26
if TYPE_CHECKING:
1✔
27
    from collections.abc import Callable
×
28
    from typing import Any
×
29

30
    from shapiq.game import Game
×
31

32
ValidSparseIndices = ValidMoebiusConverterIndices
1✔
33

34

35
class Sparse(Approximator[ValidSparseIndices]):
1✔
36
    """Approximator interface using sparse transformation techniques.
37

38
    This class implements a sparse approximation method for computing various interaction indices
39
    using sparse Fourier transforms. It efficiently estimates interaction values with a limited
40
    sample budget by leveraging sparsity in the Fourier domain. The notion of sparse approximation
41
    is described in [Kan25]_ and further improved in [But25]_.
42

43
    See Also:
44
        - :class:`~shapiq.approximator.sparse.SPEX` for a specific implementation of the
45
            sparse approximation using Fourier transforms described in [Kan25]_.
46
        - :class:`~shapiq.approximator.sparse.ProxySPEX` for a specific implementation of the
47
            sparse approximation using Fourier transforms described in [But25]_.
48

49
    Attributes:
50
        transform_type: Type of transform used (currently only ``"fourier"`` is supported).
51

52
        degree_parameter: A parameter that controls the maximum degree of the interactions to
53
            extract during execution of the algorithm. Note that this is a soft limit, and in
54
            practice, the algorithm may extract interactions of any degree. We typically find
55
            that there is little value going beyond ``5``. Defaults to ``5``. Note that
56
            increasing this parameter will need more ``budget`` in the :meth:`approximate`
57
            method.
58

59
        query_args: Parameters for querying the signal.
60

61
        decoder_args: Parameters for decoding the transform.
62

63
    Raises:
64
        ValueError: If transform_type is not "fourier" or if decoder_type is not "soft" or "hard".
65

66
    References:
67
        .. [Kan25] Kang, J.S., Butler, L., Agarwal. A., Erginbas, Y.E., Pedarsani, R., Ramchandran, K., Yu, Bin (2025). SPEX: Scaling Feature Interaction Explanations for LLMs https://arxiv.org/abs/2502.13870
68
        .. [But25] Butler, L., Kang, J.S., Agarwal. A., Erginbas, Y.E., Yu, Bin, Ramchandran, K. (2025). ProxySPEX: Inference-Efficient Interpretability via Sparse Feature Interactions in LLMs https://arxiv.org/pdf/2505.17495
69
    """
70

71
    valid_indices: tuple[ValidSparseIndices, ...] = tuple(get_args(ValidSparseIndices))  # type: ignore[assignment]
1✔
72
    """The valid indices for the SPEX approximator."""
1✔
73

74
    def __init__(
1✔
75
        self,
76
        n: int,
77
        index: ValidSparseIndices,
78
        *,
79
        max_order: int | None = None,
80
        top_order: bool = False,
81
        random_state: int | None = None,
82
        transform_type: Literal["fourier"] = "fourier",
83
        decoder_type: Literal["soft", "hard", "proxyspex"] | None = "proxyspex",
84
        degree_parameter: int = 5,
85
    ) -> None:
86
        """Initialize the Sparse approximator.
87

88
        Args:
89
            n: Number of players (features).
90

91
            max_order: Maximum interaction order to consider. Defaults to ``None``, which means
92
                that all orders up to ``n`` will be considered.
93

94
            index: The Interaction index to use. All indices supported by shapiq's
95
                :class:`~shapiq.game_theory.moebius_converter.MoebiusConverter` are supported.
96

97
            top_order: If ``True``, only reports interactions of exactly order ``max_order``.
98
                Otherwise, reports all interactions up to order ``max_order``. Defaults to
99
                ``False``.
100

101
            random_state: Seed for random number generator. Defaults to ``None``.
102

103
            transform_type: Type of transform to use. Currently only "fourier" is supported.
104

105
            decoder_type: Type of decoder to use, either "soft", "hard", or "proxyspex". Defaults to "proxyspex".
106

107
            degree_parameter: A parameter that controls the maximum degree of the interactions to
108
                extract during execution of the algorithm. Note that this is a soft limit, and in
109
                practice, the algorithm may extract interactions of any degree. We typically find
110
                that there is little value going beyond ``5``. Defaults to ``5``. Note that
111
                increasing this parameter will need more ``budget`` in the :meth:`approximate`
112
                method.
113

114
        """
115
        if transform_type.lower() not in ["fourier"]:
1✔
UNCOV
116
            msg = "transform_type must be 'fourier'"
×
UNCOV
117
            raise ValueError(msg)
×
118
        self.transform_type = transform_type.lower()
1✔
119
        self.degree_parameter = degree_parameter
1✔
120
        max_order = n if max_order is None else max_order
1✔
121
        self.decoder_type = "proxyspex" if decoder_type is None else decoder_type.lower()
1✔
122
        if self.decoder_type not in ["soft", "hard", "proxyspex"]:
1✔
123
            msg = "decoder_type must be 'soft', 'hard', or 'proxyspex'"
1✔
124
            raise ValueError(msg)
1✔
125
        if self.decoder_type == "proxyspex":
1✔
126
            try:
1✔
127
                import lightgbm as lgb  # noqa: F401
1✔
128
            except ImportError as err:
1✔
129
                msg = (
1✔
130
                    "The 'lightgbm' package is required when decoder_type is 'proxyspex' but it is "
131
                    "not installed. Please see the installation instructions at "
132
                    "https://github.com/microsoft/LightGBM/tree/master/python-package."
133
                )
134
                raise ImportError(msg) from err
1✔
135
        # The sampling parameters for the Fourier transform
136
        self.query_args = {
1✔
137
            "query_method": "complex",
138
            "num_subsample": 3,
139
            "delays_method_source": "joint-coded",
140
            "subsampling_method": "qsft",
141
            "delays_method_channel": "identity-siso",
142
            "num_repeat": 1,
143
            "t": self.degree_parameter,
144
        }
145
        if self.decoder_type == "proxyspex":
1✔
146
            self.decoder_args = {
1✔
147
                "max_depth": [3, 5],
148
                "max_iter": [500, 1000],
149
                "learning_rate": [0.01, 0.1],
150
            }
151
            self._uniform_sampler = CoalitionSampler(
1✔
152
                n_players=n,
153
                sampling_weights=np.array([math.comb(n, i) for i in range(n + 1)], dtype=float),
154
                pairing_trick=True,
155
                random_state=random_state,
156
            )
157
        else:
158
            self.decoder_args = {
1✔
159
                "num_subsample": 3,
160
                "num_repeat": 1,
161
                "reconstruct_method_source": "coded",
162
                "peeling_method": "multi-detect",
163
                "reconstruct_method_channel": "identity-siso"
164
                if self.decoder_type == "soft"
165
                else "identity",
166
                "regress": "lasso",
167
                "res_energy_cutoff": 0.9,
168
                "source_decoder": get_bch_decoder(n, self.degree_parameter, self.decoder_type),
169
            }
170
        super().__init__(
1✔
171
            n=n,
172
            max_order=max_order,
173
            index=index,
174
            top_order=top_order,
175
            random_state=random_state,
176
            initialize_dict=False,  # Important for performance
177
        )
178

179
    def approximate(
1✔
180
        self,
181
        budget: int,
182
        game: Game | Callable[[np.ndarray], np.ndarray],
183
        **kwargs: Any,  # noqa: ARG002
184
    ) -> InteractionValues:
185
        """Approximates the interaction values using a sparse transform approach.
186

187
        Args:
188
            budget: The budget for the approximation.
189
            game: The game function that returns the values for the coalitions.
190
            **kwargs: Additional keyword arguments (not used).
191

192
        Returns:
193
            The approximated Shapley interaction values.
194
        """
195
        if self.decoder_type == "proxyspex":
1✔
196
            import lightgbm as lgb
1✔
197

198
            used_budget = budget
1✔
199

200
            # Take the budget amount of uniform samples
201
            self._uniform_sampler.sample(budget)
1✔
202

203
            train_X = pd.DataFrame(
1✔
204
                self._uniform_sampler.coalitions_matrix,
205
                columns=np.array([f"f{i}" for i in range(self.n)]),
206
            )
207
            train_y = game(self._uniform_sampler.coalitions_matrix)
1✔
208

209
            base_model = lgb.LGBMRegressor(verbose=-1, n_jobs=1, random_state=self._random_state)
1✔
210

211
            # Set up GridSearchCV with cross-validation
212
            grid_search = GridSearchCV(
1✔
213
                estimator=base_model,
214
                param_grid=self.decoder_args,
215
                scoring="r2",
216
                cv=5,
217
                verbose=0,
218
                n_jobs=1,
219
            )
220

221
            # Fit the model on the training data
222
            grid_search.fit(train_X, train_y)
1✔
223

224
            best_model = grid_search.best_estimator_
1✔
225

226
            initial_transform = self._refine(
1✔
227
                self._lgboost_to_fourier(best_model.booster_.dump_model()),
228
                self._uniform_sampler.coalitions_matrix,
229
                train_y,
230
            )
231
        else:
232
            # Find the max value of b that fits within the given sample budget and get the used budget
233
            used_budget = self._set_transform_budget(budget)
1✔
234
            signal = SubsampledSignalFourier(
1✔
235
                func=lambda inputs: game(inputs.astype(bool)),
236
                n=self.n,
237
                q=2,
238
                query_args=self.query_args,
239
            )
240
            # Extract the coefficients of the original transform
241
            initial_transform = {
1✔
242
                tuple(np.nonzero(key)[0]): np.real(value)
243
                for key, value in sparse_fourier_transform(signal, **self.decoder_args).items()
244
            }
245
        # If we are using the fourier transform, we need to convert it to a Moebius transform
246
        moebius_transform = fourier_to_moebius(initial_transform)
1✔
247
        # Convert the Moebius transform to the desired index
248
        result = self._process_moebius(moebius_transform=moebius_transform)
1✔
249
        # Filter the output as needed
250
        if self.top_order:
1✔
251
            result = self._filter_order(result)
1✔
252
        # finalize the interactions
253
        return InteractionValues(
1✔
254
            values=result,
255
            index=self.approximation_index,
256
            min_order=self.min_order,
257
            max_order=self.max_order,
258
            n_players=self.n,
259
            interaction_lookup=copy.deepcopy(self.interaction_lookup),
260
            estimated=True,
261
            estimation_budget=used_budget,
262
            baseline_value=self.interaction_lookup.get((), 0.0),
263
            target_index=self.index,
264
        )
265

266
    def _filter_order(self, result: np.ndarray) -> np.ndarray:
1✔
267
        """Filters the interactions to keep only those of the maximum order.
268

269
        This method is used when top_order=True to filter out all interactions that are not
270
        of exactly the maximum order (self.max_order).
271

272
        Args:
273
            result: Array of interaction values.
274

275
        Returns:
276
            Filtered array containing only interaction values of the maximum order.
277
            The method also updates the internal _interaction_lookup dictionary.
278
        """
279
        filtered_interactions = {}
1✔
280
        filtered_results = []
1✔
281
        i = 0
1✔
282
        for j, key in enumerate(self.interaction_lookup):
1✔
283
            if len(key) == self.max_order:
1✔
284
                filtered_interactions[key] = i
1✔
285
                filtered_results.append(result[j])
1✔
286
                i += 1
1✔
287
        self._interaction_lookup = filtered_interactions
1✔
288
        return np.array(filtered_results)
1✔
289

290
    def _process_moebius(self, moebius_transform: dict[tuple, float]) -> np.ndarray:
1✔
291
        """Convert the Moebius transform into the desired index.
292

293
        Args:
294
            moebius_transform: The Moebius transform to process as a dict mapping tuples to float
295
                values.
296

297
        Returns:
298
            np.ndarray: The converted interaction values based on the specified index.
299
            The function also updates the internal _interaction_lookup dictionary.
300
        """
301
        moebius_interactions = InteractionValues(
1✔
302
            values=np.array([moebius_transform[key] for key in moebius_transform]),
303
            index="Moebius",
304
            min_order=self.min_order,
305
            max_order=self.max_order,
306
            n_players=self.n,
307
            interaction_lookup={key: i for i, key in enumerate(moebius_transform.keys())},
308
            estimated=True,
309
            baseline_value=moebius_transform.get((), 0.0),
310
        )
311
        autoconverter = MoebiusConverter(moebius_coefficients=moebius_interactions)
1✔
312
        converted_interaction_values = autoconverter(
1✔
313
            index=cast(ValidMoebiusConverterIndices, self.index), order=self.max_order
314
        )
315
        self._interaction_lookup = converted_interaction_values.interaction_lookup
1✔
316
        return converted_interaction_values.values  # noqa: PD011
1✔
317

318
    def _set_transform_budget(self, budget: int) -> int:
1✔
319
        """Sets the appropriate transform budget parameters based on the given sample budget.
320

321
        This method calculates the maximum possible 'b' parameter (number of bits to subsample)
322
        that fits within the provided budget, then configures the query and decoder arguments
323
        accordingly. The actual number of samples that will be used is returned.
324

325
        Args:
326
            budget: The maximum number of samples allowed for the approximation.
327

328
        Returns:
329
            int: The actual number of samples that will be used, which is less than or equal to the
330
                budget.
331

332
        Raises:
333
            ValueError: If the budget is too low to compute the transform with acceptable parameters.
334
        """
335
        b = SubsampledSignalFourier.get_b_for_sample_budget(
1✔
336
            budget, self.n, self.degree_parameter, 2, self.query_args
337
        )
338
        used_budget = SubsampledSignalFourier.get_number_of_samples(
1✔
339
            self.n, b, self.degree_parameter, 2, self.query_args
340
        )
341

342
        if b <= 2:
1✔
343
            while self.degree_parameter > 2:
1✔
344
                self.degree_parameter -= 1
1✔
345
                self.query_args["t"] = self.degree_parameter
1✔
346

347
                # Recalculate 'b' with the updated 't'
348
                b = SubsampledSignalFourier.get_b_for_sample_budget(
1✔
349
                    budget, self.n, self.degree_parameter, 2, self.query_args
350
                )
351

352
                # Compute the used budget
353
                used_budget = SubsampledSignalFourier.get_number_of_samples(
1✔
354
                    self.n, b, self.degree_parameter, 2, self.query_args
355
                )
356

357
                # Break if 'b' is now sufficient
358
                if b > 2:
1✔
359
                    self.decoder_args["source_decoder"] = get_bch_decoder(
1✔
360
                        self.n, self.degree_parameter, self.decoder_type
361
                    )
362
                    break
1✔
363

364
            # If 'b' is still too low, raise an error
365
            if b <= 2:
1✔
366
                msg = (
1✔
367
                    "Insufficient budget to compute the transform. Increase the budget or use a "
368
                    "different approximator."
369
                )
370
                raise ValueError(msg)
1✔
371
        # Store the final 'b' value
372
        self.query_args["b"] = b
1✔
373
        self.decoder_args["b"] = b
1✔
374
        return used_budget
1✔
375

376
    def _lgboost_to_fourier(self, model_dict: dict[str, Any]) -> dict[tuple[int, ...], float]:
1✔
377
        """Extracts the aggregated Fourier coefficients from an LGBoost model dictionary.
378

379
        This method iterates over all trees in the LightGBM ensemble, computes the
380
        Fourier coefficients for each individual tree using the `_lgboost_tree_to_fourier`
381
        helper method, and then sums these coefficients to get the final Fourier
382
        representation of the complete model.
383

384
        Args:
385
        model_dict: A dictionary representing the trained LGBoost model, as
386
            produced by `model.booster_.dump_model()`.
387

388
        Returns:
389
            A dictionary that maps interaction tuples (representing Fourier frequencies)
390
            to their aggregated Fourier coefficients.
391
        """
392
        aggregated_coeffs = defaultdict(float)
1✔
393

394
        for tree_info in model_dict["tree_info"]:
1✔
395
            tree_coeffs = self._lgboost_tree_to_fourier(tree_info)
1✔
396
            for interaction, value in tree_coeffs.items():
1✔
397
                aggregated_coeffs[interaction] += value
1✔
398

399
        # Convert defaultdict to a standard dict, removing zero-valued coefficients
400
        return {k: v for k, v in aggregated_coeffs.items() if v != 0.0}
1✔
401

402
    def _lgboost_tree_to_fourier(self, tree_info: dict[str, Any]) -> dict[tuple[int, ...], float]:
1✔
403
        """Recursively strips the Fourier coefficients from a single LGBoost tree.
404

405
        This method traverses a tree's structure, as provided by LightGBM's `dump_model`
406
        method, and computes the Fourier representation of the piecewise-constant
407
        function that the tree defines. The logic is adapted from the work by Gorji et al. (2024).
408

409
        Args:
410
            tree_info: A dictionary representing a single decision tree from an LGBM model.
411

412
        Returns:
413
            A dictionary mapping interaction tuples to their corresponding coefficients for
414
            the single tree.
415

416
        References:
417
            Gorji, Ali, Andisheh Amrollahi, and Andreas Krause.
418
            "SHAP values via sparse Fourier representation"
419
            arXiv preprint arXiv:2410.06300 (2024).
420
        """
421

422
        def _combine_coeffs(
1✔
423
            left_coeffs: dict[tuple[int, ...], float],
424
            right_coeffs: dict[tuple[int, ...], float],
425
            feature_idx: int,
426
        ) -> dict[tuple[int, ...], float]:
427
            """Combines Fourier coefficients from the left and right children of a split node."""
428
            combined_coeffs = {}
1✔
429
            all_interactions = set(left_coeffs.keys()) | set(right_coeffs.keys())
1✔
430

431
            for interaction in all_interactions:
1✔
432
                left_val = left_coeffs.get(interaction, 0.0)
1✔
433
                right_val = right_coeffs.get(interaction, 0.0)
1✔
434
                combined_coeffs[interaction] = (left_val + right_val) / 2
1✔
435

436
                new_interaction = tuple(sorted(set(interaction) | {feature_idx}))
1✔
437
                combined_coeffs[new_interaction] = (left_val - right_val) / 2
1✔
438
            return combined_coeffs
1✔
439

440
        def _dfs_traverse(node: dict[str, Any]) -> dict[tuple[int, ...], float]:
1✔
441
            """Performs a depth-first traversal of the tree to compute coefficients."""
442
            # Base case: if the node is a leaf, its function is a constant.
443
            if "leaf_value" in node:
1✔
444
                # The only non-zero coefficient is for the empty interaction (the bias term).
445
                return {(): node["leaf_value"]}
1✔
446
            # Recursive step: if the node is a split node.
447
            left_coeffs = _dfs_traverse(node["left_child"])
1✔
448
            right_coeffs = _dfs_traverse(node["right_child"])
1✔
449
            feature_idx = node["split_feature"]
1✔
450
            return _combine_coeffs(left_coeffs, right_coeffs, feature_idx)
1✔
451

452
        return _dfs_traverse(tree_info["tree_structure"])
1✔
453

454
    def _refine(
1✔
455
        self,
456
        four_dict: dict[tuple[int, ...], float],
457
        train_X: np.ndarray,
458
        train_y: np.ndarray,
459
    ) -> dict[tuple[int, ...], float]:
460
        """Refines the estimated Fourier coefficients using a Ridge regression model.
461

462
        This method takes an initial set of estimated Fourier coefficients and refines them to
463
        better fit the observed game values. It first identifies the most significant
464
        coefficients by keeping those that contribute to 95% of the total "energy" (sum of
465
        squared Fourier coefficients, excluding the baseline). Then, it constructs a new feature matrix
466
        based on the Fourier basis functions corresponding to these significant interactions.
467
        Finally, it fits a `RidgeCV` model to re-estimate the values of these coefficients,
468
        effectively fine-tuning them against the training data.
469

470
        Args:
471
            four_dict: A dictionary mapping interaction tuples to their initial estimated
472
                Fourier coefficient values.
473
            train_X: The training data matrix where rows are coalitions (binary vectors) and
474
                columns are players.
475
            train_y: The corresponding game values for each coalition in `train_X`.
476

477
        Returns:
478
            A dictionary containing the refined Fourier coefficients for the most significant
479
            interactions.
480
        """
481
        n = train_X.shape[1]
1✔
482
        four_items = list(four_dict.items())
1✔
483
        list_keys = [item[0] for item in four_items]
1✔
484
        four_coefs = np.array([item[1] for item in four_items])
1✔
485

486
        nfc_idx = list_keys.index(()) if () in list_keys else None
1✔
487

488
        four_coefs_for_energy = np.copy(four_coefs)
1✔
489
        if nfc_idx is not None:
1✔
490
            four_coefs_for_energy[nfc_idx] = 0
1✔
491
        four_coefs_sq = four_coefs_for_energy**2
1✔
492
        tot_energy = np.sum(four_coefs_sq)
1✔
493
        sorted_four_coefs_sq = np.sort(four_coefs_sq)[::-1]
1✔
494
        cumulative_energy_ratio = np.cumsum(sorted_four_coefs_sq / tot_energy)
1✔
495
        thresh_idx_95 = np.argmin(cumulative_energy_ratio < 0.95) + 1
1✔
496
        thresh = np.sqrt(sorted_four_coefs_sq[thresh_idx_95])
1✔
497

498
        four_dict_trunc = {
1✔
499
            tuple(int(i in k) for i in range(n)): v for k, v in four_dict.items() if abs(v) > thresh
500
        }
501
        support = np.array(list(four_dict_trunc.keys()))
1✔
502

503
        X = np.real(np.exp(train_X @ (1j * np.pi * support.T)))
1✔
504
        reg = RidgeCV(alphas=np.logspace(-6, 6, 100), fit_intercept=False).fit(X, train_y)
1✔
505

506
        regression_coefs = dict(
1✔
507
            zip([tuple(s.astype(int)) for s in support], reg.coef_, strict=False)
508
        )
509
        return {tuple(i for i, x in enumerate(k) if x): v for k, v in regression_coefs.items()}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc