• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bethgelab / foolbox / 8139141456

04 Mar 2024 11:03AM UTC coverage: 37.923% (-60.6%) from 98.477%
8139141456

Pull #722

github

web-flow
Merge 5663238db into 17e0e9b31
Pull Request #722: Fix guide compilation

1344 of 3544 relevant lines covered (37.92%)

0.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

26.23
/foolbox/attacks/binarization.py
1
from typing import Union, Optional, Any
1✔
2
from typing_extensions import Literal
1✔
3
import eagerpy as ep
1✔
4
import numpy as np
1✔
5

6
from ..models import Model
1✔
7

8
from ..criteria import Criterion
1✔
9

10
from ..distances import Distance
1✔
11

12
from .base import FlexibleDistanceMinimizationAttack
1✔
13
from .base import T
1✔
14
from .base import get_is_adversarial
1✔
15
from .base import get_criterion
1✔
16
from .base import raise_if_kwargs
1✔
17
from .base import verify_input_bounds
1✔
18

19

20
class BinarizationRefinementAttack(FlexibleDistanceMinimizationAttack):
1✔
21
    """For models that preprocess their inputs by binarizing the
22
    inputs, this attack can improve adversarials found by other
23
    attacks. It does this by utilizing information about the
24
    binarization and mapping values to the corresponding value in
25
    the clean input or to the right side of the threshold.
26

27
    Args:
28
        threshold : The threshold used by the models binarization. If none,
29
            defaults to (model.bounds()[1] - model.bounds()[0]) / 2.
30
        included_in : Whether the threshold value itself belongs to the lower or
31
            upper interval.
32
    """
33

34
    def __init__(
1✔
35
        self,
36
        *,
37
        distance: Optional[Distance] = None,
38
        threshold: Optional[float] = None,
39
        included_in: Union[Literal["lower"], Literal["upper"]] = "upper",
40
    ):
41
        super().__init__(distance=distance)
×
42
        self.threshold = threshold
×
43
        self.included_in = included_in
×
44

45
    def run(
1✔
46
        self,
47
        model: Model,
48
        inputs: T,
49
        criterion: Union[Criterion, T],
50
        *,
51
        early_stop: Optional[float] = None,
52
        starting_points: Optional[T] = None,
53
        **kwargs: Any,
54
    ) -> T:
55
        raise_if_kwargs(kwargs)
×
56
        if starting_points is None:
×
57
            raise ValueError("BinarizationRefinementAttack requires starting_points")
×
58
        (o, x), restore_type = ep.astensors_(inputs, starting_points)
×
59
        del inputs, starting_points, kwargs
×
60

61
        verify_input_bounds(x, model)
×
62

63
        criterion = get_criterion(criterion)
×
64
        is_adversarial = get_is_adversarial(criterion, model)
×
65

66
        if self.threshold is None:
×
67
            min_, max_ = model.bounds
×
68
            threshold = (min_ + max_) / 2.0
×
69
        else:
70
            threshold = self.threshold
×
71

72
        assert o.dtype == x.dtype
×
73

74
        nptype = o.reshape(-1)[0].numpy().dtype.type
×
75
        if nptype not in [np.float16, np.float32, np.float64]:
×
76
            raise ValueError(  # pragma: no cover
77
                f"expected dtype to be float16, float32 or float64, found '{nptype}'"
78
            )
79

80
        threshold = nptype(threshold)
×
81
        offset = nptype(1.0)
×
82

83
        if self.included_in == "lower":
×
84
            lower_ = threshold
×
85
            upper_ = np.nextafter(threshold, threshold + offset)
×
86
        elif self.included_in == "upper":
×
87
            lower_ = np.nextafter(threshold, threshold - offset)
×
88
            upper_ = threshold
×
89
        else:
90
            raise ValueError(
×
91
                f"expected included_in to be 'lower' or 'upper', found '{self.included_in}'"
92
            )
93

94
        assert lower_ < upper_
×
95

96
        p = ep.full_like(o, ep.nan)
×
97

98
        lower = ep.ones_like(o) * lower_
×
99
        upper = ep.ones_like(o) * upper_
×
100

101
        indices = ep.logical_and(o <= lower, x <= lower)
×
102
        p = ep.where(indices, o, p)
×
103

104
        indices = ep.logical_and(o <= lower, x >= upper)
×
105
        p = ep.where(indices, upper, p)
×
106

107
        indices = ep.logical_and(o >= upper, x <= lower)
×
108
        p = ep.where(indices, lower, p)
×
109

110
        indices = ep.logical_and(o >= upper, x >= upper)
×
111
        p = ep.where(indices, o, p)
×
112

113
        assert not ep.any(ep.isnan(p))
×
114

115
        is_adv1 = is_adversarial(x)
×
116
        is_adv2 = is_adversarial(p)
×
117
        if (is_adv1 != is_adv2).any():
×
118
            raise ValueError(
×
119
                "The specified threshold does not match what is done by the model."
120
            )
121
        return restore_type(p)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc