• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgeberry / blayers / 17146046591

22 Aug 2025 04:25AM UTC coverage: 92.507% (-7.5%) from 100.0%
17146046591

Pull #23

github

georgeberry
update
Pull Request #23: beta 3

6 of 32 new or added lines in 1 file covered. (18.75%)

1 existing line in 1 file now uncovered.

321 of 347 relevant lines covered (92.51%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.39
/blayers/layers.py
1
"""
2
Implements Bayesian Layers using Jax and Numpyro.
3

4
Design:
5
  - There are three levels of complexity here: class-level, instance-level, and
6
    call-level
7
  - The class-level handles things like choosing generic model form and how to
8
    multiply coefficents with data. Defined by the ``class Layer(BLayer)`` def
9
    itself.
10
  - The instance-level handles specific distributions that fit into a generic
11
    model and the initial parameters for those distributions. Defined by
12
    creating an instance of the class: ``Layer(*args, **kwargs)``.
13
  - The call-level handles seeing a batch of data, sampling from the
14
    distributions defined on the class and multiplying coefficients and data to
15
    produce an output, works like ``result = Layer(*args, **kwargs)(data)``
16

17
Notation:
18
  - ``n``: observations in a batch
19
  - ``c``: number of categories of things for time, random effects, etc
20
  - ``d``: number of coefficients
21
  - ``l``: low rank dimension of low rank models
22
  - ``m``: embedding dimension
23
  - ``u``: units aka output dimension
24
"""
25

26
from abc import ABC, abstractmethod
1✔
27
from typing import Any
1✔
28

29
import jax
1✔
30
import jax.numpy as jnp
1✔
31
from numpyro import distributions, sample
1✔
32

33
from blayers._utils import add_trailing_dim
1✔
34

35
# ---- Matmul functions ------------------------------------------------------ #
36

37

38
def pairwise_interactions(x: jax.Array, z: jax.Array) -> jax.Array:
1✔
39
    """
40
    Compute all pairwise interactions between features in X and Y.
41

42
    Parameters:
43
        X: (n_samples, n_features1)
44
        Y: (n_samples, n_features2)
45

46
    Returns:
47
        interactions: (n_samples, n_features1 * n_features2)
48
    """
49

50
    n, d1 = x.shape
1✔
51
    _, d2 = z.shape
1✔
52
    return jnp.reshape(x[:, :, None] * z[:, None, :], (n, d1 * d2))
1✔
53

54

55
def _matmul_dot_product(x: jax.Array, beta: jax.Array) -> jax.Array:
1✔
56
    """Standard dot product between beta and x.
57

58
    Args:
59
        beta: Coefficient vector of shape `(d, u)`.
60
        x: Input matrix of shape `(n, d)`.
61

62
    Returns:
63
        jax.Array: Output of shape `(n, u)`.
64
    """
65
    return jnp.einsum("nd,du->nu", x, beta)
1✔
66

67

68
def _matmul_factorization_machine(x: jax.Array, theta: jax.Array) -> jax.Array:
1✔
69
    """Apply second-order factorization machine interaction.
70

71
    Based on Rendle (2010). Computes:
72

73
    .. math::
74
        0.5 * sum((xV)^2 - (x^2 V^2))
75

76
    Args:
77
        theta: Weight matrix of shape `(d, l, u)`.
78
        x: Input data of shape `(n, d)`.
79

80
    Returns:
81
        jax.Array: Output of shape `(n, u)`.
82
    """
83
    vx2 = jnp.einsum("nd,dlu->nlu", x, theta) ** 2
1✔
84
    v2x2 = jnp.einsum("nd,dlu->nlu", x**2, theta**2)
1✔
85
    return 0.5 * jnp.einsum("nlu->nu", vx2 - v2x2)
1✔
86

87

88
def _matmul_fm3(x: jax.Array, theta: jax.Array) -> jax.Array:
1✔
89
    """Apply second-order factorization machine interaction.
90

91
    Based on Rendle (2010). Computes:
92

93
    .. math::
94
        0.5 * sum((xV)^2 - (x^2 V^2))
95

96
    Args:
97
        theta: Weight matrix of shape `(d, l, u)`.
98
        x: Input data of shape `(n, d)`.
99

100
    Returns:
101
        jax.Array: Output of shape `(n, u)`.
102
    """
103
    # x: (n_features,)
104
    # E: (n_features, k)  embedding matrix
105
    linear_sum = jnp.einsum("nd,dlu->nlu", x, theta)  # jnp.dot(x, theta)
1✔
106
    square_sum = jnp.einsum(
1✔
107
        "nd,dlu->nlu", x**2, theta**2
108
    )  # jnp.dot(x**2, theta**2)
109
    cube_sum = jnp.einsum(
1✔
110
        "nd,dlu->nlu", x**3, theta**3
111
    )  # jnp.dot(x**3, theta**3)
112

113
    term = (
1✔
114
        linear_sum**3 - 3.0 * square_sum * linear_sum + 2.0 * cube_sum
115
    ) / 6.0
116
    return jnp.einsum("nlu->nu", term)  # scalar
1✔
117

118

119
def _matmul_uv_decomp(
1✔
120
    theta1: jax.Array,
121
    theta2: jax.Array,
122
    x: jax.Array,
123
    z: jax.Array,
124
) -> jax.Array:
125
    """Implements low rank multiplication.
126

127
    According to ChatGPT this is a "factorized bilinear interaction".
128
    Basically, you just need to project x and z down to a common number of
129
    low rank terms and then just multiply those terms.
130

131
    This is equivalent to a UV decomposition where you use n=low_rank_dim
132
    on the columns of the U/V matrices.
133

134
    Args:
135
        theta1: Weight matrix of shape `(d1, l, u)`.
136
        theta2: Weight matrix of shape `(d2, l, u)`.
137
        x: Input data of shape `(n, d1)`.
138
        z: Input data of shape `(n, d2)`.
139

140
    Returns:
141
        jax.Array: Output of shape `(n, u)`.
142
    """
143
    xb = jnp.einsum("nd,dlu->nlu", x, theta1)
1✔
144
    zb = jnp.einsum("nd,dlu->nlu", z, theta2)
1✔
145
    return jnp.einsum("nlu->nu", xb * zb)
1✔
146

147

148
def _matmul_randomwalk(
1✔
149
    theta: jax.Array,
150
    idx: jax.Array,
151
) -> jax.Array:
152
    """Vertical cumsum and then picks out index.
153

154
    We do a vertical cumsum of `theta` across `m` embedding dimensions, and then
155
    pick out the index.
156

157
    Args:
158
        theta: Weight matrix of shape `(c, m)`
159
        idx: Integer indexes of shape `(n, 1)` or `(n,)` with indexes up to `c`
160

161
    Returns:
162
        jax.Array: Output of shape `(n, m)`
163

164
    """
165
    theta_cumsum = jnp.cumsum(theta, axis=0)
1✔
166
    idx_flat = idx.squeeze().astype(jnp.int32)
1✔
167
    return theta_cumsum[idx_flat]
1✔
168

169

170
def _matmul_interaction(
1✔
171
    beta: jax.Array,
172
    x: jax.Array,
173
    z: jax.Array,
174
) -> jax.Array:
175
    """Full interaction between `x` and `z`.
176

177
    Args:
178
        beta: Weight matrix for each interaction between `x` and `z`.
179
        x: First feature matrix.
180
        z: Second feature matrix.
181

182
    Returns:
183
        jax.Array
184

185
    """
186

187
    # thanks chat GPT
188
    interactions = pairwise_interactions(x, z)
1✔
189

190
    return jnp.einsum("nd,du->nu", interactions, beta)
1✔
191

192

193
# ---- Classes --------------------------------------------------------------- #
194

195

196
class BLayer(ABC):
1✔
197
    """Abstract base class for Bayesian layers. Lays out an interface."""
198

199
    @abstractmethod
1✔
200
    def __init__(self, *args: Any) -> None:
1✔
201
        """Initialize layer parameters."""
202

203
    @abstractmethod
1✔
204
    def __call__(self, *args: Any) -> Any:
1✔
205
        """
206
        Run the layer's forward pass.
207

208
        Args:
209
            name: Name scope for sampled variables. Note due to mypy stuff we
210
                  only write the `name` arg explicitly in subclass.
211
            *args: Inputs to the layer.
212

213
        Returns:
214
            jax.Array: The result of the forward computation.
215
        """
216

217

218
class AdaptiveLayer(BLayer):
1✔
219
    """Bayesian layer with adaptive prior using hierarchical modeling.
220

221
    Generates coefficients from the hierarchical model
222

223
    .. math::
224
        \\lambda \\sim HalfNormal(1.)
225

226
    .. math::
227
        \\beta \\sim Normal(0., \\lambda)
228
    """
229

230
    def __init__(
1✔
231
        self,
232
        lmbda_dist: distributions.Distribution = distributions.HalfNormal,
233
        coef_dist: distributions.Distribution = distributions.Normal,
234
        coef_kwargs: dict[str, float] = {"loc": 0.0},
235
        lmbda_kwargs: dict[str, float] = {"scale": 1.0},
236
        units: int = 1,
237
    ):
238
        """
239
        Args:
240
            lmbda_dist: NumPyro distribution class for the scale (λ) of the
241
                prior.
242
            coef_dist: NumPyro distribution class for the coefficient prior.
243
            coef_kwargs: Parameters for the prior distribution.
244
            lmbda_kwargs: Parameters for the scale distribution.
245
            units: The number of outputs
246
            dependent_outputs: For multi-output models whether to treat the outputs as dependent. By deafult they are independent.
247
        """
248
        self.lmbda_dist = lmbda_dist
1✔
249
        self.coef_dist = coef_dist
1✔
250
        self.coef_kwargs = coef_kwargs
1✔
251
        self.lmbda_kwargs = lmbda_kwargs
1✔
252
        self.units = units
1✔
253

254
    def __call__(
1✔
255
        self,
256
        name: str,
257
        x: jax.Array,
258
    ) -> jax.Array:
259
        """
260
        Forward pass with adaptive prior on coefficients.
261

262
        Args:
263
            name: Variable name scope.
264
            x: Input data array of shape (n, d, u).
265

266
        Returns:
267
            jax.Array: Output array of shape (n, u).
268
        """
269

270
        x = add_trailing_dim(x)
1✔
271
        input_shape = x.shape[1]
1✔
272

273
        # sampling block
274
        lmbda = sample(
1✔
275
            name=f"{self.__class__.__name__}_{name}_lmbda",
276
            fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]),
277
        )
278
        beta = sample(
1✔
279
            name=f"{self.__class__.__name__}_{name}_beta",
280
            fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand(
281
                [input_shape, self.units]
282
            ),
283
        )
284

285
        # matmul and return
286
        return _matmul_dot_product(x, beta)
1✔
287

288

289
class FixedPriorLayer(BLayer):
1✔
290
    """Bayesian layer with a fixed prior distribution over coefficients.
291

292
    Generates coefficients from the model
293

294
    .. math::
295

296
        \\beta \\sim Normal(0., 1.)
297
    """
298

299
    def __init__(
1✔
300
        self,
301
        coef_dist: distributions.Distribution = distributions.Normal,
302
        coef_kwargs: dict[str, float] = {"loc": 0.0, "scale": 1.0},
303
        units: int = 1,
304
    ):
305
        """
306
        Args:
307
            coef_dist: NumPyro distribution class for the coefficients.
308
            coef_kwargs: Parameters to initialize the prior distribution.
309
        """
310
        self.coef_dist = coef_dist
1✔
311
        self.coef_kwargs = coef_kwargs
1✔
312
        self.units = units
1✔
313

314
    def __call__(
1✔
315
        self,
316
        name: str,
317
        x: jax.Array,
318
    ) -> jax.Array:
319
        """
320
        Forward pass with fixed prior.
321

322
        Args:
323
            name: Variable name prefix.
324
            x: Input data array of shape (n, d).
325

326
        Returns:
327
            jax.Array: Output array of shape (n, u).
328
        """
329

330
        x = add_trailing_dim(x)
1✔
331
        input_shape = x.shape[1]
1✔
332

333
        # sampling block
334
        beta = sample(
1✔
335
            name=f"{self.__class__.__name__}_{name}_beta",
336
            fn=self.coef_dist(**self.coef_kwargs).expand(
337
                [input_shape, self.units]
338
            ),
339
        )
340
        # matmul and return
341
        return _matmul_dot_product(x, beta)
1✔
342

343

344
class InterceptLayer(BLayer):
1✔
345
    """Bayesian layer with a fixed prior distribution over coefficients.
346

347
    Generates coefficients from the model
348

349
    .. math::
350

351
        \\beta \\sim Normal(0., 1.)
352
    """
353

354
    def __init__(
1✔
355
        self,
356
        coef_dist: distributions.Distribution = distributions.Normal,
357
        coef_kwargs: dict[str, float] = {"loc": 0.0, "scale": 1.0},
358
        units: int = 1,
359
    ):
360
        """
361
        Args:
362
            ``coef_dist``: NumPyro distribution class for the coefficients.
363
            ``coef_kwargs``: Parameters to initialize the prior distribution.
364
        """
365
        self.coef_dist = coef_dist
1✔
366
        self.coef_kwargs = coef_kwargs
1✔
367
        self.units = units
1✔
368

369
    def __call__(
1✔
370
        self,
371
        name: str,
372
    ) -> jax.Array:
373
        """
374
        Forward pass with fixed prior.
375

376
        Args:
377
            name: Variable name prefix.
378

379
        Returns:
380
            jax.Array: Output array of shape (1, u).
381
        """
382

383
        # sampling block
384
        beta = sample(
1✔
385
            name=f"{self.__class__.__name__}_{name}_beta",
386
            fn=self.coef_dist(**self.coef_kwargs).expand([1, self.units]),
387
        )
388
        return beta
1✔
389

390

391
class EmbeddingLayer(BLayer):
1✔
392
    """Bayesian embedding layer for sparse categorical features."""
393

394
    def __init__(
1✔
395
        self,
396
        lmbda_dist: distributions.Distribution = distributions.HalfNormal,
397
        coef_dist: distributions.Distribution = distributions.Normal,
398
        coef_kwargs: dict[str, float] = {"loc": 0.0},
399
        lmbda_kwargs: dict[str, float] = {"scale": 1.0},
400
        units: int = 1,
401
    ):
402
        """
403
        Args:
404
            num_embeddings: Total number of discrete embedding entries.
405
            embedding_dim: Dimensionality of each embedding vector.
406
            coef_dist: Prior distribution for embedding weights.
407
            coef_kwargs: Parameters for the prior distribution.
408
        """
409
        self.lmbda_dist = lmbda_dist
1✔
410
        self.coef_dist = coef_dist
1✔
411
        self.coef_kwargs = coef_kwargs
1✔
412
        self.lmbda_kwargs = lmbda_kwargs
1✔
413
        self.units = units
1✔
414

415
    def __call__(
1✔
416
        self,
417
        name: str,
418
        x: jax.Array,
419
        num_categories: int,
420
        embedding_dim: int,
421
    ) -> jax.Array:
422
        """
423
        Forward pass through embedding lookup.
424

425
        Args:
426
            name: Variable name scope.
427
            x: Integer indices of shape (n,) indicating embeddings to use.
428
            num_categories: The number of distinct things getting an embedding
429
            embedding_dim: The size of each embedding, e.g. 2, 4, 8, etc.
430

431
        Returns:
432
            jax.Array: Embedding vectors of shape (n, m).
433
        """
434

435
        # sampling block
436
        lmbda = sample(
1✔
437
            name=f"{self.__class__.__name__}_{name}_lmbda",
438
            fn=self.lmbda_dist(**self.lmbda_kwargs),
439
        )
440
        beta = sample(
1✔
441
            name=f"{self.__class__.__name__}_{name}_beta",
442
            fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand(
443
                [num_categories, embedding_dim]
444
            ),
445
        )
446
        # matmul and return
447
        return beta[x.squeeze()]
1✔
448

449

450
class RandomEffectsLayer(BLayer):
1✔
451
    """Exactly like the EmbeddingLayer but with ``embedding_dim=1``."""
452

453
    def __init__(
1✔
454
        self,
455
        lmbda_dist: distributions.Distribution = distributions.HalfNormal,
456
        coef_dist: distributions.Distribution = distributions.Normal,
457
        coef_kwargs: dict[str, float] = {"loc": 0.0},
458
        lmbda_kwargs: dict[str, float] = {"scale": 1.0},
459
        units: int = 1,
460
    ):
461
        """
462
        Args:
463
            ``num_embeddings``: Total number of discrete embedding entries.
464
            ``embedding_dim``: Dimensionality of each embedding vector.
465
            ``coef_dist``: Prior distribution for embedding weights.
466
            ``coef_kwargs``: Parameters for the prior distribution.
467
        """
468
        self.lmbda_dist = lmbda_dist
1✔
469
        self.coef_dist = coef_dist
1✔
470
        self.coef_kwargs = coef_kwargs
1✔
471
        self.lmbda_kwargs = lmbda_kwargs
1✔
472
        self.units = units
1✔
473

474
    def __call__(
1✔
475
        self,
476
        name: str,
477
        x: jax.Array,
478
        num_categories: int,
479
    ) -> jax.Array:
480
        """
481
        Forward pass through embedding lookup.
482

483
        Args:
484
            name: Variable name scope.
485
            x: Integer indices of shape (n,) indicating embeddings to use.
486
            num_categories: The number of distinct things getting an embedding
487

488
        Returns:
489
            jax.Array: Embedding vectors of shape (n, embedding_dim).
490
        """
491

492
        # sampling block
493
        lmbda = sample(
1✔
494
            name=f"{self.__class__.__name__}_{name}_lmbda",
495
            fn=self.lmbda_dist(**self.lmbda_kwargs),
496
        )
497
        beta = sample(
1✔
498
            name=f"{self.__class__.__name__}_{name}_beta",
499
            fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand(
500
                [num_categories, 1]
501
            ),
502
        )
503
        return beta[x.squeeze()]
1✔
504

505

506
class FMLayer(BLayer):
1✔
507
    """Bayesian factorization machine layer with adaptive priors.
508

509
    Generates coefficients from the hierarchical model
510

511
    .. math::
512

513
        \\lambda \\sim HalfNormal(1.)
514

515
    .. math::
516

517
        \\beta \\sim Normal(0., \\lambda)
518

519
    The shape of ``beta`` is ``(j, l)``, where ``j`` is the number
520
    if input covariates and ``l`` is the low rank dim.
521

522
    Then performs matrix multiplication using the formula in `Rendle (2010) <https://jame-zhang.github.io/assets/algo/Factorization-Machines-Rendle2010.pdf>`_.
523
    """
524

525
    def __init__(
1✔
526
        self,
527
        lmbda_dist: distributions.Distribution = distributions.HalfNormal,
528
        coef_dist: distributions.Distribution = distributions.Normal,
529
        coef_kwargs: dict[str, float] = {"loc": 0.0},
530
        lmbda_kwargs: dict[str, float] = {"scale": 1.0},
531
        units: int = 1,
532
    ):
533
        """
534
        Args:
535
            lmbda_dist: Distribution for scaling factor λ.
536
            coef_dist: Prior for beta parameters.
537
            coef_kwargs: Arguments for prior distribution.
538
            lmbda_kwargs: Arguments for λ distribution.
539
            low_rank_dim: Dimensionality of low-rank approximation.
540
        """
541
        self.lmbda_dist = lmbda_dist
1✔
542
        self.coef_dist = coef_dist
1✔
543
        self.coef_kwargs = coef_kwargs
1✔
544
        self.lmbda_kwargs = lmbda_kwargs
1✔
545
        self.units = units
1✔
546

547
    def __call__(
1✔
548
        self,
549
        name: str,
550
        x: jax.Array,
551
        low_rank_dim: int,
552
    ) -> jax.Array:
553
        """
554
        Forward pass through the factorization machine layer.
555

556
        Args:
557
            name: Variable name scope.
558
            x: Input matrix of shape (n, d).
559

560
        Returns:
561
            jax.Array: Output array of shape (n,).
562
        """
563
        # get shapes and reshape if necessary
564
        x = add_trailing_dim(x)
1✔
565
        input_shape = x.shape[1]
1✔
566

567
        # sampling block
568
        lmbda = sample(
1✔
569
            name=f"{self.__class__.__name__}_{name}_lmbda",
570
            fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]),
571
        )
572
        theta = sample(
1✔
573
            name=f"{self.__class__.__name__}_{name}_theta",
574
            fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand(
575
                [input_shape, low_rank_dim, self.units]
576
            ),
577
        )
578
        # matmul and return
579
        return _matmul_factorization_machine(x, theta)
1✔
580

581

582
class FM3Layer(BLayer):
1✔
583
    """Order 3 FM. See `Blondel et al 2016 <https://proceedings.neurips.cc/paper/2016/file/158fc2ddd52ec2cf54d3c161f2dd6517-Paper.pdf>`_."""
584

585
    def __init__(
1✔
586
        self,
587
        lmbda_dist: distributions.Distribution = distributions.HalfNormal,
588
        coef_dist: distributions.Distribution = distributions.Normal,
589
        coef_kwargs: dict[str, float] = {"loc": 0.0},
590
        lmbda_kwargs: dict[str, float] = {"scale": 1.0},
591
        units: int = 1,
592
    ):
593
        """
594
        Args:
595
            lmbda_dist: Distribution for scaling factor λ.
596
            coef_dist: Prior for beta parameters.
597
            coef_kwargs: Arguments for prior distribution.
598
            lmbda_kwargs: Arguments for λ distribution.
599
            low_rank_dim: Dimensionality of low-rank approximation.
600
        """
601
        self.lmbda_dist = lmbda_dist
1✔
602
        self.coef_dist = coef_dist
1✔
603
        self.coef_kwargs = coef_kwargs
1✔
604
        self.lmbda_kwargs = lmbda_kwargs
1✔
605
        self.units = units
1✔
606

607
    def __call__(
1✔
608
        self,
609
        name: str,
610
        x: jax.Array,
611
        low_rank_dim: int,
612
    ) -> jax.Array:
613
        """
614
        Forward pass through the factorization machine layer.
615

616
        Args:
617
            name: Variable name scope.
618
            x: Input matrix of shape (n, d).
619

620
        Returns:
621
            jax.Array: Output array of shape (n,).
622
        """
623
        # get shapes and reshape if necessary
624
        x = add_trailing_dim(x)
1✔
625
        input_shape = x.shape[1]
1✔
626

627
        # sampling block
628
        lmbda = sample(
1✔
629
            name=f"{self.__class__.__name__}_{name}_lmbda",
630
            fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]),
631
        )
632
        theta = sample(
1✔
633
            name=f"{self.__class__.__name__}_{name}_theta",
634
            fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand(
635
                [input_shape, low_rank_dim, self.units]
636
            ),
637
        )
638
        # matmul and return
639
        return _matmul_fm3(x, theta)
1✔
640

641

642
class LowRankInteractionLayer(BLayer):
1✔
643
    """Takes two sets of features and learns a low-rank interaction matrix."""
644

645
    def __init__(
1✔
646
        self,
647
        lmbda_dist: distributions.Distribution = distributions.HalfNormal,
648
        coef_dist: distributions.Distribution = distributions.Normal,
649
        coef_kwargs: dict[str, float] = {"loc": 0.0},
650
        lmbda_kwargs: dict[str, float] = {"scale": 1.0},
651
        units: int = 1,
652
    ):
653
        self.lmbda_dist = lmbda_dist
1✔
654
        self.coef_dist = coef_dist
1✔
655
        self.coef_kwargs = coef_kwargs
1✔
656
        self.lmbda_kwargs = lmbda_kwargs
1✔
657
        self.units = units
1✔
658

659
    def __call__(
1✔
660
        self,
661
        name: str,
662
        x: jax.Array,
663
        z: jax.Array,
664
        low_rank_dim: int,
665
    ) -> jax.Array:
666
        # get shapes and reshape if necessary
667
        x = add_trailing_dim(x)
1✔
668
        z = add_trailing_dim(z)
1✔
669
        input_shape1 = x.shape[1]
1✔
670
        input_shape2 = z.shape[1]
1✔
671

672
        # sampling block
673
        lmbda1 = sample(
1✔
674
            name=f"{self.__class__.__name__}_{name}_lmbda1",
675
            fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]),
676
        )
677
        theta1 = sample(
1✔
678
            name=f"{self.__class__.__name__}_{name}_theta1",
679
            fn=self.coef_dist(scale=lmbda1, **self.coef_kwargs).expand(
680
                [input_shape1, low_rank_dim, self.units]
681
            ),
682
        )
683
        lmbda2 = sample(
1✔
684
            name=f"{self.__class__.__name__}_{name}_lmbda2",
685
            fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]),
686
        )
687
        theta2 = sample(
1✔
688
            name=f"{self.__class__.__name__}_{name}_theta2",
689
            fn=self.coef_dist(scale=lmbda2, **self.coef_kwargs).expand(
690
                [input_shape2, low_rank_dim, self.units]
691
            ),
692
        )
693
        return _matmul_uv_decomp(theta1, theta2, x, z)
1✔
694

695

696
class RandomWalkLayer(BLayer):
1✔
697
    """Random walk of embedding dim ``m``, defaults to Gaussian walk."""
698

699
    def __init__(
1✔
700
        self,
701
        lmbda_dist: distributions.Distribution = distributions.HalfNormal,
702
        coef_dist: distributions.Distribution = distributions.Normal,
703
        coef_kwargs: dict[str, float] = {"loc": 0.0},
704
        lmbda_kwargs: dict[str, float] = {"scale": 1.0},
705
    ):
706
        self.lmbda_dist = lmbda_dist
1✔
707
        self.coef_dist = coef_dist
1✔
708
        self.coef_kwargs = coef_kwargs
1✔
709
        self.lmbda_kwargs = lmbda_kwargs
1✔
710

711
    def __call__(
1✔
712
        self,
713
        name: str,
714
        x: jax.Array,
715
        num_categories: int,
716
        embedding_dim: int,
717
    ) -> jax.Array:
718
        """ """
719

720
        # sampling block
721
        lmbda = sample(
1✔
722
            name=f"{self.__class__.__name__}_{name}_lmbda",
723
            fn=self.lmbda_dist(**self.lmbda_kwargs),
724
        )
725
        theta = sample(
1✔
726
            name=f"{self.__class__.__name__}_{name}_theta",
727
            fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand(
728
                [
729
                    num_categories,
730
                    embedding_dim,
731
                ]
732
            ),
733
        )
734
        # matmul and return
735
        return _matmul_randomwalk(theta, x)
1✔
736

737

738
class InteractionLayer(BLayer):
1✔
739
    """Computes every interaction coefficient between two sets of inputs."""
740

741
    def __init__(
1✔
742
        self,
743
        lmbda_dist: distributions.Distribution = distributions.HalfNormal,
744
        coef_dist: distributions.Distribution = distributions.Normal,
745
        coef_kwargs: dict[str, float] = {"loc": 0.0},
746
        lmbda_kwargs: dict[str, float] = {"scale": 1.0},
747
        units: int = 1,
748
    ):
749
        self.lmbda_dist = lmbda_dist
1✔
750
        self.coef_dist = coef_dist
1✔
751
        self.coef_kwargs = coef_kwargs
1✔
752
        self.lmbda_kwargs = lmbda_kwargs
1✔
753
        self.units = units
1✔
754

755
    def __call__(
1✔
756
        self,
757
        name: str,
758
        x: jax.Array,
759
        z: jax.Array,
760
    ) -> jax.Array:
761
        # get shapes and reshape if necessary
762
        x = add_trailing_dim(x)
1✔
763
        z = add_trailing_dim(z)
1✔
764
        input_shape1 = x.shape[1]
1✔
765
        input_shape2 = z.shape[1]
1✔
766

767
        # sampling block
768
        lmbda = sample(
1✔
769
            name=f"{self.__class__.__name__}_{name}_lmbda1",
770
            fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]),
771
        )
772
        beta = sample(
1✔
773
            name=f"{self.__class__.__name__}_{name}_beta1",
774
            fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand(
775
                [input_shape1 * input_shape2, self.units]
776
            ),
777
        )
778

779
        return _matmul_interaction(beta, x, z)
1✔
780

781

782
class BilinearLayer(BLayer):
1✔
783
    """Bayesian bilinear interaction layer: computes x^T W z."""
784

785
    def __init__(
1✔
786
        self,
787
        lmbda_dist: distributions.Distribution = distributions.HalfNormal,
788
        coef_dist: distributions.Distribution = distributions.Normal,
789
        coef_kwargs: dict[str, float] = {"loc": 0.0},
790
        lmbda_kwargs: dict[str, float] = {"scale": 1.0},
791
        units: int = 1,
792
    ):
793
        """
794
        Args:
795
            lmbda_dist: prior on scale of coefficients
796
            coef_dist: distribution for coefficients
797
            coef_kwargs: kwargs for coef distribution
798
            lmbda_kwargs: kwargs for scale prior
799
            units: number of output dimensions
800
        """
NEW
801
        self.lmbda_dist = lmbda_dist
×
NEW
802
        self.coef_dist = coef_dist
×
NEW
803
        self.coef_kwargs = coef_kwargs
×
NEW
804
        self.lmbda_kwargs = lmbda_kwargs
×
NEW
805
        self.units = units
×
806

807
    def __call__(self, name: str, x: jax.Array, z: jax.Array) -> jax.Array:
1✔
808
        # ensure inputs are [batch, dim]
NEW
809
        x = add_trailing_dim(x)
×
NEW
810
        z = add_trailing_dim(z)
×
NEW
811
        input_shape1, input_shape2 = x.shape[1], z.shape[1]
×
812

813
        # sample coefficient scales
NEW
814
        lmbda = sample(
×
815
            name=f"{self.__class__.__name__}_{name}_lmbda",
816
            fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]),
817
        )
818
        # full W: [input_shape1, input_shape2, units]
NEW
819
        W = sample(
×
820
            name=f"{self.__class__.__name__}_{name}_W",
821
            fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand(
822
                [input_shape1, input_shape2, self.units]
823
            ),
824
        )
825
        # bilinear form: x^T W z for each unit
NEW
826
        return jnp.einsum("bi,ijc,bj->bc", x, W, z)
×
827

828

829
class LowRankBilinearLayer(BLayer):
1✔
830
    """Bayesian bilinear interaction layer: computes x^T W z. W low rank."""
831

832
    def __init__(
1✔
833
        self,
834
        lmbda_dist: distributions.Distribution = distributions.HalfNormal,
835
        coef_dist: distributions.Distribution = distributions.Normal,
836
        coef_kwargs: dict[str, float] = {"loc": 0.0},
837
        lmbda_kwargs: dict[str, float] = {"scale": 1.0},
838
        units: int = 1,
839
    ):
840
        """
841
        Args:
842
            lmbda_dist: prior on scale of coefficients
843
            coef_dist: distribution for coefficients
844
            coef_kwargs: kwargs for coef distribution
845
            lmbda_kwargs: kwargs for scale prior
846
            units: number of output dimensions
847
        """
NEW
848
        self.lmbda_dist = lmbda_dist
×
NEW
849
        self.coef_dist = coef_dist
×
NEW
850
        self.coef_kwargs = coef_kwargs
×
NEW
851
        self.lmbda_kwargs = lmbda_kwargs
×
NEW
852
        self.units = units
×
853

854
    def __call__(
1✔
855
        self, name: str, x: jax.Array, z: jax.Array, low_rank_dim: int
856
    ) -> jax.Array:
857
        # ensure inputs are [batch, dim]
NEW
858
        x = add_trailing_dim(x)
×
NEW
859
        z = add_trailing_dim(z)
×
NEW
860
        input_shape1, input_shape2 = x.shape[1], z.shape[1]
×
861

862
        # sample coefficient scales
NEW
863
        lmbda = sample(
×
864
            name=f"{self.__class__.__name__}_{name}_lmbda",
865
            fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]),
866
        )
867

NEW
868
        A = sample(
×
869
            name=f"{self.__class__.__name__}_{name}_A",
870
            fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand(
871
                [input_shape1, low_rank_dim, self.units]
872
            ),
873
        )
NEW
874
        B = sample(
×
875
            name=f"{self.__class__.__name__}_{name}_B",
876
            fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand(
877
                [input_shape2, low_rank_dim, self.units]
878
            ),
879
        )
880
        # project x and z into rank-r space, then take dot product
NEW
881
        x_proj = jnp.einsum("bi,irk->brk", x, A)  # [batch, rank, units]
×
NEW
882
        z_proj = jnp.einsum("bj,jrk->brk", z, B)  # [batch, rank, units]
×
NEW
UNCOV
883
        out = jnp.sum(x_proj * z_proj, axis=1)  # [batch, units]
×
884

NEW
885
        return out
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc