• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bethgelab / foolbox / 8139224398

04 Mar 2024 11:11AM UTC coverage: 98.477%. Remained the same
8139224398

push

github

web-flow
Fix issue in guide compilation (#723)

* Fix terser error

3492 of 3546 relevant lines covered (98.48%)

7.22 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.95
/foolbox/attacks/pointwise.py
1
from typing import Union, Any, Optional, Tuple, Callable, List
10✔
2
import eagerpy as ep
10✔
3
import numpy as np
10✔
4
import logging
10✔
5

6
from ..criteria import Criterion
10✔
7

8
from .base import FlexibleDistanceMinimizationAttack
10✔
9
from .saltandpepper import SaltAndPepperNoiseAttack
10✔
10

11
from ..devutils import flatten
10✔
12
from .base import Model
10✔
13
from .base import MinimizationAttack
10✔
14
from .base import get_is_adversarial
10✔
15
from .base import get_criterion
10✔
16
from .base import T
10✔
17
from .base import raise_if_kwargs
10✔
18
from .base import verify_input_bounds
10✔
19

20

21
class PointwiseAttack(FlexibleDistanceMinimizationAttack):
10✔
22
    """Starts with an adversarial and performs a binary search between
23
    the adversarial and the original for each dimension of the input
24
    individually. [#Sch18]_
25

26
    References:
27
        .. [#Sch18] Lukas Schott, Jonas Rauber, Matthias Bethge, Wieland Brendel,
28
               "Towards the first adversarially robust neural network model on MNIST",
29
               https://arxiv.org/abs/1805.09190
30
    """
31

32
    def __init__(
10✔
33
        self,
34
        init_attack: Optional[MinimizationAttack] = None,
35
        l2_binary_search: bool = True,
36
    ):
37
        self.init_attack = init_attack
10✔
38
        self.l2_binary_search = l2_binary_search
10✔
39

40
    def run(
10✔
41
        self,
42
        model: Model,
43
        inputs: T,
44
        criterion: Union[Criterion, Any] = None,
45
        *,
46
        starting_points: Optional[ep.Tensor] = None,
47
        early_stop: Optional[float] = None,
48
        **kwargs: Any,
49
    ) -> T:
50
        raise_if_kwargs(kwargs)
2✔
51
        del kwargs
2✔
52

53
        x, restore_type = ep.astensor_(inputs)
2✔
54
        del inputs
2✔
55

56
        verify_input_bounds(x, model)
2✔
57

58
        criterion_ = get_criterion(criterion)
2✔
59
        del criterion
2✔
60
        is_adversarial = get_is_adversarial(criterion_, model)
2✔
61

62
        if starting_points is None:
2✔
63
            init_attack: MinimizationAttack
64
            if self.init_attack is None:
×
65
                init_attack = SaltAndPepperNoiseAttack()
×
66
                logging.info(
×
67
                    f"Neither starting_points nor init_attack given. Falling"
68
                    f" back to {init_attack!r} for initialization."
69
                )
70
            else:
71
                init_attack = self.init_attack
×
72
            # TODO: use call and support all types of attacks (once early_stop is
73
            # possible in __call__)
74
            starting_points = init_attack.run(model, x, criterion_)
×
75

76
        x_adv = ep.astensor(starting_points)
2✔
77
        assert is_adversarial(x_adv).all()
2✔
78

79
        original_shape = x.shape
2✔
80
        N = len(x)
2✔
81

82
        x_flat = flatten(x)
2✔
83
        x_adv_flat = flatten(x_adv)
2✔
84

85
        # was there a pixel left in the samples to manipulate,
86
        # i.e. reset to the clean version?
87
        found_index_to_manipulate = ep.from_numpy(x, np.ones(N, dtype=bool))
2✔
88

89
        while ep.any(found_index_to_manipulate):
2✔
90
            diff_mask = (ep.abs(x_flat - x_adv_flat) > 1e-8).numpy()
2✔
91
            diff_idxs = [z.nonzero()[0] for z in diff_mask]
2✔
92
            untouched_indices = [z.tolist() for z in diff_idxs]
2✔
93
            untouched_indices = [
2✔
94
                np.random.permutation(it).tolist() for it in untouched_indices
95
            ]
96

97
            found_index_to_manipulate = ep.from_numpy(x, np.zeros(N, dtype=bool))
2✔
98

99
            # since the number of pixels still left to manipulate might differ
100
            # across different samples we track each of them separately and
101
            # and manipulate the images until there is no pixel left for
102
            # any of the samples. to not update already finished samples, we mask
103
            # the updates such that only samples that still have pixels left to manipulate
104
            # will be updated
105
            i = 0
2✔
106
            while i < max([len(it) for it in untouched_indices]):
2✔
107
                # mask all samples that still have pixels to manipulate left
108
                relevant_mask_lst = [len(it) > i for it in untouched_indices]
2✔
109
                relevant_mask: np.ndarray[Any, np.dtype[np.bool_]] = np.array(
2✔
110
                    relevant_mask_lst, dtype=bool
111
                )
112
                relevant_mask_index: np.ndarray[
2✔
113
                    Any, np.dtype[np.int_]
114
                ] = np.flatnonzero(relevant_mask)
115

116
                # for each image get the index of the next pixel we try out
117
                relevant_indices = [it[i] for it in untouched_indices if len(it) > i]
2✔
118

119
                old_values = x_adv_flat[relevant_mask_index, relevant_indices]
2✔
120
                new_values = x_flat[relevant_mask_index, relevant_indices]
2✔
121
                x_adv_flat = ep.index_update(
2✔
122
                    x_adv_flat, (relevant_mask_index, relevant_indices), new_values
123
                )
124

125
                # check if still adversarial
126
                is_adv = is_adversarial(x_adv_flat.reshape(original_shape))
2✔
127
                found_index_to_manipulate = ep.index_update(
2✔
128
                    found_index_to_manipulate,
129
                    relevant_mask_index,
130
                    ep.logical_or(found_index_to_manipulate, is_adv)[relevant_mask],
131
                )
132

133
                # if not, undo change
134
                new_or_old_values = ep.where(
2✔
135
                    is_adv[relevant_mask], new_values, old_values
136
                )
137
                x_adv_flat = ep.index_update(
2✔
138
                    x_adv_flat,
139
                    (relevant_mask_index, relevant_indices),
140
                    new_or_old_values,
141
                )
142

143
                i += 1
2✔
144

145
            if not ep.any(found_index_to_manipulate):
2✔
146
                break
2✔
147

148
        if self.l2_binary_search:
2✔
149
            while True:
1✔
150
                diff_mask = (ep.abs(x_flat - x_adv_flat) > 1e-12).numpy()
2✔
151
                diff_idxs = [z.nonzero()[0] for z in diff_mask]
2✔
152
                untouched_indices = [z.tolist() for z in diff_idxs]
2✔
153
                # draw random shuffling of all indices for all samples
154
                untouched_indices = [
2✔
155
                    np.random.permutation(it).tolist() for it in untouched_indices
156
                ]
157

158
                # whether that run through all values made any improvement
159
                improved = ep.from_numpy(x, np.zeros(N, dtype=bool)).astype(bool)
2✔
160

161
                logging.info("Starting new loop through all values")
2✔
162

163
                # use the same logic as above
164
                i = 0
2✔
165
                while i < max([len(it) for it in untouched_indices]):
2✔
166
                    # mask all samples that still have pixels to manipulate left
167
                    relevant_mask_lst = [len(it) > i for it in untouched_indices]
2✔
168
                    relevant_mask = np.array(relevant_mask_lst, dtype=bool)
2✔
169
                    relevant_mask_index = np.flatnonzero(relevant_mask)
2✔
170

171
                    # for each image get the index of the next pixel we try out
172
                    relevant_indices = [
2✔
173
                        it[i] for it in untouched_indices if len(it) > i
174
                    ]
175

176
                    old_values = x_adv_flat[relevant_mask_index, relevant_indices]
2✔
177
                    new_values = x_flat[relevant_mask_index, relevant_indices]
2✔
178

179
                    x_adv_flat = ep.index_update(
2✔
180
                        x_adv_flat, (relevant_mask_index, relevant_indices), new_values
181
                    )
182

183
                    # check if still adversarial
184
                    is_adv = is_adversarial(x_adv_flat.reshape(original_shape))
2✔
185

186
                    improved = ep.index_update(
2✔
187
                        improved,
188
                        relevant_mask_index,
189
                        ep.logical_or(improved, is_adv)[relevant_mask],
190
                    )
191

192
                    if not ep.all(is_adv):
2✔
193
                        # run binary search for examples that became non-adversarial
194
                        updated_new_values = self._binary_search(
2✔
195
                            x_adv_flat,
196
                            relevant_mask,
197
                            relevant_mask_index,
198
                            relevant_indices,
199
                            old_values,
200
                            new_values,
201
                            (-1, *original_shape[1:]),
202
                            is_adversarial,
203
                        )
204
                        x_adv_flat = ep.index_update(
2✔
205
                            x_adv_flat,
206
                            (relevant_mask_index, relevant_indices),
207
                            ep.where(
208
                                is_adv[relevant_mask], new_values, updated_new_values
209
                            ),
210
                        )
211

212
                        improved = ep.index_update(
2✔
213
                            improved,
214
                            relevant_mask_index,
215
                            ep.logical_or(
216
                                old_values != updated_new_values,
217
                                improved[relevant_mask],
218
                            ),
219
                        )
220

221
                    i += 1
2✔
222

223
                if not ep.any(improved):
2✔
224
                    # no improvement for any of the indices
225
                    break
2✔
226

227
        x_adv = x_adv_flat.reshape(original_shape)
2✔
228

229
        return restore_type(x_adv)
2✔
230

231
    def _binary_search(
10✔
232
        self,
233
        x_adv_flat: ep.Tensor,
234
        mask: Union[ep.Tensor, List[bool], np.ndarray[Any, np.dtype[np.bool_]]],
235
        mask_indices: Union[ep.Tensor, np.ndarray[Any, np.dtype[np.int_]]],
236
        indices: Union[ep.Tensor, List[int]],
237
        adv_values: ep.Tensor,
238
        non_adv_values: ep.Tensor,
239
        original_shape: Tuple,
240
        is_adversarial: Callable,
241
    ) -> ep.Tensor:
242
        for i in range(10):
2✔
243
            next_values = (adv_values + non_adv_values) / 2
2✔
244
            x_adv_flat = ep.index_update(
2✔
245
                x_adv_flat, (mask_indices, indices), next_values
246
            )
247
            is_adv = is_adversarial(x_adv_flat.reshape(original_shape))[mask]
2✔
248

249
            adv_values = ep.where(is_adv, next_values, adv_values)
2✔
250
            non_adv_values = ep.where(is_adv, non_adv_values, next_values)
2✔
251

252
        return adv_values
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc