• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 17613931038

10 Sep 2025 12:36PM UTC coverage: 93.646% (-0.2%) from 93.845%
17613931038

Pull #431

github

web-flow
Merge 0344967e6 into dede390c9
Pull Request #431: Product kernel explainer

180 of 203 new or added lines in 14 files covered. (88.67%)

4 existing lines in 2 files now uncovered.

5099 of 5445 relevant lines covered (93.65%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.18
/src/shapiq/explainer/product_kernel/explainer.py
1
"""Implementation of the ProductKernelExplainer class."""
2

3
from __future__ import annotations
1✔
4

5
from typing import TYPE_CHECKING, Any
1✔
6

7
from shapiq import InteractionValues
1✔
8
from shapiq.explainer.base import Explainer
1✔
9
from shapiq.game_theory import get_computation_index
1✔
10

11
from .product_kernel import ProductKernelComputer, ProductKernelSHAPIQIndices
1✔
12
from .validation import validate_pk_model
1✔
13

14
if TYPE_CHECKING:
1✔
NEW
15
    import numpy as np
×
16

NEW
17
    from shapiq.utils.custom_types import Model
×
18

NEW
19
    from .base import ProductKernelModel
×
20

21

22
class ProductKernelExplainer(Explainer):
1✔
23
    """The ProductKernelExplainer class for product kernel-based models.
24

25
    The ProductKernelExplainer can be used with a variety of product kernel-based models. The explainer can handle both regression and
26
    classification models.
27

28
    References:
29
        -- [pkex-shapley] Majid Mohammadi and Siu Lun Chau, Krikamol Muandet. (2025). Computing Exact Shapley Values in Polynomial Time for Product-Kernel Methods. https://arxiv.org/abs/2505.16516
30

31
    Attributes:
32
        model: The product kernel model to explain. Can be a dictionary, a ProductKernelModel, or a list of ProductKernelModels.
33
        max_order: The maximum interaction order to be computed. Defaults to ``1``.
34
        min_order: The minimum interaction order to be computed. Defaults to ``0``.
35
        index: The type of interaction to be computed. Currently, only ``"SV"`` is supported.
36
        class_index: The class index of the model to explain. Defaults to ``None``, which will set the class index to ``1`` per default for classification models and is ignored for regression models.
37
    """
38

39
    def __init__(
1✔
40
        self,
41
        model: dict
42
        | ProductKernelModel
43
        | list[ProductKernelModel]
44
        | Model,  # TODO (IsaH57): check if list of models is needed (Issue #425)
45
        *,
46
        min_order: int = 0,
47
        max_order: int = 1,
48
        index: ProductKernelSHAPIQIndices = "SV",
49
        class_index: int | None = None,
50
        **kwargs: Any,  # noqa: ARG002
51
    ) -> None:
52
        """Initializes the ProductKernelExplainer.
53

54
        Args:
55
            model: A product kernel-based model to explain.
56

57
            min_order: The minimum interaction order to be computed. Defaults to ``0``.
58

59
            max_order: The maximum interaction order to be computed. An interaction order of ``1``
60
                corresponds to the Shapley value. Defaults to ``1``.
61

62
            index: The type of interaction to be computed. Currently, only ``"SV"`` is supported.
63

64
            class_index: The class index of the model to explain. Defaults to ``None``, which will
65
                set the class index to ``1`` per default for classification models and is ignored
66
                for regression models.
67

68
            **kwargs: Additional keyword arguments are ignored.
69

70
        """
71
        if max_order > 1:
1✔
72
            msg = "ProductKernelExplainer currently only supports max_order=1."
1✔
73
            raise ValueError(msg)
1✔
74

75
        super().__init__(model, index=index, max_order=max_order)
1✔
76

77
        self._min_order: int = min_order
1✔
78
        self._max_order: int = max_order
1✔
79

80
        self._index: ProductKernelSHAPIQIndices = index
1✔
81
        self._base_index: str = get_computation_index(self._index)
1✔
82

83
        self._class_label: int | None = class_index
1✔
84

85
        # validate model
86
        self.converted_model = validate_pk_model(model, class_label=class_index)
1✔
87

88
        self.explainer = ProductKernelComputer(
1✔
89
            model=self.converted_model,
90
            max_order=max_order,
91
            index=index,
92
        )
93

94
        self.empty_prediction = self._compute_baseline_value()
1✔
95

96
    def explain_function(
1✔
97
        self,
98
        x: np.ndarray,
99
        **kwargs: Any,  # noqa: ARG002
100
    ) -> InteractionValues:
101
        """Compute Shapley values for all features of an instance.
102

103
        Args:
104
           x: The instance (1D array) for which to compute Shapley values.
105
           **kwargs: Additional keyword arguments are ignored.
106

107
        Returns:
108
           The interaction values for the instance.
109
        """
110
        n_players = self.converted_model.d
1✔
111

112
        # compute the kernel vectors for the instance x
113
        kernel_vectors = self.explainer.compute_kernel_vectors(self.converted_model.X_train, x)
1✔
114

115
        shapley_values = {}
1✔
116
        for j in range(self.converted_model.d):
1✔
117
            shapley_values.update({(j,): self.explainer.compute_shapley_value(kernel_vectors, j)})
1✔
118

119
        return InteractionValues(
1✔
120
            values=shapley_values,
121
            index=self._base_index,
122
            min_order=self._min_order,
123
            max_order=self.max_order,
124
            n_players=n_players,
125
            estimated=False,
126
            baseline_value=self.empty_prediction,
127
            target_index=self._index,
128
        )
129

130
    def _compute_baseline_value(self) -> float:
1✔
131
        """Computes the baseline value for the explainer.
132

133
        Returns:
134
            The baseline value for the explainer.
135

136
        """
137
        return self.converted_model.alpha.sum() + self.converted_model.intercept
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc