• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / scaaml / 18830501663

24 Oct 2025 10:44AM UTC coverage: 86.433% (-0.3%) from 86.689%
18830501663

push

github

web-flow
Cleanup dependencies (#428)

Another attempt

---------

Co-authored-by: Karel Král <karelkral@google.com>

55 of 67 new or added lines in 6 files covered. (82.09%)

4 existing lines in 1 file now uncovered.

3077 of 3560 relevant lines covered (86.43%)

0.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.7
/scaaml/models/gpam.py
1
# Copyright 2025 Google LLC
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     https://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""This is the GPAM model version which can be imported. For the archived
15
version see /papers/2024/GPAM/gpam_ecc_cm1.py.
16

17
GPAM model, see https://github.com/google/scaaml/tree/main/papers/2024/GPAM
18

19
@article{bursztein2023generic,
20
  title={Generalized Power Attacks against Crypto Hardware using Long-Range
21
  Deep Learning},
22
  author={Bursztein, Elie and Invernizzi, Luca and Kr{\'a}l, Karel and Moghimi,
23
  Daniel and Picod, Jean-Michel and Zhang, Marina},
24
  journal={arXiv preprint arXiv:2306.07249},
25
  year={2023}
26
}
27
"""
28

29
from collections import defaultdict
1✔
30
from typing import Any, Union
1✔
31

32
# NetworkX is an optional dependency.
33
try:
1✔
34
    import networkx as nx
1✔
NEW
35
except ImportError:
×
NEW
36
    nx = None  # type: ignore[assignment]
×
37
import tensorflow as tf
1✔
38
import keras
1✔
39
from tensorflow.keras import layers
1✔
40
from tensorflow import Tensor
1✔
41

42

43
@keras.saving.register_keras_serializable()
1✔
44
class Rescale(layers.Layer):  # type: ignore[type-arg]
1✔
45
    """Rescale input to the interval [-1, 1].
46
    """
47

48
    def __init__(self, trace_min: float, trace_delta: float,
1✔
49
                 **kwargs: Any) -> None:
50
        """Information for trace rescaling.
51

52
        Args:
53

54
          trace_min (float): Minimum over all traces.
55

56
          trace_delta (float): Maximum over all traces minus `trace_min`.
57
        """
58
        super().__init__(**kwargs)
1✔
59
        self.trace_min: float = trace_min
1✔
60
        self.trace_delta: float = trace_delta
1✔
61

62
    def call(self, inputs: Tensor, **kwargs: Any) -> Tensor:
1✔
63
        """Rescale to the interval [-1, 1]."""
64
        del kwargs  # unused
1✔
65
        x = inputs
1✔
66
        x = 2 * ((x - self.trace_min) / self.trace_delta) - 1
1✔
67
        return x
1✔
68

69
    def get_config(self) -> dict[str, Any]:
1✔
70
        """Return the config to allow saving and loading of the model.
71
        """
72
        config = super().get_config()
1✔
73
        config.update({
1✔
74
            "trace_min": self.trace_min,
75
            "trace_delta": self.trace_delta,
76
        })
77
        return config
1✔
78

79

80
@keras.saving.register_keras_serializable()
1✔
81
class ScaledNorm(layers.Layer):  # type: ignore[type-arg]
1✔
82
    """ScaledNorm layer.
83

84
    Transformers without Tears: Improving the Normalization of Self-Attention
85
    Toan Q. Nguyen, Julian Salazar
86
    https://arxiv.org/abs/1910.05895
87
    """
88

89
    def __init__(self,
1✔
90
                 begin_axis: int = -1,
91
                 epsilon: float = 1e-5,
92
                 **kwargs: Any) -> None:
93
        """Initialize a ScaledNorm Layer.
94

95
        Args:
96

97
            begin_axis (int): Axis along which to apply norm. Defaults to -1.
98

99
            epsilon (float): Norm epsilon value. Defaults to 1e-5.
100
        """
101
        super().__init__(**kwargs)
1✔
102
        self._begin_axis = begin_axis
1✔
103
        self._epsilon = epsilon
1✔
104
        self._scale = self.add_weight(
1✔
105
            name="norm_scale",
106
            shape=(),
107
            initializer=tf.constant_initializer(value=1.0),
108
            trainable=True,
109
        )
110

111
    def call(self, inputs: Tensor) -> Tensor:
1✔
112
        """Return the output of this layer.
113
        """
114
        x = inputs
1✔
115
        axes = list(range(len(x.shape)))[self._begin_axis:]
1✔
116
        mean_square = tf.reduce_mean(tf.math.square(x), axes, keepdims=True)
1✔
117
        x = x * tf.math.rsqrt(mean_square + self._epsilon)
1✔
118
        return x * self._scale
1✔
119

120
    def get_config(self) -> dict[str, Any]:
1✔
121
        """Return the config to allow saving and loading of the model.
122
        """
123
        config = super().get_config()
1✔
124
        config.update({
1✔
125
            "begin_axis": self._begin_axis,
126
            "epsilon": self._epsilon
127
        })
128
        return config
1✔
129

130

131
def clone_initializer(initializer: tf.keras.initializers.Initializer) -> Any:
1✔
132
    """Clone an initializer (if an initializer is reused the generated
133
    weights are the same).
134
    """
135
    if isinstance(initializer, tf.keras.initializers.Initializer):
1✔
136
        return initializer.__class__.from_config(initializer.get_config())
1✔
137
    return initializer  # type: ignore[unreachable]
1✔
138

139

140
def rope(
1✔
141
    x: Tensor,
142
    axis: Union[list[int], int],
143
) -> Tensor:
144
    """RoPE positional encoding.
145

146
      Implementation of the Rotary Position Embedding proposed in
147
      https://arxiv.org/abs/2104.09864.
148

149
      Args:
150
          x: input tensor.
151
          axis: axis to add the positional encodings.
152

153
      Returns:
154
          The input tensor with RoPE encodings.
155
    """
156
    shape = x.shape.as_list()
1✔
157

158
    if isinstance(axis, int):
1✔
159
        axis = [axis]
×
160

161
    if isinstance(shape, (list, tuple)):
1✔
162
        spatial_shape = [shape[i] for i in axis]
1✔
163
        total_len = 1
1✔
164
        for i in spatial_shape:
1✔
165
            total_len *= i  # type: ignore[operator]
1✔
166
        position = tf.reshape(
1✔
167
            tf.cast(tf.range(total_len, delta=1.0), tf.float32), spatial_shape)
168
    else:
169
        raise ValueError(f"Unsupported shape: {shape}")
×
170

171
    # we assume that the axis can not be negative (e.g., -1)
172
    if any(dim < 0 for dim in axis):
1✔
173
        raise ValueError(f"Unsupported axis: {axis}")
×
174
    for i in range(axis[-1] + 1, len(shape) - 1, 1):
1✔
175
        position = tf.expand_dims(position, axis=-1)
×
176

177
    half_size = shape[-1] // 2  # type: ignore[operator]
1✔
178
    freq_seq = tf.cast(tf.range(half_size), tf.float32) / float(half_size)
1✔
179
    inv_freq = 10000**-freq_seq
1✔
180
    sinusoid = tf.einsum("...,d->...d", position, inv_freq)
1✔
181
    sin = tf.cast(tf.sin(sinusoid), dtype=x.dtype)
1✔
182
    cos = tf.cast(tf.cos(sinusoid), dtype=x.dtype)
1✔
183
    x1, x2 = tf.split(x, 2, axis=-1)
1✔
184
    return tf.concat(  # type: ignore[no-any-return]
1✔
185
        [x1 * cos - x2 * sin, x2 * cos + x1 * sin],
186
        axis=-1,
187
    )
188

189

190
def toeplitz_matrix_rope(
1✔
191
    n: int,
192
    a: Tensor,
193
    b: Tensor,
194
) -> Tensor:
195
    """Obtain Toeplitz matrix using rope."""
196
    a = rope(tf.tile(a[None, :], [n, 1]), axis=[0])
1✔
197
    b = rope(tf.tile(b[None, :], [n, 1]), axis=[0])
1✔
198
    return tf.einsum("mk,nk->mn", a, b)  # type: ignore[no-any-return]
1✔
199

200

201
@keras.saving.register_keras_serializable()
1✔
202
class GAU(layers.Layer):  # type: ignore[type-arg]
1✔
203
    """Gated Attention Unit layer introduced in Transformer
204
    Quality in Linear Time.
205

206
    Paper reference: https://arxiv.org/abs/2202.10447
207
    """
208

209
    def __init__(
1✔
210
            self,
211
            *,  # key-word only arguments
212
            dim: int,
213
            max_len: int = 128,
214
            shared_dim: int = 128,
215
            expansion_factor: int = 2,
216
            activation: str = "swish",
217
            attention_activation: str = "sqrrelu",
218
            dropout_rate: float = 0.0,
219
            attention_dropout_rate: float = 0.0,
220
            spatial_dropout_rate: float = 0.0,
221
            **kwargs: Any) -> None:
222
        """
223
        Initialize a GAU layer.
224

225
        Args:
226
            dim: Dimension of GAU block.
227

228
            max_len: Maximum seq len of input.
229

230
            shared_dim: Size of shared dim. Defaults to 128.
231

232
            expansion_factor: Hidden dim expansion factor. Defaults to 2.
233

234
            activation: Activation to use in projection layers. Defaults
235
                to 'swish'.
236

237
            attention_activation: Activation to use on attention scores.
238
                Defaults to 'sqrrelu'.
239

240
            dropout_rate: Feature dropout rate. Defaults to 0.0.
241

242
            attention_dropout_rate: Feature dropout rate after attention.
243
                Defaults to 0.0
244

245
            spatial_dropout_rate: Spatial dropout rate. Defaults to 0.0.
246
        """
247
        super().__init__(**kwargs)
1✔
248

249
        self.dim = dim
1✔
250
        self.max_len = max_len
1✔
251
        self.shared_dim = shared_dim
1✔
252
        self.expansion_factor = expansion_factor
1✔
253
        self.activation = activation
1✔
254
        self.attention_activation: str = attention_activation
1✔
255
        self.dropout_rate = dropout_rate
1✔
256
        self.spatial_dropout_rate = spatial_dropout_rate
1✔
257
        self.attention_dropout_rate = attention_dropout_rate
1✔
258

259
        # compute projection dimension
260
        self.expand_dim = self.dim * self.expansion_factor
1✔
261
        self.proj_dim = 2 * self.expand_dim + self.shared_dim
1✔
262

263
        # define layers
264
        self.norm = layers.LayerNormalization()
1✔
265
        self.proj1 = layers.Dense(
1✔
266
            self.proj_dim,
267
            use_bias=True,
268
            activation=self.activation,
269
        )
270
        self.proj2 = layers.Dense(self.dim, use_bias=True)
1✔
271

272
        # dropout layers
273
        self.dropout1 = layers.Dropout(self.dropout_rate)
1✔
274
        self.dropout2 = layers.Dropout(self.dropout_rate)
1✔
275

276
        if self.attention_dropout_rate:
1✔
277
            self.attention_dropout = layers.Dropout(self.attention_dropout_rate)
×
278

279
        if self.spatial_dropout_rate:
1✔
280
            self.spatial_dropout = layers.SpatialDropout1D(
×
281
                self.spatial_dropout_rate)
282

283
        # attention activation function
284
        self.attention_activation_layer = tf.keras.layers.Activation(
1✔
285
            self.attention_activation)
286

287
    def build(self, input_shape: tuple[int, ...]) -> None:
1✔
288
        del input_shape  # unused
1✔
289

290
        # setting up position encoding
291
        self.a = self.add_weight(
1✔
292
            name="a",
293
            shape=(self.max_len,),
294
            initializer=lambda *args, **kwargs: self.weight_initializer(
295
                shape=[self.max_len]),
296
            trainable=True,
297
        )
298
        self.b = self.add_weight(
1✔
299
            name="b",
300
            shape=(self.max_len,),
301
            initializer=lambda *args, **kwargs: self.weight_initializer(
302
                shape=[self.max_len]),
303
            trainable=True,
304
        )
305

306
        # offset scaling values
307
        self.gamma = self.add_weight(
1✔
308
            name="gamma",
309
            shape=(2, self.shared_dim),
310
            initializer=lambda *args, **kwargs: self.weight_initializer(
311
                shape=[2, self.shared_dim]),
312
            trainable=True,
313
        )
314
        self.beta = self.add_weight(
1✔
315
            name="beta",
316
            shape=(2, self.shared_dim),
317
            initializer=lambda *args, **kwargs: self.zeros_initializer(
318
                shape=[2, self.shared_dim]),
319
            trainable=True,
320
        )
321

322
    def call(self, x: Any, training: bool = False) -> Any:
1✔
323

324
        shortcut = x
1✔
325
        x = self.norm(x)
1✔
326

327
        # input dropout
328
        if self.spatial_dropout_rate:
1✔
329
            x = self.spatial_dropout(x, training=training)
×
330

331
        x = self.dropout1(x, training=training)
1✔
332

333
        # initial projection to generate uv
334
        uv = self.proj1(x)
1✔
335
        uv = self.dropout2(uv, training=training)
1✔
336

337
        u, v, base = tf.split(
1✔
338
            uv, [self.expand_dim, self.expand_dim, self.shared_dim], axis=-1)
339

340
        # generate q, k by scaled offset
341
        base = tf.einsum("bnr,hr->bnhr", base, self.gamma) + self.beta
1✔
342
        q, k = tf.unstack(base, axis=-2)
1✔
343

344
        # compute key-query scores
345
        qk = tf.einsum("bnd,bmd->bnm", q, k)
1✔
346
        qk = qk / self.max_len
1✔
347

348
        # add relative position bias for attention
349
        qk += toeplitz_matrix_rope(self.max_len, self.a, self.b)
1✔
350

351
        # apply attention activation
352
        kernel = self.attention_activation_layer(qk)
1✔
353

354
        if self.attention_dropout_rate:
1✔
355
            kernel = self.attention_dropout(kernel)
×
356

357
        # apply values and project
358
        x = u * tf.einsum("bnm,bme->bne", kernel, v)
1✔
359

360
        x = self.proj2(x)
1✔
361
        return x + shortcut
1✔
362

363
    def get_config(self) -> dict[str, Any]:
1✔
364
        config = super().get_config()
1✔
365
        config.update({
1✔
366
            "dim": self.dim,
367
            "max_len": self.max_len,
368
            "shared_dim": self.shared_dim,
369
            "expansion_factor": self.expansion_factor,
370
            "activation": self.activation,
371
            "attention_activation": self.attention_activation,
372
            "dropout_rate": self.dropout_rate,
373
            "spatial_dropout_rate": self.spatial_dropout_rate,
374
            "attention_dropout_rate": self.attention_dropout_rate,
375
        })
376
        return config
1✔
377

378
    @property
1✔
379
    def weight_initializer(self) -> Any:
1✔
380
        return clone_initializer(tf.random_normal_initializer(stddev=0.02))
1✔
381

382
    @property
1✔
383
    def zeros_initializer(self) -> Any:
1✔
384
        return clone_initializer(tf.initializers.zeros())
1✔
385

386

387
@keras.saving.register_keras_serializable()
1✔
388
class StopGradient(
1✔
389
        keras.layers.Layer,  # type: ignore[misc,no-any-unimported]
390
):
391
    """Stop gradient as a Keras layer.
392
    """
393

394
    def __init__(
1✔
395
        self,
396
        stop_gradient: bool = False,
397
        **kwargs: Any,
398
    ) -> None:
399
        """Stop gradient, or not, depending on the configuration.
400

401
        Args:
402

403
            stop_gradient (bool): If `True` then this layer stops gradient,
404
            otherwise it is a no-op. Defaults to `False`.
405

406
           **kwargs: Additional arguments for keras.layers.Layer.__init__.
407
        """
408
        super().__init__(**kwargs)
×
409
        self._stop_gradient = stop_gradient
×
410

411
    def call(self, inputs):  # type: ignore[no-untyped-def]
1✔
412
        if self._stop_gradient:
×
413
            # Stopping gradient.
414
            return keras.ops.stop_gradient(inputs)
×
415

416
        return inputs
×
417

418
    def get_config(self) -> dict[str, Any]:
1✔
419
        config = super().get_config()
×
420
        config.update({
×
421
            "stop_gradient": self._stop_gradient,
422
        })
423
        return config  # type: ignore[no-any-return]
×
424

425

426
def _make_head(  # type: ignore[no-any-unimported]
1✔
427
    x: keras.layers.Layer,
428
    heads: dict[str, keras.layers.Layer],
429
    name: str,
430
    relations: list[str],
431
    dim: int,
432
) -> keras.layers.Layer:
433
    """Make a single head.
434

435
    Args:
436

437
      x (Tensor): Stem of the neural network.
438

439
      heads (dict[str, keras.layers.Layer]): A dictionary of previous heads
440
      (those that are sooner in the topologically sorted outputs).
441

442
      name (str): Name of this output.
443

444
      relations (list[str]): Which outputs should be routed to this one. All of
445
      these must be already constructed and present in `heads`.
446

447
      dim (int): Number of classes of this output.
448
    """
449
    activation: str = "swish"
1✔
450
    dense_dropout: float = 0.05
1✔
451

452
    head = x
1✔
453

454
    # Construction relations layers if needed
455
    if relations:
1✔
456
        related_outputs = []
×
457
        for rname in relations:
×
458
            related_outputs.append(
×
459
                StopGradient(stop_gradient=True)(heads[rname]))
460
        related_outputs.append(x)
×
461
        head = layers.Concatenate(name=f"{name}_relations")(related_outputs)
×
462
        for _ in range(3):
×
463
            head = layers.Dense(256)(head)
×
464
            head = layers.Activation(activation)(head)
×
465

466
    head = layers.Dropout(dense_dropout, name=f"{name}_dropout")(head)
1✔
467

468
    # Dense block
469
    head = layers.Dense(dim, activation=activation)(head)
1✔
470
    head = layers.Dense(dim, activation=activation)(head)
1✔
471
    head = layers.Dense(dim, activation=activation)(head)
1✔
472
    head = layers.Dropout(dense_dropout)(head)
1✔
473
    head = layers.Dense(dim, activation=activation)(head)
1✔
474

475
    # Prediction
476
    return layers.Dense(dim, activation="softmax", name=name)(head)
1✔
477

478

479
def get_dag(
1✔
480
    outputs: dict[str, dict[str, int]],
481
    output_relations: list[tuple[str, str]],
482
) -> Any:
483
    """Return graph of output relation dependencies.
484

485
    Both `outputs` and `output_relations` are needed to have even the outputs
486
    which are not a part of any relation.
487

488
    Args:
489
      outputs (dict[str, dict]): Description of outputs as returned by
490
        scaaml.io.Dataset.as_tfdataset.
491
      output_relations (list[tuple[str, str]]): List of arcs (oriented edges)
492
        attack point name (full -- with the index) which is required for the
493
        second one. When `(ap_1, ap_2)` is present the interpretation is that
494
        `ap_2` depends on the value of `ap_1`.
495

496
    Returns: A networkx.DiGraph representation of relations.
497
    """
498
    # Create graph of relations that will be topologically sorted and contains
499
    # all head names.
500
    relation_graph: nx.DiGraph[str]  # pylint: disable=unsubscriptable-object
NEW
501
    relation_graph = nx.DiGraph()
×
502
    # Add all output names into the relation_graph (even if they appear in no
503
    # relations).
UNCOV
504
    for name in outputs:
×
UNCOV
505
        relation_graph.add_node(name)
×
506
    # Add all relation edges.
UNCOV
507
    for ap_1, ap_2 in output_relations:
×
508
        # When ap_2 depends on ap_1 then ap_1 must be created before ap_2.
509
        relation_graph.add_edge(ap_1, ap_2)
×
510

UNCOV
511
    return relation_graph
×
512

513

514
def get_topological_order(
1✔
515
    outputs: dict[str, dict[str, int]],
516
    output_relations: list[tuple[str, str]],
517
) -> list[str]:
518
    """Return iterator of vertices in topological order (if attack point ap_2
519
    depends on ap_1 then ap_1 appears before ap_2).
520

521
    Both outputs and output_relations are needed to have even the outputs which
522
    are not a part of any relation.
523

524
    Args:
525

526
      outputs (dict[str, dict[str, int]]): Description of outputs as returned
527
      by scaaml.io.Dataset.as_tfdataset.
528

529
      output_relations (list[tuple[str, str]]): List of arcs (oriented edges)
530
      attack point name (full -- with the index) which is required for the
531
      second one. When (ap_1, ap_2) is present the interpretation is that ap_2
532
      depends on the value of ap_1.
533

534
    """
535
    if output_relations:
1✔
NEW
536
        if nx is None:
×
NEW
537
            raise ImportError("To use the relational heads please install "
×
538
                              "networkx[default]")
539

540
        # We need to create the heads in a topological order.
NEW
541
        return nx.topological_sort(  # type: ignore[return-value]
×
542
            get_dag(outputs=outputs, output_relations=output_relations))
543
    else:
544
        return list(outputs)
1✔
545

546

547
def create_heads_outputs(  # type: ignore[no-any-unimported]
1✔
548
    x: Tensor,
549
    outputs: dict[str, dict[str, int]],
550
    output_relations: list[tuple[str, str]],
551
) -> dict[str, keras.layers.Layer]:
552
    """Make a mapping of all heads (name to Layer).
553

554
    Args:
555

556
      x (FloatTensor): The trunk.
557

558
      outputs (dict[str, dict[str, int]]): Description of outputs as returned
559
      by scaaml.io.Dataset.as_tfdataset.
560

561
      output_relations (list[tuple[str, str]]): List of arcs (oriented edges)
562
      attack point name (full -- with the index) which is required for the
563
      second one. When (ap_1, ap_2) is present the interpretation is that ap_2
564
      depends on the value of ap_1.
565

566
    Returns: A mapping of all head outputs (name to Layer).
567
    """
568
    # Create relations represented by lists of ingoing edges (attack points:
569
    # list of all attack points it depends on).
570
    ingoing_relations: dict[str, list[str]] = defaultdict(list)
1✔
571
    for ap_1, ap_2 in output_relations:
1✔
572
        ingoing_relations[ap_2].append(ap_1)
×
573
    # Freeze the dict
574
    ingoing_relations = dict(ingoing_relations)
1✔
575

576
    # Dictionary containing the actual network heads
577
    heads: dict[str, keras.layers.Layer] = {}  # type: ignore[no-any-unimported]
1✔
578

579
    # Get iterator of outputs that are in topological order (if ap_2 depends on
580
    # ap_1 then ap_1 appears before ap_2).
581
    topological_order = get_topological_order(
1✔
582
        outputs=outputs,
583
        output_relations=output_relations,
584
    )
585

586
    # Create heads.
587
    for name in topological_order:
1✔
588
        # Get relations (possibly an empty list).
589
        relations = ingoing_relations.get(name, [])
1✔
590

591
        # Get parameters for head creation.
592
        dim = outputs[name]["max_val"] if outputs[name]["max_val"] > 2 else 1
1✔
593
        head = _make_head(x, heads, name, relations, dim)
1✔
594
        heads[name] = head
1✔
595

596
    # Return all head outputs in a dict.
597
    heads_outputs = {name: heads[name] for name in outputs.keys()}
1✔
598
    return heads_outputs
1✔
599

600

601
def get_gpam_model(  # type: ignore[no-any-unimported]
1✔
602
    *,  # key-word only arguments
603
    inputs: dict[str, dict[str, float]],
604
    outputs: dict[str, dict[str, int]],
605
    output_relations: list[tuple[str, str]],
606
    trace_len: int,
607
    merge_filter_1: int,
608
    merge_filter_2: int,
609
    patch_size: int,
610
) -> keras.models.Model:
611
    """Get a GPAM model instance.
612

613
    Args:
614

615
      inputs (dict[str, dict[str, float]]): The following dictionary:
616
      {"trace1": {"min": MIN, "delta": MAX}} where `MIN` is the minimum value
617
      across all traces and time and `MAX` is the maximum value.
618

619
      outputs (dict[str, dict[str, int]]): A dictionary with output name and
620
      "max_val" being the number of possible classes. Example:
621
      `outputs={"sub_bytes_in_0": {"max_val": 256}}`.
622

623
      output_relations (list[tuple[str, str]]): A list of related inputs. Each
624
      relation is a list where the output of the first is fed to the second.
625
      Must form a directed acyclic graph.
626

627
      trace_len (int): The trace is assumed to be one-dimensional of length
628
      `trace_len`. Must be divisible by `patch_size`.
629

630
      merge_filter_1 (int): The number of filters in the first layer of
631
      convolutions.
632

633
      merge_filter_2 (int): The number of filters in the second layer of
634
      convolutions.
635

636
      patch_size (int): Cut the trace into patches of this length. Must divide
637
      `trace_len`.
638

639
    ```
640
    @article{bursztein2023generic,
641
      title={Generalized Power Attacks against Crypto Hardware using Long-Range
642
      Deep Learning},
643
      author={Bursztein, Elie and Invernizzi, Luca and Kr{\'a}l, Karel and
644
      Moghimi, Daniel and Picod, Jean-Michel and Zhang, Marina},
645
      journal={arXiv preprint arXiv:2306.07249},
646
      year={2023}
647
    }
648
    ```
649
    """
650
    # Constants:
651
    if trace_len % patch_size:
1✔
652
        raise ValueError(f"{trace_len = } is not divisible by {patch_size = }")
×
653
    steps: int = trace_len // patch_size
1✔
654
    combine_kernel_size: int = 3
1✔
655
    activation: str = "swish"
1✔
656
    combine_strides: int = 1
1✔
657
    filters: int = 192
1✔
658

659
    # Input
660
    model_input = layers.Input(shape=(trace_len,), name="trace1")
1✔
661
    x = model_input
1✔
662

663
    # Reshape the trace.
664
    x = layers.Reshape((steps, patch_size))(x)
1✔
665
    x = Rescale(  # to the interval [-1, 1].
1✔
666
        trace_min=inputs["trace1"]["min"],
667
        trace_delta=inputs["trace1"]["delta"],
668
    )(x)
669

670
    # Single dense after preprocess.
671
    x = layers.Dense(filters)(x)
1✔
672

673
    # Dropout
674
    x = layers.SpatialDropout1D(0.1)(x)
1✔
675

676
    # Transformer layers (with intermediate results).
677
    s = x
1✔
678
    gau_results = []  # Intermediate results (after 1st, 2nd...).
1✔
679
    for _ in range(3):
1✔
680
        s = GAU(
1✔
681
            dim=filters,
682
            max_len=steps,
683
            expansion_factor=2,
684
            attention_activation="softsign",
685
        )(s)
686
        gau_results.append(s)
1✔
687
    x = layers.Concatenate()(gau_results)
1✔
688

689
    # Norm after concatenate
690
    x = layers.BatchNormalization()(x)
1✔
691

692
    # merge blocks
693
    if merge_filter_1:
1✔
694
        x = layers.Conv1D(merge_filter_1,
1✔
695
                          combine_kernel_size,
696
                          activation=activation,
697
                          strides=combine_strides)(x)
698
        # MaxPool1D if applicable
699
        x = layers.MaxPool1D(pool_size=2)(x)
1✔
700
        # Second merge block
701
        if merge_filter_2:
1✔
702
            x = ScaledNorm()(x)
1✔
703
            x = layers.Conv1D(merge_filter_2,
1✔
704
                              combine_kernel_size,
705
                              activation=activation,
706
                              strides=combine_strides)(x)
707

708
    # post merge dropouts
709
    x = layers.Dropout(0.1)(x)
1✔
710

711
    # flattening
712
    x = layers.GlobalAveragePooling1D(data_format="channels_first")(x)
1✔
713

714
    # Normalizing
715
    x = layers.BatchNormalization()(x)
1✔
716

717
    # Make head outputs
718
    heads_outputs = create_heads_outputs(
1✔
719
        x=x,
720
        outputs=outputs,
721
        output_relations=output_relations,
722
    )
723

724
    model = keras.models.Model(model_input, heads_outputs)
1✔
725
    return model
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc