• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mmschlk / shapiq / 18754124332

23 Oct 2025 03:48PM UTC coverage: 92.865% (-0.2%) from 93.032%
18754124332

Pull #431

github

web-flow
Merge e4a9a83cb into 830c6bc23
Pull Request #431: Product kernel explainer

186 of 210 new or added lines in 13 files covered. (88.57%)

2 existing lines in 1 file now uncovered.

5375 of 5788 relevant lines covered (92.86%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.43
/src/shapiq/explainer/product_kernel/conversion.py
1
"""Functions for converting scikit-learn models to a format used by shapiq."""
2

3
from __future__ import annotations
1✔
4

5
from typing import TYPE_CHECKING
1✔
6

7
import numpy as np
1✔
8

9
from shapiq.explainer.product_kernel.base import ProductKernelModel
1✔
10

11
if TYPE_CHECKING:
1✔
NEW
12
    from sklearn.gaussian_process import GaussianProcessRegressor
×
NEW
13
    from sklearn.svm import SVC, SVR
×
14

15

16
def convert_svm(model: SVC | SVR) -> ProductKernelModel:
1✔
17
    """Converts a scikit-learn SVM model to the product kernel format used by shapiq.
18

19
    Args:
20
        model: The scikit-learn SVM model to convert. Can be either a binary support vector classifier (SVC) or a support vector regressor (SVR).
21

22
    Returns:
23
        ProductKernelModel: The converted model in the product kernel format.
24

25
    """
26
    X_train = model.support_vectors_
1✔
27
    n, d = X_train.shape
1✔
28

29
    if hasattr(model, "kernel"):
1✔
30
        kernel_type = model.kernel  # pyright: ignore[reportAttributeAccessIssue]
1✔
31
        if kernel_type != "rbf":
1✔
32
            msg = "Currently only RBF kernel is supported for SVM models."
1✔
33
            raise ValueError(msg)
1✔
34
    else:
NEW
35
        msg = "Kernel type not found in the model. Ensure the model is a valid SVC or SVR."
×
NEW
36
        raise ValueError(msg)
×
37

38
    return ProductKernelModel(
1✔
39
        alpha=model.dual_coef_.flatten(),  # pyright: ignore[reportAttributeAccessIssue]
40
        X_train=X_train,
41
        n=n,
42
        d=d,
43
        gamma=model._gamma,  # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] # noqa: SLF001
44
        kernel_type=kernel_type,
45
        intercept=model.intercept_[0],
46
    )
47

48

49
def convert_gp_reg(model: GaussianProcessRegressor) -> ProductKernelModel:
1✔
50
    """Converts a scikit-learn Gaussian Process Regression model to the product kernel format used by shapiq.
51

52
    Args:
53
        model: The scikit-learn Gaussian Process Regression model to convert.
54

55
    Returns:
56
        ProductKernelModel: The converted model in the product kernel format.
57

58
    """
59
    X_train = np.array(model.X_train_)
1✔
60
    n, d = X_train.shape
1✔
61

62
    if hasattr(model, "kernel"):
1✔
63
        kernel_type = model.kernel_.__class__.__name__.lower()  # Get the kernel type as a string
1✔
64
        if kernel_type != "rbf":
1✔
NEW
65
            msg = "Currently only RBF kernel is supported for Gaussian Process Regression models."
×
NEW
66
            raise ValueError(msg)
×
67
    else:
NEW
68
        msg = "Kernel type not found in the model. Ensure the model is a valid Gaussian Process Regressor."
×
NEW
69
        raise ValueError(msg)
×
70

71
    alphas = np.array(model.alpha_).flatten()
1✔
72
    parameters = (
1✔
73
        model.kernel_.get_params()  # pyright: ignore[reportAttributeAccessIssue]
74
    )
75
    if "length_scale" in parameters:
1✔
76
        length_scale = parameters["length_scale"]
1✔
77
    else:
NEW
78
        msg = "Length scale parameter not found in the kernel."
×
NEW
79
        raise ValueError(msg)
×
80

81
    return ProductKernelModel(
1✔
82
        alpha=alphas,
83
        X_train=X_train,
84
        n=n,
85
        d=d,
86
        gamma=(2 * (length_scale**2)) ** -1,
87
        kernel_type=kernel_type,
88
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc