• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bb515 / diffusionjax / 12854864454

26 Aug 2024 11:58AM UTC coverage: 12.865% (+0.03%) from 12.832%
12854864454

push

github

web-flow
Merge pull request #34 from bb515/develop

Develop

3 of 23 new or added lines in 5 files covered. (13.04%)

1 existing line in 1 file now uncovered.

198 of 1539 relevant lines covered (12.87%)

0.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/diffusionjax/models/networks_edm2.py
1
"""JAX port of Improved diffusion model architecture proposed in the paper
2
"Analyzing and Improving the Training Dynamics of Diffusion Models".
3
Ported from the code https://github.com/NVlabs/edm2/blob/main/training/networks_edm2.py
4
"""
UNCOV
5
import jax
×
6
import jax.numpy as jnp
×
7
import flax.linen as nn
×
8
from typing import Any
×
9

10

11
def jax_unstack(x, axis=0):
×
12
  """https://github.com/google/jax/discussions/11028"""
13
  return [
×
14
    jax.lax.index_in_dim(x, i, axis, keepdims=False) for i in range(x.shape[axis])
15
  ]
16

17

18
def pixel_normalize(x, channel_axis, eps=1e-4):
×
19
  """
20
  Normalize given tensor to unit magnitude with respect to the given
21
  channel axis.
22
  Args:
23
    x: Assume (N, C, H, W)
24
  """
25
  norm = jnp.float32(jnp.linalg.vector_norm(x, axis=channel_axis, keepdims=True))
×
26
  norm = eps + jnp.sqrt(norm.size / x.size) * norm
×
27
  return x / jnp.array(norm, dtype=x.dtype)
×
28

29

30
def weight_normalize(x, eps=1e-4):
×
31
  """
32
  Normalize given tensor to unit magnitude with respect to all the dimensions
33
  except the first.
34
  Args:
35
    x: Assume (N, C, H, W)
36
  """
37
  norm = jnp.float32(jax.vmap(lambda x: jnp.linalg.vector_norm(x, keepdims=True))(x))
×
38
  norm = eps + jnp.sqrt(norm.size / x.size) * norm
×
39
  return x / jnp.array(norm, dtype=x.dtype)
×
40

41

42
def forced_weight_normalize(x, eps=1e-4):
×
43
  """
44
  Normalize given tensor to unit magnitude with respect to all the dimensions
45
  except the first. Don't take gradients through the computation.
46
  Args:
47
    x: Assume (N, C, H, W)
48
  """
49
  norm = jax.lax.stop_gradient(
×
50
    jnp.float32(jax.vmap(lambda x: jnp.linalg.vector_norm(x, keepdims=True))(x))
51
  )
52
  norm = eps + jnp.sqrt(norm.size / x.size) * norm
×
53
  return x / jnp.array(norm, dtype=x.dtype)
×
54

55

56
def resample(x, f=[1, 1], mode="keep"):
×
57
  """
58
  Upsample or downsample the given tensor with the given filter,
59
  or keep it as is.
60

61
  Args:
62
    x: Assume (N, C, H, W)
63
  """
64
  if mode == "keep":
×
65
    return x
×
66
  f = jnp.array(f, dtype=x.dtype)
×
67
  assert f.ndim == 1 and len(f) % 2 == 0
×
68
  f = f / f.sum()
×
69
  f = jnp.outer(f, f)[jnp.newaxis, jnp.newaxis, :, :]
×
70
  c = x.shape[1]
×
71

72
  if mode == "down":
×
73
    return jax.lax.conv_general_dilated(
×
74
      x,
75
      jnp.tile(f, (c, 1, 1, 1)),
76
      window_strides=(2, 2),
77
      feature_group_count=c,
78
      padding="SAME",
79
    )
80
  assert mode == "up"
×
81

82
  pad = (len(f) - 1) // 2 + 1
×
83
  return jax.lax.conv_general_dilated(
×
84
    x,
85
    jnp.tile(f * 4, (c, 1, 1, 1)),
86
    dimension_numbers=("NCHW", "OIHW", "NCHW"),
87
    window_strides=(1, 1),
88
    lhs_dilation=(2, 2),
89
    feature_group_count=c,
90
    padding=((pad, pad), (pad, pad)),
91
  )
92

93

94
def mp_silu(x):
×
95
  """Magnitude-preserving SiLU (Equation 81)."""
96
  return nn.activation.silu(x) / 0.596
×
97

98

99
def mp_sum(a, b, t=0.5):
×
100
  """Magnitude-preserving sum (Equation 88)."""
101
  return (a + t * (b - a)) / jnp.sqrt((1 - t) ** 2 + t**2)
×
102

103

104
def mp_cat(a, b, dim=1, t=0.5):
×
105
  """Magnitude-preserving concatenation (Equation 103)."""
106
  Na = a.shape[dim]
×
107
  Nb = b.shape[dim]
×
108
  C = jnp.sqrt((Na + Nb) / ((1 - t) ** 2 + t**2))
×
109
  wa = C / jnp.sqrt(Na) * (1 - t)
×
110
  wb = C / jnp.sqrt(Nb) * t
×
111
  return jax.lax.concatenate([wa * a, wb * b], dimension=dim)
×
112

113

114
class MPFourier(nn.Module):
×
115
  """Magnitude-preserving Fourier features (Equation 75)."""
116

117
  num_channels: int
×
118
  bandwidth: float = 1.0
×
119

120
  @nn.compact
×
121
  def __call__(self, x):
×
122
    freqs = self.param(
×
123
      "freqs",
124
      jax.nn.initializers.normal(stddev=2 * jnp.pi * self.bandwidth),
125
      (self.num_channels,),
126
    )
127
    freqs = jax.lax.stop_gradient(freqs)
×
128
    phases = self.param(
×
129
      "phases", jax.nn.initializers.normal(stddev=2 * jnp.pi), (self.num_channels,)
130
    )
131
    phases = jax.lax.stop_gradient(phases)
×
132
    y = jnp.float32(x)
×
133
    y = jnp.float32(jnp.outer(x, freqs))
×
134
    y = y + jnp.float32(phases)
×
135
    y = jnp.cos(y) * jnp.sqrt(2)
×
136
    return jnp.array(y, dtype=x.dtype)
×
137

138

139
class MPConv(nn.Module):
×
140
  """Magnitude-preserving convolution or fully-connected layer (Equation 47)
141
  with force weight normalization (Equation 66).
142
  """
143

144
  in_channels: int
×
145
  out_channels: int
×
146
  kernel_shape: tuple
×
147
  training: bool = True
×
148

149
  @nn.compact
×
150
  def __call__(self, x, gain=1.0):
×
151
    w = jnp.float32(
×
152
      self.param(
153
        "w",
154
        jax.nn.initializers.normal(stddev=1.0),
155
        (self.out_channels, self.in_channels, *self.kernel_shape),
156
      )
157
    )  # TODO: type promotion required in JAX?
158
    if self.training:
×
159
      w = forced_weight_normalize(w)  # forced weight normalization
×
160

161
    w = weight_normalize(w)  # traditional weight normalization
×
162
    w = w * (gain / jnp.sqrt(w[0].size))  # magnitude-preserving scaling
×
163
    w = jnp.array(w, dtype=x.dtype)
×
164
    if w.ndim == 2:
×
NEW
165
      return x @ w.T
×
166
    assert w.ndim == 4
×
167

168
    return jax.lax.conv(
×
169
      x,
170
      w,
171
      window_strides=(1, 1),
172
      padding="SAME",
173
    )
174

175

176
class Block(nn.Module):
×
177
  """
178
  U-Net encoder/decoder block with optional self-attention (Figure 21).
179
  """
180

181
  in_channels: int  # Number of input channels
×
182
  out_channels: int  # Number of output channels
×
183
  emb_channels: int  # Number of embedding channels
×
184
  flavor: str = "enc"  # Flavor: 'enc' or 'dec'
×
185
  resample_mode: str = "keep"  # Resampling: 'keep', 'up', or 'down'.
×
186
  resample_filter: tuple = (1, 1)  # Resampling filter.
×
187
  attention: bool = False  # Include self-attention?
×
188
  channels_per_head: int = 64  # Number of channels per attention head.
×
189
  dropout: float = 0.0  # Dropout probability.
×
190
  res_balance: float = 0.3  # Balance between main branch (0) and residual branch (1).
×
191
  attn_balance: float = 0.3  # Balance between main branch (0) and self-attention (1).
×
192
  clip_act: int = 256  # Clip output activations. None = do not clip.
×
193
  training: bool = True
×
194

195
  @nn.compact
×
196
  def __call__(self, x, emb):
×
197
    # Main branch
198
    x = resample(x, f=self.resample_filter, mode=self.resample_mode)
×
199
    if self.flavor == "enc":
×
200
      if self.in_channels != self.out_channels:
×
201
        x = MPConv(
×
202
          self.in_channels, self.out_channels, kernel_shape=(1, 1), name="conv_skip"
203
        )(x)
204
      x = pixel_normalize(x, channel_axis=1)  # pixel norm
×
205

206
    # Residual branch
207
    y = MPConv(
×
208
      self.out_channels if self.flavor == "enc" else self.in_channels,
209
      self.out_channels,
210
      kernel_shape=(3, 3),
211
      name="conv_res0",
212
    )(mp_silu(x))
213

214
    c = (
×
215
      MPConv(self.emb_channels, self.out_channels, kernel_shape=(), name="emb_linear")(
216
        emb, gain=self.param("emb_gain", jax.nn.initializers.zeros, (1,))
217
      )
218
      + 1
219
    )
220
    y = jnp.array(
×
221
      mp_silu(y * jnp.expand_dims(jnp.expand_dims(c, axis=2), axis=3)), dtype=y.dtype
222
    )
223
    if self.dropout:
×
224
      y = nn.Dropout(self.dropout)(y, deterministic=not self.training)
×
225
    y = MPConv(
×
226
      self.out_channels, self.out_channels, kernel_shape=(3, 3), name="conv_res1"
227
    )(y)
228

229
    # Connect the branches
230
    if self.flavor == "dec" and self.in_channels != self.out_channels:
×
231
      x = MPConv(
×
232
        self.in_channels, self.out_channels, kernel_shape=(1, 1), name="conv_skip"
233
      )(x)
234
    x = mp_sum(x, y, t=self.res_balance)
×
235

236
    # Self-attention
237
    # TODO: test if flax.linen.SelfAttention can be used instead here?
238
    num_heads = self.out_channels // self.channels_per_head if self.attention else 0
×
239
    if num_heads != 0:
×
240
      y = MPConv(
×
241
        self.out_channels, self.out_channels * 3, kernel_shape=(1, 1), name="attn_qkv"
242
      )(x)
243
      y = y.reshape(y.shape[0], num_heads, -1, 3, y.shape[2] * y.shape[3])
×
244
      q, k, v = jax_unstack(
×
245
        pixel_normalize(y, channel_axis=2), axis=3
246
      )  # pixel normalization and split
247
      # NOTE: quadratic cost in last dimension
248
      w = nn.softmax(jnp.einsum("nhcq,nhck->nhqk", q, k / jnp.sqrt(q.shape[2])), axis=3)
×
249
      y = jnp.einsum("nhqk,nhck->nhcq", w, v)
×
250
      y = MPConv(
×
251
        self.out_channels, self.out_channels, kernel_shape=(1, 1), name="attn_proj"
252
      )(y.reshape(*x.shape))
253
      x = mp_sum(x, y, t=self.attn_balance)
×
254

255
    # Clip activations
256
    if self.clip_act is not None:
×
257
      x = jnp.clip(x, -self.clip_act, self.clip_act)
×
258
    return x
×
259

260

261
class UNet(nn.Module):
×
262
  """EDM2 U-Net model (Figure 21)."""
263

264
  img_resolution: int  # Image resolution.
×
265
  img_channels: int  # Image channels.
×
266
  label_dim: int  # Class label dimensionality. 0 = unconditional.
×
267
  model_channels: int = 192  # Base multiplier for the number of channels.
×
268
  channel_mult: tuple = (
×
269
    1,
270
    2,
271
    3,
272
    4,
273
  )  # Per-resolution multipliers for the number of channels.
274
  channel_mult_noise: Any = None  # Multiplier for noise embedding dimensionality. None = select based on channel_mult.
×
275
  channel_mult_emb: Any = None  # Multiplier for final embedding dimensionality. None = select based on channel_mult.
×
276
  num_blocks: int = 3  # Number of residual blocks per resolution.
×
277
  attn_resolutions: tuple = (16, 8)  # List of resolutions with self-attention.
×
278
  label_balance: float = (
×
279
    0.5  # Balance between noise embedding (0) and class embedding (1).
280
  )
281
  concat_balance: float = 0.5  # Balance between skip connections (0) and main path (1).
×
282

283
  # **block_kwargs - arguments for Block
284
  resample_filter: tuple = (1, 1)  # Resampling filter
×
285
  channels_per_head: int = 64  # Number of channels per attention head
×
286
  dropout: float = 0.0  # Dropout probability
×
287
  res_balance: float = 0.3  # Balance between main branch (0) and residual branch (1)
×
288
  attn_balance: float = 0.3  # Balance between main branch (0) and self-attention (1)
×
289
  clip_act: int = 256  # Clip output activations. None = do not clip
×
290
  out_gain: Any = None
×
291
  block_kwargs = {
×
292
    "resample_filter": resample_filter,
293
    "channels_per_head": channels_per_head,
294
    "dropout": dropout,
295
    "res_balance": res_balance,
296
    "attn_balance": attn_balance,
297
    "clip_act": clip_act,
298
  }
299

300
  @nn.compact
×
301
  def __call__(self, x, noise_labels, class_labels):
×
302
    cblock = [self.model_channels * x for x in self.channel_mult]
×
303
    cnoise = (
×
304
      self.model_channels * self.channel_mult_noise
305
      if self.channel_mult_noise is not None
306
      else cblock[0]
307
    )
308
    cemb = (
×
309
      self.model_channels * self.channel_mult_emb
310
      if self.channel_mult_emb is not None
311
      else max(cblock)
312
    )
313

314
    if self.out_gain is None:
×
315
      out_gain = self.param("out_gain", jax.nn.initializers.zeros, (1,))
×
316
    else:
317
      out_gain = self.out_gain
×
318

319
    # Encoder
320
    enc = {}
×
321
    cout = self.img_channels + 1
×
322
    for level, channels in enumerate(cblock):
×
323
      res = self.img_resolution >> level
×
324
      if level == 0:
×
325
        cin = cout
×
326
        cout = channels
×
327
        enc[f"{res}x{res}_conv"] = MPConv(
×
328
          cin, cout, kernel_shape=(3, 3), name=f"enc_{res}x{res}_conv"
329
        )
330
      else:
331
        enc[f"{res}x{res}_down"] = Block(
×
332
          cout,
333
          cout,
334
          cemb,
335
          flavor="enc",
336
          resample_mode="down",
337
          name=f"enc_{res}x{res}_down",
338
          **self.block_kwargs,
339
        )
340
      for idx in range(self.num_blocks):
×
341
        cin = cout
×
342
        cout = channels
×
343
        enc[f"{res}x{res}_block{idx}"] = Block(
×
344
          cin,
345
          cout,
346
          cemb,
347
          flavor="enc",
348
          attention=(res in self.attn_resolutions),
349
          name=f"enc_{res}x{res}_block{idx}",
350
          **self.block_kwargs,
351
        )
352

353
    # Decoder
354
    dec = {}
×
355
    skips = [block.out_channels for block in enc.values()]
×
356
    for level, channels in reversed(list(enumerate(cblock))):
×
357
      res = self.img_resolution >> level
×
358
      if level == len(cblock) - 1:
×
359
        dec[f"{res}x{res}_in0"] = Block(
×
360
          cout,
361
          cout,
362
          cemb,
363
          flavor="dec",
364
          attention=True,
365
          name=f"dec_{res}x{res}_in0",
366
          **self.block_kwargs,
367
        )
368
        dec[f"{res}x{res}_in1"] = Block(
×
369
          cout,
370
          cout,
371
          cemb,
372
          flavor="dec",
373
          name=f"dec_{res}x{res}_in1",
374
          **self.block_kwargs,
375
        )
376
      else:
377
        dec[f"{res}x{res}_up"] = Block(
×
378
          cout,
379
          cout,
380
          cemb,
381
          flavor="dec",
382
          resample_mode="up",
383
          name=f"dec_{res}x{res}_up",
384
          **self.block_kwargs,
385
        )
386
      for idx in range(self.num_blocks + 1):
×
387
        cin = cout + skips.pop()
×
388
        cout = channels
×
389
        dec[f"{res}x{res}_block{idx}"] = Block(
×
390
          cin,
391
          cout,
392
          cemb,
393
          flavor="dec",
394
          attention=(res in self.attn_resolutions),
395
          name=f"dec_{res}x{res}_block{idx}",
396
          **self.block_kwargs,
397
        )
398

399
    # Embedding
400
    emb = MPConv(cnoise, cemb, kernel_shape=(), name="emb_noise")(
×
401
      MPFourier(cnoise, name="emb_fourier")(noise_labels)
402
    )
403
    if self.label_dim != 0:
×
404
      emb = mp_sum(
×
405
        emb,
406
        MPConv(self.label_dim, cemb, kernel_shape=(), name="emb_label")(
407
          class_labels * jnp.sqrt(class_labels.shape[1])
408
        ),
409
        t=self.label_balance,
410
      )
411
    emb = mp_silu(emb)
×
412

413
    # Encoder
414
    x = jax.lax.concatenate([x, jnp.ones_like(x[:, :1])], dimension=1)
×
415
    skips = []
×
416
    for name, block in enc.items():
×
417
      x = block(x) if "conv" in name else block(x, emb)
×
418
      skips.append(x)
×
419

420
    # Decoder
421
    for name, block in dec.items():
×
422
      if "block" in name:
×
423
        x = mp_cat(x, skips.pop(), t=self.concat_balance)
×
424
      x = block(x, emb)
×
425
    x = MPConv(cout, self.img_channels, kernel_shape=(3, 3), name="out_conv")(
×
426
      x, gain=out_gain
427
    )
428
    return x
×
429

430

431
class Precond(nn.Module):
×
432
  """Preconditioning and uncertainty estimation."""
433

434
  img_resolution: int  # Image resolution.
×
435
  img_channels: int  # Image channels.
×
436
  label_dim: int  # Class label dimensionality. 0 = unconditional.
×
437
  # **precond_kwargs
438
  use_fp16: bool = True  # Run the model at FP16 precision?
×
439
  sigma_data: float = 0.5  # Expected standard deviation of the training data.
×
440
  logvar_channels: int = 128  # Intermediate dimensionality for uncertainty estimation.
×
441
  return_logvar: bool = False
×
442
  # **unet_kwargs  # Keyword arguments for UNet.
443
  model_channels: int = 192  # Base multiplier for the number of channels.
×
444
  channel_mult: tuple = (
×
445
    1,
446
    2,
447
    3,
448
    4,
449
  )  # Per-resolution multipliers for the number of channels.
450
  channel_mult_noise: Any = None  # Multiplier for noise embedding dimensionality. None = select based on channel_mult.
×
451
  channel_mult_emb: Any = None  # Multiplier for final embedding dimensionality. None = select based on channel_mult.
×
452
  num_blocks: int = 3  # Number of residual blocks per resolution.
×
453
  attn_resolutions: tuple = (16, 8)  # List of resolutions with self-attention.
×
454
  label_balance: float = (
×
455
    0.5  # Balance between noise embedding (0) and class embedding (1).
456
  )
457
  concat_balance: float = 0.5  # Balance between skip connections (0) and main path (1).
×
458
  out_gain: float = 1.0
×
459
  unet_kwargs = {
×
460
    "model_channels": model_channels,
461
    "channel_mult": channel_mult,
462
    "channel_mult_noise": channel_mult_noise,
463
    "channel_mult_emb": channel_mult_emb,
464
    "num_blocks": num_blocks,
465
    "attn_resolutions": attn_resolutions,
466
    "label_balance": label_balance,
467
    "concat_balance": concat_balance,
468
    "out_gain": out_gain,
469
  }
470

471
  # **block_kwargs  # Keyword arguments for Block
472
  resample_filter: tuple = (1, 1)  # Resampling filter
×
473
  channels_per_head: int = 64  # Number of channels per attention head
×
474
  dropout: float = 0.0  # Dropout probability
×
475
  res_balance: float = 0.3  # Balance between main branch (0) and residual branch (1)
×
476
  attn_balance: float = 0.3  # Balance between main branch (0) and self-attention (1)
×
477
  clip_act: int = 256  # Clip output activations. None = do not clip
×
478
  out_gain: Any = None
×
479
  block_kwargs = {
×
480
    "resample_filter": resample_filter,
481
    "channels_per_head": channels_per_head,
482
    "dropout": dropout,
483
    "res_balance": res_balance,
484
    "attn_balance": attn_balance,
485
    "clip_act": clip_act,
486
  }
487

488
  @nn.compact
×
489
  def __call__(
×
490
    self,
491
    x,
492
    sigma,
493
    class_labels=None,
494
    force_fp32=False,
495
  ):
496
    x = jnp.float32(x)
×
497
    sigma = jnp.float32(sigma).reshape(-1, 1, 1, 1)
×
498
    class_labels = (
×
499
      None
500
      if self.label_dim == 0
501
      else jnp.zeros((1, self.label_dim), device=x.device)
502
      if class_labels is None
503
      else jnp.float32(class_labels).reshape(-1, self.label_dim)
504
    )
505
    dtype = jnp.float16 if (self.use_fp16 and not force_fp32) else jnp.float32
×
506

507
    # Preconditioning weights
508
    c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
×
509
    c_out = sigma * self.sigma_data / jnp.sqrt(sigma**2 + self.sigma_data**2)
×
510
    c_in = 1 / jnp.sqrt(self.sigma_data**2 + sigma**2)
×
511
    c_noise = jnp.log(sigma.flatten()) / 4
×
512

513
    # Run the model
514
    x_in = jnp.array(c_in * x, dtype=dtype)
×
515

516
    F_x = UNet(
×
517
      img_resolution=self.img_resolution,
518
      img_channels=self.img_channels,
519
      label_dim=self.label_dim,
520
      **self.unet_kwargs,
521
      **self.block_kwargs,
522
      name="unet",
523
    )(x_in, c_noise, class_labels)
524
    D_x = c_skip * x + c_out * jnp.float32(F_x)
×
525

526
    # Estimate uncertainty if requested
527
    if self.return_logvar:
×
528
      logvar = MPConv(self.logvar_channels, 1, kernel_shape=(), name="logvar_linear")(
×
529
        MPFourier(self.logvar_channels, name="logvar_fourier")(c_noise)
530
      ).reshape(-1, 1, 1, 1)
531
      return D_x, logvar  # u(sigma) in Equation 21
×
532
    return D_x
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc