• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bethgelab / foolbox / 8139141456

04 Mar 2024 11:03AM UTC coverage: 37.923% (-60.6%) from 98.477%
8139141456

Pull #722

github

web-flow
Merge 5663238db into 17e0e9b31
Pull Request #722: Fix guide compilation

1344 of 3544 relevant lines covered (37.92%)

0.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

7.02
/foolbox/attacks/gen_attack_utils.py
1
from typing import Union, List, Tuple
1✔
2
import eagerpy as ep
1✔
3

4

5
def rescale_jax(x: ep.JAXTensor, target_shape: List[int]) -> ep.JAXTensor:
1✔
6
    # img must be in channel_last format
7

8
    # modified according to https://github.com/google/jax/issues/862
9
    import jax.numpy as np
×
10

11
    img = x.raw
×
12

13
    resize_rates = (target_shape[1] / x.shape[1], target_shape[2] / x.shape[2])
×
14

15
    def interpolate_bilinear(
×
16
        im: np.ndarray, rows: np.ndarray, cols: np.ndarray
17
    ) -> np.ndarray:
18
        # based on http://stackoverflow.com/a/12729229
19
        col_lo = np.floor(cols).astype(int)
×
20
        col_hi = col_lo + 1
×
21
        row_lo = np.floor(rows).astype(int)
×
22
        row_hi = row_lo + 1
×
23

24
        def cclip(cols: np.ndarray) -> np.ndarray:
×
25
            return np.clip(cols, 0, ncols - 1)  # type: ignore
×
26

27
        def rclip(rows: np.ndarray) -> np.ndarray:
×
28
            return np.clip(rows, 0, nrows - 1)  # type: ignore
×
29

30
        nrows, ncols = im.shape[-3:-1]
×
31

32
        Ia = im[..., rclip(row_lo), cclip(col_lo), :]
×
33
        Ib = im[..., rclip(row_hi), cclip(col_lo), :]
×
34
        Ic = im[..., rclip(row_lo), cclip(col_hi), :]
×
35
        Id = im[..., rclip(row_hi), cclip(col_hi), :]
×
36

37
        wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1)
×
38
        wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1)
×
39
        wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1)
×
40
        wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1)
×
41

42
        return wa * Ia + wb * Ib + wc * Ic + wd * Id
×
43

44
    nrows, ncols = img.shape[-3:-1]
×
45
    deltas = (0.5 / resize_rates[0], 0.5 / resize_rates[1])
×
46

47
    rows = np.linspace(deltas[0], nrows - deltas[0], int(resize_rates[0] * nrows))
×
48
    cols = np.linspace(deltas[1], ncols - deltas[1], int(resize_rates[1] * ncols))
×
49
    rows_grid, cols_grid = np.meshgrid(rows - 0.5, cols - 0.5, indexing="ij")
×
50

51
    img_resize_vec = interpolate_bilinear(img, rows_grid.flatten(), cols_grid.flatten())
×
52
    img_resize = np.reshape(
×
53
        img_resize_vec, img.shape[:-3] + (len(rows), len(cols)) + img.shape[-1:]
54
    )
55

56
    return ep.JAXTensor(img_resize)
×
57

58

59
def rescale_numpy(x: ep.NumPyTensor, target_shape: List[int]) -> ep.NumPyTensor:
1✔
60
    import numpy as np
×
61

62
    img = x.raw
×
63

64
    resize_rates = (target_shape[1] / x.shape[1], target_shape[2] / x.shape[2])
×
65

66
    def interpolate_bilinear(
×
67
        im: np.ndarray, rows: np.ndarray, cols: np.ndarray
68
    ) -> np.ndarray:
69
        # based on http://stackoverflow.com/a/12729229
70
        col_lo = np.floor(cols).astype(int)
×
71
        col_hi = col_lo + 1
×
72
        row_lo = np.floor(rows).astype(int)
×
73
        row_hi = row_lo + 1
×
74

75
        def cclip(cols: np.ndarray) -> np.ndarray:
×
76
            return np.clip(cols, 0, ncols - 1)
×
77

78
        def rclip(rows: np.ndarray) -> np.ndarray:
×
79
            return np.clip(rows, 0, nrows - 1)
×
80

81
        nrows, ncols = im.shape[-3:-1]
×
82

83
        Ia = im[..., rclip(row_lo), cclip(col_lo), :]
×
84
        Ib = im[..., rclip(row_hi), cclip(col_lo), :]
×
85
        Ic = im[..., rclip(row_lo), cclip(col_hi), :]
×
86
        Id = im[..., rclip(row_hi), cclip(col_hi), :]
×
87

88
        wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1)
×
89
        wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1)
×
90
        wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1)
×
91
        wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1)
×
92

93
        return wa * Ia + wb * Ib + wc * Ic + wd * Id  # type: ignore
×
94

95
    nrows, ncols = img.shape[-3:-1]
×
96
    deltas = (0.5 / resize_rates[0], 0.5 / resize_rates[1])
×
97

98
    rows = np.linspace(deltas[0], nrows - deltas[0], np.int32(resize_rates[0] * nrows))
×
99
    cols = np.linspace(deltas[1], ncols - deltas[1], np.int32(resize_rates[1] * ncols))
×
100
    rows_grid, cols_grid = np.meshgrid(rows - 0.5, cols - 0.5, indexing="ij")
×
101

102
    img_resize_vec = interpolate_bilinear(img, rows_grid.flatten(), cols_grid.flatten())
×
103
    img_resize = img_resize_vec.reshape(
×
104
        img.shape[:-3] + (len(rows), len(cols)) + img.shape[-1:]
105
    )
106

107
    return ep.NumPyTensor(img_resize)
×
108

109

110
def rescale_tensorflow(
1✔
111
    x: ep.TensorFlowTensor, target_shape: List[int]
112
) -> ep.TensorFlowTensor:
113
    import tensorflow
×
114

115
    img = x.raw
×
116

117
    img_resized = tensorflow.image.resize(img, size=target_shape[1:-1])
×
118

119
    return ep.TensorFlowTensor(img_resized)
×
120

121

122
def rescale_pytorch(x: ep.PyTorchTensor, target_shape: List[int]) -> ep.PyTorchTensor:
1✔
123
    import torch
×
124

125
    img = x.raw
×
126

127
    img_resized = torch.nn.functional.interpolate(
×
128
        img, size=target_shape[2:], mode="bilinear", align_corners=False
129
    )
130

131
    return ep.PyTorchTensor(img_resized)
×
132

133

134
def swap_axes(x: ep.TensorType, dim0: int, dim1: int) -> ep.TensorType:
1✔
135
    assert dim0 < x.ndim
×
136
    assert dim1 < x.ndim
×
137

138
    axes = list(range(x.ndim))
×
139
    axes[dim0] = dim1
×
140
    axes[dim1] = dim0
×
141

142
    return ep.transpose(x, tuple(axes))
×
143

144

145
def rescale_images(
1✔
146
    x: ep.TensorType, target_shape: Union[Tuple[int, ...], List[int]], channel_axis: int
147
) -> ep.TensorType:
148
    target_shape = list(target_shape)
×
149

150
    if channel_axis < 0:
×
151
        channel_axis = x.ndim - 1 + channel_axis
×
152

153
    if isinstance(x, ep.PyTorchTensor):
×
154
        if channel_axis != 1:
×
155
            x = swap_axes(x, channel_axis, 1)  # type: ignore
×
156

157
            target_shape[channel_axis], target_shape[1] = (
×
158
                target_shape[1],
159
                target_shape[channel_axis],
160
            )
161

162
        x = rescale_pytorch(x, target_shape)  # type: ignore
×
163

164
        if channel_axis != 1:
×
165
            x = swap_axes(x, channel_axis, 1)  # type: ignore
×
166

167
    elif isinstance(x, ep.TensorFlowTensor):
×
168
        if channel_axis != x.ndim - 1:
×
169
            x = swap_axes(x, channel_axis, x.ndim - 1)  # type: ignore
×
170

171
            target_shape[channel_axis], target_shape[x.ndim - 1] = (
×
172
                target_shape[x.ndim - 1],
173
                target_shape[channel_axis],
174
            )
175

176
        x = rescale_tensorflow(x, target_shape)  # type: ignore
×
177

178
        if channel_axis != x.ndim - 1:
×
179
            x = swap_axes(x, channel_axis, x.ndim - 1)  # type: ignore
×
180

181
    elif isinstance(x, ep.NumPyTensor):
×
182
        if channel_axis != x.ndim - 1:
×
183
            x = swap_axes(x, channel_axis, x.ndim - 1)  # type: ignore
×
184

185
            target_shape[channel_axis], target_shape[x.ndim - 1] = (
×
186
                target_shape[x.ndim - 1],
187
                target_shape[channel_axis],
188
            )
189

190
        x = rescale_numpy(x, target_shape)  # type: ignore
×
191
        if channel_axis != x.ndim - 1:
×
192
            x = swap_axes(x, channel_axis, x.ndim - 1)  # type: ignore
×
193

194
    elif isinstance(x, ep.JAXTensor):
×
195
        if channel_axis != x.ndim - 1:
×
196
            x = swap_axes(x, channel_axis, x.ndim - 1)  # type: ignore
×
197

198
            target_shape[channel_axis], target_shape[x.ndim - 1] = (
×
199
                target_shape[x.ndim - 1],
200
                target_shape[channel_axis],
201
            )
202

203
        x = rescale_jax(x, target_shape)  # type: ignore
×
204
        if channel_axis != x.ndim - 1:
×
205
            x = swap_axes(x, channel_axis, x.ndim - 1)  # type: ignore
×
206

207
    return x
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc