• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 14319828903

07 Apr 2025 09:27PM UTC coverage: 97.259% (+0.2%) from 97.035%
14319828903

Pull #86

github

web-flow
Merge 44b7c25ee into e9c323c4f
Pull Request #86: Smallclasses

2275 of 2409 new or added lines in 54 files covered. (94.44%)

1 existing line in 1 file now uncovered.

11637 of 11965 relevant lines covered (97.26%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.38
/nnodely/operators/memory.py
1
import  torch
1✔
2

3
from nnodely.support.utils import  TORCH_DTYPE, check, enforce_types
1✔
4

5
class Memory:
1✔
6
    def __init__(self):
1✔
NEW
7
        check(type(self) is not Memory, TypeError, "Loader class cannot be instantiated directly")
×
8

9
        # Model definition
NEW
10
        self._states = {}
×
NEW
11
        self._input_n_samples = {}
×
NEW
12
        self._input_ns_backward = {}
×
NEW
13
        self._input_ns_forward = {}
×
NEW
14
        self._max_samples_backward = None
×
NEW
15
        self._max_samples_forward = None
×
NEW
16
        self._max_n_samples = 0
×
17

18
    def _removeVirtualStates(self, connect, closed_loop):
1✔
19
        for key in (connect.keys() | closed_loop.keys()):
1✔
20
            if key in self._states.keys():
1✔
21
                del self._states[key]
1✔
22

23
    def _updateState(self, X, out_closed_loop, out_connect):
1✔
24
        ## Update
25
        for key, val in out_closed_loop.items():
1✔
26
            shift = val.shape[1]  ## take the output time dimension
1✔
27
            X[key] = torch.roll(X[key], shifts=-1, dims=1)  ## Roll the time window
1✔
28
            X[key][:, -shift:, :] = val  ## substitute with the predicted value
1✔
29
            self._states[key] = X[key].clone().detach()
1✔
30
        for key, value in out_connect.items():
1✔
31
            X[key] = value
1✔
32
            self._states[key] = X[key].clone().detach()
1✔
33

34
    @enforce_types
1✔
35
    def resetStates(self, states:set={}, batch:int=1) -> None:
1✔
36
        if states: ## reset only specific states
1✔
37
            for key in states:
1✔
38
                window_size = self._input_n_samples[key]
1✔
39
                dim = self._model_def['States'][key]['dim']
1✔
40
                self._states[key] = torch.zeros(size=(batch, window_size, dim), dtype=TORCH_DTYPE, requires_grad=False)
1✔
41
        else: ## reset all states
42
            self._states = {}
1✔
43
            for key, state in self._model_def['States'].items():
1✔
44
                window_size = self._input_n_samples[key]
1✔
45
                dim = state['dim']
1✔
46
                self._states[key] = torch.zeros(size=(batch, window_size, dim), dtype=TORCH_DTYPE, requires_grad=False)
1✔
47

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc