• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ControlNet / MARLIN / 9011167809

09 May 2024 01:56AM CUT coverage: 65.763%. First build
9011167809

Pull #25

github

web-flow
Merge 7c160e431 into 23491494b
Pull Request #25: hotfix: add lightning dependancy for _cosine_scheduler_fn(..)

461 of 701 relevant lines covered (65.76%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.24
/src/marlin_pytorch/model/positional_embedding.py
1
import torch
1✔
2
from torch import Tensor, nn
1✔
3

4
from .modules import Shape
1✔
5

6

7
class PositionalEmbedding(nn.Module):
1✔
8

9
    def __init__(self, input_shape: Shape, dropout_rate: float = 0.5, trainable: bool = True):
1✔
10
        super().__init__()
1✔
11
        self.input_shape = input_shape
1✔
12
        self.emb = nn.Parameter(torch.zeros(1, *input_shape), requires_grad=trainable)
1✔
13
        self.use_dropout = dropout_rate is not None and dropout_rate != 0.
1✔
14
        if self.use_dropout:
1✔
15
            self.dropout = nn.Dropout(dropout_rate)
×
16

17
    def forward(self, x: Tensor) -> Tensor:
1✔
18
        x = x + self.emb
1✔
19
        if self.use_dropout:
1✔
20
            x = self.dropout(x)
×
21
        return x
1✔
22

23
    @property
1✔
24
    def trainable(self):
1✔
25
        return self.emb.requires_grad
×
26

27
    @trainable.setter
1✔
28
    def trainable(self, value: bool):
1✔
29
        self.emb.requires_grad = value
×
30

31

32
class SinCosPositionalEmbedding(PositionalEmbedding):
1✔
33

34
    def __init__(self, input_shape: Shape, dropout_rate: float = 0.5):
1✔
35
        super().__init__(input_shape, dropout_rate, trainable=False)
1✔
36
        self.emb.data = self.make_embedding().unsqueeze(0)
1✔
37

38
    def make_embedding(self) -> Tensor:
1✔
39
        n_position, d_hid = self.input_shape
1✔
40

41
        def get_position_angle_vec(position):
1✔
42
            return position / torch.tensor(10000).pow(
1✔
43
                2 * torch.div(torch.arange(d_hid), 2, rounding_mode='trunc') / d_hid)
44

45
        sinusoid_table = torch.stack([get_position_angle_vec(pos_i) for pos_i in range(n_position)], 0)
1✔
46
        sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])  # dim 2i
1✔
47
        sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])  # dim 2i+1
1✔
48

49
        return sinusoid_table.float()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc