• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bethgelab / foolbox / 8139141456

04 Mar 2024 11:03AM UTC coverage: 37.923% (-60.6%) from 98.477%
8139141456

Pull #722

github

web-flow
Merge 5663238db into 17e0e9b31
Pull Request #722: Fix guide compilation

1344 of 3544 relevant lines covered (37.92%)

0.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

25.42
/foolbox/attacks/ead.py
1
from typing import Union, Tuple, Any, Optional
1✔
2
from typing_extensions import Literal
1✔
3

4
import math
1✔
5

6
import eagerpy as ep
1✔
7

8
from ..models import Model
1✔
9

10
from ..criteria import Misclassification, TargetedMisclassification
1✔
11

12
from ..distances import l1
1✔
13

14
from ..devutils import atleast_kd, flatten
1✔
15

16
from .base import MinimizationAttack
1✔
17
from .base import get_criterion
1✔
18
from .base import T
1✔
19
from .base import raise_if_kwargs
1✔
20
from .base import verify_input_bounds
1✔
21

22

23
class EADAttack(MinimizationAttack):
1✔
24
    """Implementation of the EAD Attack with EN Decision Rule. [#Chen18]_
25

26
    Args:
27
        binary_search_steps : Number of steps to perform in the binary search
28
            over the const c.
29
        steps : Number of optimization steps within each binary search step.
30
        initial_stepsize : Initial stepsize to update the examples.
31
        confidence : Confidence required for an example to be marked as adversarial.
32
            Controls the gap between example and decision boundary.
33
        initial_const : Initial value of the const c with which the binary search starts.
34
        regularization : Controls the L1 regularization.
35
        decision_rule : Rule according to which the best adversarial examples are selected.
36
            They either minimize the L1 or ElasticNet distance.
37
        abort_early : Stop inner search as soons as an adversarial example has been found.
38
            Does not affect the binary search over the const c.
39

40
    References:
41
        .. [#Chen18] Pin-Yu Chen, Yash Sharma, Huan Zhang, Jinfeng Yi, Cho-Jui Hsieh,
42
        "EAD: Elastic-Net Attacks to Deep Neural Networks via Adversarial Examples",
43
        https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/viewPaper/16893
44
    """
45

46
    distance = l1
1✔
47

48
    def __init__(
1✔
49
        self,
50
        binary_search_steps: int = 9,
51
        steps: int = 10000,
52
        initial_stepsize: float = 1e-2,
53
        confidence: float = 0.0,
54
        initial_const: float = 1e-3,
55
        regularization: float = 1e-2,
56
        decision_rule: Union[Literal["EN"], Literal["L1"]] = "EN",
57
        abort_early: bool = True,
58
    ):
59
        if decision_rule not in ("EN", "L1"):
1✔
60
            raise ValueError("invalid decision rule")
1✔
61

62
        self.binary_search_steps = binary_search_steps
1✔
63
        self.steps = steps
1✔
64
        self.confidence = confidence
1✔
65
        self.initial_stepsize = initial_stepsize
1✔
66
        self.regularization = regularization
1✔
67
        self.initial_const = initial_const
1✔
68
        self.abort_early = abort_early
1✔
69
        self.decision_rule = decision_rule
1✔
70

71
    def run(
1✔
72
        self,
73
        model: Model,
74
        inputs: T,
75
        criterion: Union[Misclassification, TargetedMisclassification, T],
76
        *,
77
        early_stop: Optional[float] = None,
78
        **kwargs: Any,
79
    ) -> T:
80
        raise_if_kwargs(kwargs)
×
81
        x, restore_type = ep.astensor_(inputs)
×
82
        criterion_ = get_criterion(criterion)
×
83
        del inputs, criterion, kwargs
×
84

85
        verify_input_bounds(x, model)
×
86

87
        N = len(x)
×
88

89
        if isinstance(criterion_, Misclassification):
×
90
            targeted = False
×
91
            classes = criterion_.labels
×
92
            change_classes_logits = self.confidence
×
93
        elif isinstance(criterion_, TargetedMisclassification):
×
94
            targeted = True
×
95
            classes = criterion_.target_classes
×
96
            change_classes_logits = -self.confidence
×
97
        else:
98
            raise ValueError("unsupported criterion")
×
99

100
        def is_adversarial(perturbed: ep.Tensor, logits: ep.Tensor) -> ep.Tensor:
×
101
            if change_classes_logits != 0:
×
102
                logits += ep.onehot_like(logits, classes, value=change_classes_logits)
×
103
            return criterion_(perturbed, logits)
×
104

105
        if classes.shape != (N,):
×
106
            name = "target_classes" if targeted else "labels"
×
107
            raise ValueError(
×
108
                f"expected {name} to have shape ({N},), got {classes.shape}"
109
            )
110

111
        min_, max_ = model.bounds
×
112
        rows = range(N)
×
113

114
        def loss_fun(y_k: ep.Tensor, consts: ep.Tensor) -> Tuple[ep.Tensor, ep.Tensor]:
×
115
            assert y_k.shape == x.shape
×
116
            assert consts.shape == (N,)
×
117

118
            logits = model(y_k)
×
119

120
            if targeted:
×
121
                c_minimize = _best_other_classes(logits, classes)
×
122
                c_maximize = classes
×
123
            else:
124
                c_minimize = classes
×
125
                c_maximize = _best_other_classes(logits, classes)
×
126

127
            is_adv_loss = logits[rows, c_minimize] - logits[rows, c_maximize]
×
128
            assert is_adv_loss.shape == (N,)
×
129

130
            is_adv_loss = is_adv_loss + self.confidence
×
131
            is_adv_loss = ep.maximum(0, is_adv_loss)
×
132
            is_adv_loss = is_adv_loss * consts
×
133

134
            squared_norms = flatten(y_k - x).square().sum(axis=-1)
×
135
            loss = is_adv_loss.sum() + squared_norms.sum()
×
136
            return loss, logits
×
137

138
        loss_aux_and_grad = ep.value_and_grad_fn(x, loss_fun, has_aux=True)
×
139

140
        consts = self.initial_const * ep.ones(x, (N,))
×
141
        lower_bounds = ep.zeros(x, (N,))
×
142
        upper_bounds = ep.inf * ep.ones(x, (N,))
×
143

144
        best_advs = ep.zeros_like(x)
×
145
        best_advs_norms = ep.ones(x, (N,)) * ep.inf
×
146

147
        # the binary search searches for the smallest consts that produce adversarials
148
        for binary_search_step in range(self.binary_search_steps):
×
149
            if (
×
150
                binary_search_step == self.binary_search_steps - 1
151
                and self.binary_search_steps >= 10
152
            ):
153
                # in the last iteration, repeat the search once
154
                consts = ep.minimum(upper_bounds, 1e10)
×
155

156
            # create a new optimizer find the delta that minimizes the loss
157
            x_k = x
×
158
            y_k = x
×
159

160
            found_advs = ep.full(
×
161
                x, (N,), value=False
162
            ).bool()  # found adv with the current consts
163
            loss_at_previous_check = ep.inf
×
164

165
            for iteration in range(self.steps):
×
166
                # square-root learning rate decay
167
                stepsize = self.initial_stepsize * (1.0 - iteration / self.steps) ** 0.5
×
168

169
                loss, logits, gradient = loss_aux_and_grad(y_k, consts)
×
170

171
                x_k_old = x_k
×
172
                x_k = _project_shrinkage_thresholding(
×
173
                    y_k - stepsize * gradient, x, self.regularization, min_, max_
174
                )
175
                y_k = x_k + iteration / (iteration + 3.0) * (x_k - x_k_old)
×
176

177
                if self.abort_early and iteration % (math.ceil(self.steps / 10)) == 0:
×
178
                    # after each tenth of the iterations, check progress
179
                    if not loss.item() <= 0.9999 * loss_at_previous_check:
×
180
                        break  # stop optimization if there has been no progress
×
181
                    loss_at_previous_check = loss.item()
×
182

183
                found_advs_iter = is_adversarial(x_k, model(x_k))
×
184

185
                best_advs, best_advs_norms = _apply_decision_rule(
×
186
                    self.decision_rule,
187
                    self.regularization,
188
                    best_advs,
189
                    best_advs_norms,
190
                    x_k,
191
                    x,
192
                    found_advs_iter,
193
                )
194

195
                found_advs = ep.logical_or(found_advs, found_advs_iter)
×
196

197
            upper_bounds = ep.where(found_advs, consts, upper_bounds)
×
198
            lower_bounds = ep.where(found_advs, lower_bounds, consts)
×
199

200
            consts_exponential_search = consts * 10
×
201
            consts_binary_search = (lower_bounds + upper_bounds) / 2
×
202
            consts = ep.where(
×
203
                ep.isinf(upper_bounds), consts_exponential_search, consts_binary_search
204
            )
205

206
        return restore_type(best_advs)
×
207

208

209
def _best_other_classes(logits: ep.Tensor, exclude: ep.Tensor) -> ep.Tensor:
1✔
210
    other_logits = logits - ep.onehot_like(logits, exclude, value=ep.inf)
×
211
    return other_logits.argmax(axis=-1)
×
212

213

214
def _apply_decision_rule(
1✔
215
    decision_rule: Union[Literal["EN"], Literal["L1"]],
216
    beta: float,
217
    best_advs: ep.Tensor,
218
    best_advs_norms: ep.Tensor,
219
    x_k: ep.Tensor,
220
    x: ep.Tensor,
221
    found_advs: ep.Tensor,
222
) -> Tuple[ep.Tensor, ep.Tensor]:
223
    if decision_rule == "EN":
×
224
        norms = beta * flatten(x_k - x).abs().sum(axis=-1) + flatten(
×
225
            x_k - x
226
        ).square().sum(axis=-1)
227
    else:
228
        # decision rule = L1
229
        norms = flatten(x_k - x).abs().sum(axis=-1)
×
230

231
    new_best = ep.logical_and(norms < best_advs_norms, found_advs)
×
232
    new_best_kd = atleast_kd(new_best, best_advs.ndim)
×
233
    best_advs = ep.where(new_best_kd, x_k, best_advs)
×
234
    best_advs_norms = ep.where(new_best, norms, best_advs_norms)
×
235

236
    return best_advs, best_advs_norms
×
237

238

239
def _project_shrinkage_thresholding(
1✔
240
    z: ep.Tensor, x0: ep.Tensor, regularization: float, min_: float, max_: float
241
) -> ep.Tensor:
242
    """Performs the element-wise projected shrinkage-thresholding
243
    operation"""
244

245
    upper_mask = z - x0 > regularization
×
246
    lower_mask = z - x0 < -regularization
×
247

248
    projection = ep.where(upper_mask, ep.minimum(z - regularization, max_), x0)
×
249
    projection = ep.where(lower_mask, ep.maximum(z + regularization, min_), projection)
×
250

251
    return projection
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc