• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bethgelab / foolbox / 8139141456

04 Mar 2024 11:03AM UTC coverage: 37.923% (-60.6%) from 98.477%
8139141456

Pull #722

github

web-flow
Merge 5663238db into 17e0e9b31
Pull Request #722: Fix guide compilation

1344 of 3544 relevant lines covered (37.92%)

0.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

34.78
/foolbox/attacks/spatial_attack.py
1
from typing import Union, Any, Tuple, Generator
1✔
2
import eagerpy as ep
1✔
3
import numpy as np
1✔
4

5
from ..devutils import atleast_kd
1✔
6

7
from ..criteria import Criterion
1✔
8

9
from .base import Model
1✔
10
from .base import T
1✔
11
from .base import get_is_adversarial
1✔
12
from .base import get_criterion
1✔
13
from .base import Attack
1✔
14
from .spatial_attack_transformations import rotate_and_shift
1✔
15
from .base import raise_if_kwargs
1✔
16
from .base import verify_input_bounds
1✔
17

18

19
class SpatialAttack(Attack):
1✔
20
    """Adversarially chosen rotations and translations. [#Engs]
21
    This implementation is based on the reference implementation by
22
    Madry et al.: https://github.com/MadryLab/adversarial_spatial
23

24
    References:
25
    .. [#Engs] Logan Engstrom*, Brandon Tran*, Dimitris Tsipras*,
26
           Ludwig Schmidt, Aleksander MÄ…dry: "A Rotation and a
27
           Translation Suffice: Fooling CNNs with Simple Transformations",
28
           http://arxiv.org/abs/1712.02779
29
    """
30

31
    def __init__(
1✔
32
        self,
33
        max_translation: float = 3,
34
        max_rotation: float = 30,
35
        num_translations: int = 5,
36
        num_rotations: int = 5,
37
        grid_search: bool = True,
38
        random_steps: int = 100,
39
    ):
40

41
        self.max_trans = max_translation
1✔
42
        self.max_rot = max_rotation
1✔
43

44
        self.grid_search = grid_search
1✔
45

46
        # grid search true
47
        self.num_trans = num_translations
1✔
48
        self.num_rots = num_rotations
1✔
49

50
        # grid search false
51
        self.random_steps = random_steps
1✔
52

53
    def __call__(  # type: ignore
1✔
54
        self,
55
        model: Model,
56
        inputs: T,
57
        criterion: Any,
58
        **kwargs: Any,
59
    ) -> Tuple[T, T, T]:
60
        x, restore_type = ep.astensor_(inputs)
×
61
        del inputs
×
62
        criterion = get_criterion(criterion)
×
63

64
        is_adversarial = get_is_adversarial(criterion, model)
×
65

66
        if x.ndim != 4:
×
67
            raise NotImplementedError(
68
                "only implemented for inputs with two spatial dimensions (and one channel and one batch dimension)"
69
            )
70

71
        xp = self.run(model, x, criterion)
×
72
        success = is_adversarial(xp)
×
73

74
        xp_ = restore_type(xp)
×
75
        return xp_, xp_, restore_type(success)  # twice to match API
×
76

77
    def run(
1✔
78
        self,
79
        model: Model,
80
        inputs: T,
81
        criterion: Union[Criterion, T],
82
        **kwargs: Any,
83
    ) -> T:
84
        raise_if_kwargs(kwargs)
×
85

86
        x, restore_type = ep.astensor_(inputs)
×
87
        del inputs, kwargs
×
88

89
        verify_input_bounds(x, model)
×
90

91
        criterion = get_criterion(criterion)
×
92
        is_adversarial = get_is_adversarial(criterion, model)
×
93

94
        found = is_adversarial(x)
×
95
        results = x
×
96

97
        def grid_search_generator() -> Generator[Any, Any, Any]:
×
98
            dphis = np.linspace(-self.max_rot, self.max_rot, self.num_rots)
×
99
            dxs = np.linspace(-self.max_trans, self.max_trans, self.num_trans)
×
100
            dys = np.linspace(-self.max_trans, self.max_trans, self.num_trans)
×
101
            for dphi in dphis:
×
102
                for dx in dxs:
×
103
                    for dy in dys:
×
104
                        yield dphi, dx, dy
×
105

106
        def random_search_generator() -> Generator[Any, Any, Any]:
×
107
            dphis = np.random.uniform(-self.max_rot, self.max_rot, self.random_steps)
×
108
            dxs = np.random.uniform(-self.max_trans, self.max_trans, self.random_steps)
×
109
            dys = np.random.uniform(-self.max_trans, self.max_trans, self.random_steps)
×
110
            for dphi, dx, dy in zip(dphis, dxs, dys):
×
111
                yield dphi, dx, dy
×
112

113
        gen = grid_search_generator() if self.grid_search else random_search_generator()
×
114
        for dphi, dx, dy in gen:
×
115
            # TODO: reduce the batch size to the ones that haven't been successful
116

117
            x_p = rotate_and_shift(x, translation=(dx, dy), rotation=dphi)
×
118
            is_adv = is_adversarial(x_p)
×
119
            new_adv = ep.logical_and(is_adv, found.logical_not())
×
120

121
            results = ep.where(atleast_kd(new_adv, x_p.ndim), x_p, results)
×
122
            found = ep.logical_or(new_adv, found)
×
123
            if found.all():
×
124
                break  # all images in batch misclassified
×
125
        return restore_type(results)
×
126

127
    def repeat(self, times: int) -> Attack:
1✔
128
        if self.grid_search:
×
129
            raise ValueError(
×
130
                "repeat is not supported if attack is deterministic"
131
            )  # attack is deterministic
132
        else:
133
            random_steps = self.random_steps * times
×
134
            return SpatialAttack(
×
135
                max_translation=self.max_trans,
136
                max_rotation=self.max_rot,
137
                num_translations=self.num_trans,
138
                num_rotations=self.num_rots,
139
                grid_search=self.grid_search,
140
                random_steps=random_steps,
141
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc