• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

skinniderlab / CLM / 20309855018

17 Dec 2025 04:27PM UTC coverage: 41.938% (-21.7%) from 63.664%
20309855018

Pull #276

github

web-flow
Merge 5453ac0d2 into 98d7c449a
Pull Request #276: updated black version in pre-commit

30 of 1579 new or added lines in 20 files covered. (1.9%)

7 existing lines in 4 files now uncovered.

1883 of 4490 relevant lines covered (41.94%)

0.42 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/clm/module_library/sequence_model.py
NEW
1
from functools import partial
×
NEW
2
from typing import Sequence, Mapping
×
NEW
3
import torch
×
NEW
4
import torch.nn as nn
×
NEW
5
from einops import rearrange
×
6

NEW
7
from .sequence_residual_block import SequenceResidualBlock
×
NEW
8
from .sequence_module import SequenceModule
×
NEW
9
from .util_modules import Normalization, DropoutNd
×
10

11

NEW
12
def is_list(x):
×
NEW
13
    return isinstance(x, Sequence) and not isinstance(x, str)
×
14

15

NEW
16
def is_dict(x):
×
NEW
17
    return isinstance(x, Mapping)
×
18

19

NEW
20
def to_dict(x, recursive=True):
×
21
    """Convert Sequence or Mapping object to dict
22

23
    lists get converted to {0: x[0], 1: x[1], ...}
24
    """
NEW
25
    if is_list(x):
×
NEW
26
        x = {i: v for i, v in enumerate(x)}
×
NEW
27
    if is_dict(x):
×
NEW
28
        if recursive:
×
NEW
29
            return {k: to_dict(v, recursive=recursive) for k, v in x.items()}
×
30
        else:
NEW
31
            return dict(x)
×
32
    else:
NEW
33
        return x
×
34

35

NEW
36
def to_list(x, recursive=False):
×
37
    """Convert an object to list.
38

39
    If Sequence (e.g. list, tuple, Listconfig): just return it
40

41
    Special case: If non-recursive and not a list, wrap in list
42
    """
NEW
43
    if is_list(x):
×
NEW
44
        if recursive:
×
NEW
45
            return [to_list(_x) for _x in x]
×
46
        else:
NEW
47
            return list(x)
×
48
    else:
NEW
49
        if recursive:
×
NEW
50
            return x
×
51
        else:
NEW
52
            return [x]
×
53

54

NEW
55
class SequenceModel(SequenceModule):
×
NEW
56
    def __init__(
×
57
        self,
58
        d_model,  # Resize input (useful for deep models with residuals)
59
        n_layers=1,  # Number of layers
60
        transposed=False,  # Transpose inputs so each layer receives (batch, dim, length)
61
        dropout=0.0,  # Dropout parameter applied on every residual and every layer
62
        tie_dropout=False,  # Tie dropout mask across sequence like nn.Dropout1d/nn.Dropout2d
63
        prenorm=True,  # Pre-norm vs. post-norm
64
        n_repeat=1,  # Each layer is repeated n times per stage before applying pooling
65
        layer=None,  # Layer config, must be specified
66
        # residual=None,  # Residual config
67
        residual="R",  # Residual config  # changed the default value from None to "R"
68
        # norm=None,  # Normalization config (e.g. layer vs batch)
69
        norm="layer",  # Normalization config (e.g. layer vs batch) # changed the default value from None to "layer"
70
        pool=None,  # Config for pooling layer per stage
71
        # track_norms=True,  # Log norms of each layer output; changed the default value from True to False
72
        track_norms=False,  # Log norms of each layer output; changed the default value from True to False
73
        dropinp=0.0,  # Input dropout
74
    ):
NEW
75
        super().__init__()
×
76
        # Save arguments needed for forward pass
NEW
77
        self.d_model = d_model
×
NEW
78
        self.transposed = transposed
×
NEW
79
        self.track_norms = track_norms
×
80

81
        # Input dropout (not really used)
NEW
82
        dropout_fn = (
×
83
            partial(DropoutNd, transposed=self.transposed)
84
            if tie_dropout
85
            else nn.Dropout
86
        )
NEW
87
        self.drop = dropout_fn(dropinp) if dropinp > 0.0 else nn.Identity()
×
88

NEW
89
        layer = to_list(layer, recursive=False)
×
90

91
        # Some special arguments are passed into each layer
NEW
92
        for _layer in layer:
×
93
            # If layers don't specify dropout, add it
NEW
94
            if _layer.get("dropout", None) is None:
×
NEW
95
                _layer["dropout"] = dropout
×
96
            # Ensure all layers are shaped the same way
NEW
97
            _layer["transposed"] = transposed
×
98

99
        # Duplicate layers
NEW
100
        layers = layer * n_layers * n_repeat
×
101

102
        # Instantiate layers
NEW
103
        _layers = []
×
NEW
104
        d = d_model
×
NEW
105
        for layer_idx, layer in enumerate(layers):
×
106
            # Pool at the end of every n_repeat blocks
NEW
107
            pool_cfg = pool if (layer_idx + 1) % n_repeat == 0 else None
×
NEW
108
            block = SequenceResidualBlock(
×
109
                d,
110
                layer_idx + 1,
111
                prenorm=prenorm,
112
                dropout=dropout,
113
                tie_dropout=tie_dropout,
114
                transposed=transposed,
115
                layer_config=layer,
116
                residual=residual,
117
                norm=norm,
118
                pool=pool_cfg,
119
            )
NEW
120
            _layers.append(block)
×
NEW
121
            d = block.d_output
×
122

NEW
123
        self.d_output = d
×
NEW
124
        self.layers = nn.ModuleList(_layers)
×
NEW
125
        if prenorm:
×
NEW
126
            if norm is None:
×
NEW
127
                self.norm = None
×
NEW
128
            elif isinstance(norm, str):
×
NEW
129
                self.norm = Normalization(
×
130
                    self.d_output, transposed=self.transposed, _name_=norm
131
                )
132
            else:
NEW
133
                self.norm = Normalization(
×
134
                    self.d_output, transposed=self.transposed, **norm
135
                )
136
        else:
NEW
137
            self.norm = nn.Identity()
×
138

NEW
139
    def forward(self, inputs, *args, state=None, **kwargs):
×
140
        """Inputs assumed to be (batch, sequence, dim)"""
NEW
141
        if self.transposed:
×
NEW
142
            inputs = rearrange(inputs, "b ... d -> b d ...")
×
NEW
143
        inputs = self.drop(inputs)
×
144

145
        # Track norms
NEW
146
        if self.track_norms:
×
NEW
147
            output_norms = [torch.mean(inputs.detach() ** 2)]
×
148

149
        # Apply layers
NEW
150
        outputs = inputs
×
NEW
151
        prev_states = [None] * len(self.layers) if state is None else state
×
NEW
152
        next_states = []
×
NEW
153
        for layer, prev_state in zip(self.layers, prev_states):
×
NEW
154
            outputs, state = layer(outputs, *args, state=prev_state, **kwargs)
×
NEW
155
            next_states.append(state)
×
NEW
156
            if self.track_norms:
×
NEW
157
                output_norms.append(torch.mean(outputs.detach() ** 2))
×
NEW
158
        if self.norm is not None:
×
NEW
159
            outputs = self.norm(outputs)
×
160

NEW
161
        if self.transposed:
×
NEW
162
            outputs = rearrange(outputs, "b d ... -> b ... d")
×
163

NEW
164
        if self.track_norms:
×
NEW
165
            metrics = to_dict(output_norms, recursive=False)
×
NEW
166
            self.metrics = {f"norm/{i}": v for i, v in metrics.items()}
×
167

NEW
168
        return outputs, next_states
×
169

NEW
170
    @property
×
NEW
171
    def d_state(self):
×
NEW
172
        d_states = [layer.d_state for layer in self.layers]
×
NEW
173
        return sum([d for d in d_states if d is not None])
×
174

NEW
175
    @property
×
NEW
176
    def state_to_tensor(self):
×
177
        # Slightly hacky way to implement this in a curried manner (so that the function can be extracted from an instance)
178
        # Somewhat more sound may be to turn this into a @staticmethod and grab subclasses using hydra.utils.get_class
NEW
179
        def fn(state):
×
NEW
180
            x = [
×
181
                _layer.state_to_tensor(_state)
182
                for (_layer, _state) in zip(self.layers, state)
183
            ]
NEW
184
            x = [_x for _x in x if _x is not None]
×
NEW
185
            return torch.cat(x, dim=-1)
×
186

NEW
187
        return fn
×
188

NEW
189
    def default_state(self, *batch_shape, device=None):
×
NEW
190
        return [
×
191
            layer.default_state(*batch_shape, device=device) for layer in self.layers
192
        ]
193

NEW
194
    def step(self, x, state, **kwargs):
×
NEW
195
        prev_states = [None] * len(self.layers) if state is None else state
×
NEW
196
        next_states = []
×
NEW
197
        layer_idx = 0
×
NEW
198
        for layer, prev_state in zip(self.layers, prev_states):
×
NEW
199
            x, state = layer.step(x, state=prev_state, **kwargs)
×
NEW
200
            next_states.append(state)
×
NEW
201
            layer_idx += 1
×
202

NEW
203
        x = self.norm(x)
×
NEW
204
        return x, next_states
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc