• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 18676707115

16 Oct 2025 04:28PM UTC coverage: 96.644% (+0.001%) from 96.643%
18676707115

push

github

MisterMandarino
Merge branch 'develop' of https://github.com/tonegas/nnodely into develop

55 of 55 new or added lines in 3 files covered. (100.0%)

13 existing lines in 3 files now uncovered.

12727 of 13169 relevant lines covered (96.64%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.84
/nnodely/operators/composer.py
1
import copy, torch
1✔
2

3
import numpy as np
1✔
4

5
from nnodely.operators.network import Network
1✔
6

7
from nnodely.basic.modeldef import ModelDef
1✔
8
from nnodely.basic.model import Model
1✔
9
from nnodely.support.utils import check, TORCH_DTYPE, NP_DTYPE, enforce_types
1✔
10
from nnodely.support.mathutils import argmax_dict, argmin_dict
1✔
11
from nnodely.basic.relation import Stream
1✔
12
from nnodely.layers.input import Input
1✔
13
from nnodely.layers.output import Output
1✔
14

15
from nnodely.support.logger import logging, nnLogger
1✔
16
log = nnLogger(__name__, logging.WARNING)
1✔
17

18
class Composer(Network):
1✔
19
    @enforce_types
1✔
20
    def __init__(self):
1✔
21
        check(type(self) is not Composer, TypeError, "Composer class cannot be instantiated directly")
1✔
22
        super().__init__()
1✔
23

24
    def __addInfo(self) -> None:
1✔
25
        total_params = sum(p.numel() for p in self._model.parameters() if p.requires_grad)
1✔
26
        self._model_def['Info']['num_parameters'] = total_params
1✔
27
        from nnodely import __version__
1✔
28
        self._model_def['Info']['nnodely_version'] = __version__
1✔
29

30
    @enforce_types
1✔
31
    def addModel(self, name:str, stream_list:list|Output) -> None:
1✔
32
        """
33
        Adds a new model with the given name along with a list of Outputs.
34

35
        Parameters
36
        ----------
37
        name : str
38
            The name of the model.
39
        stream_list : list of Stream
40
            The list of Outputs stream in the model.
41

42
        Example
43
        -------
44
        Example usage:
45
            >>> model = Modely()
46
            >>> x = Input('x')
47
            >>> out = Output('out', Fir(x.last()))
48
            >>> model.addModel('example_model', [out])
49
        """
50
        self._model_def.addModel(name, stream_list)
1✔
51

52
    @enforce_types
1✔
53
    def removeModel(self, name_list:list|str) -> None:
1✔
54
        """
55
        Removes models with the given list of names.
56

57
        Parameters
58
        ----------
59
        name_list : list of str
60
            The list of model names to remove.
61

62
        Example
63
        -------
64
        Example usage:
65
            >>> model.removeModel(['sub_model1', 'sub_model2'])
66
        """
67
        self._model_def.removeModel(name_list)
1✔
68

69
    @enforce_types
1✔
70
    def addConnect(self, stream_out:str|Output|Stream, input_in:str|Input, *, local:bool=False) -> None:
1✔
71
        """
72
        Adds a connection from a relation stream to an input.
73

74
        Parameters
75
        ----------
76
        stream_out : Stream
77
            The relation stream to connect from.
78
        input_in : Input or list of inputs
79
            The input or list of input to connect to.
80

81
        Examples
82
        --------
83
        .. image:: https://colab.research.google.com/assets/colab-badge.svg
84
            :target: https://colab.research.google.com/github/tonegas/nnodely/blob/main/examples/states.ipynb
85
            :alt: Open in Colab
86

87
        Example:
88
            >>> model = Modely()
89
            >>> x = Input('x')
90
            >>> y = Input('y')
91
            >>> relation = Fir(x.last())
92
            >>> model.addConnect(relation, y)
93
        """
94
        self._model_def.addConnection(stream_out, input_in,'connect', local)
1✔
95

96
    @enforce_types
1✔
97
    def addClosedLoop(self, stream_out:str|Output|Stream, input_in:str|Input, *, local:bool=False) -> None:
1✔
98
        """
99
        Adds a closed loop connection from a relation stream to an input.
100

101
        Parameters
102
        ----------
103
        stream_out : Stream
104
            The relation stream to connect from.
105
        input_in : Input or list of inputs
106
            The Input or the list of inputs to connect to.
107

108
        Examples
109
        --------
110
        .. image:: https://colab.research.google.com/assets/colab-badge.svg
111
            :target: https://colab.research.google.com/github/tonegas/nnodely/blob/main/examples/states.ipynb
112
            :alt: Open in Colab
113

114
        Example:
115
            >>> model = Modely()
116
            >>> x = Input('x')
117
            >>> y = Input('y')
118
            >>> relation = Fir(x.last())
119
            >>> model.addClosedLoop(relation, y)
120
        """
121
        self._model_def.addConnection(stream_out, input_in,'closedLoop', local)
1✔
122

123
    @enforce_types
1✔
124
    def removeConnection(self, input_in:str|Input) -> None:
1✔
125
        """
126
        Remove a closed loop or connect connection from an input.
127

128
        Parameters
129
        ----------
130
        input_in : Input or name of the input of inputs
131
            The Input to disconnect.
132

133
        Examples
134
        --------
135
        .. image:: https://colab.research.google.com/assets/colab-badge.svg
136
            :target: https://colab.research.google.com/github/tonegas/nnodely/blob/main/examples/states.ipynb
137
            :alt: Open in Colab
138

139
        Example:
140
            >>> model = Modely()
141
            >>> x = Input('x')
142
            >>> y = Input('y')
143
            >>> relation = Fir(x.last())
144
            >>> model.addConnect(relation, y)
145
            >>> model.removeConnection(y)
146
        """
147
        if isinstance(input_in, Input):
1✔
148
            input_name = input_in.name
1✔
149
        else:
150
            input_name = input_in
1✔
151
        self._model_def.removeConnection(input_name)
1✔
152

153
    @enforce_types
1✔
154
    def neuralizeModel(self, sample_time:float|int|None = None, *, clear_model:bool = False, model_def:dict|None = None) -> None:
1✔
155
        """
156
        Neuralizes the model, preparing it for inference and training. This method creates a neural network model starting from the model definition.
157
        It will also create all the time windows and correct slicing for all the inputs defined.
158

159
        Parameters
160
        ----------
161
        sample_time : float or None, optional
162
            The sample time for the model. Default is 1.0
163
        clear_model : bool, optional
164
            Whether to clear the existing model definition. Default is False.
165
        model_def : dict or None, optional
166
            A dictionary defining the model. If provided, it overrides the existing model definition. Default is None.
167

168
        Raises
169
        ------
170
        ValueError
171
            If sample_time is not None and model_def is provided.
172
            If clear_model is True and model_def is provided.
173

174
        Example
175
        -------
176
        Example usage:
177
            >>> model = Modely(name='example_model')
178
            >>> model.neuralizeModel(sample_time=0.1, clear_model=True)
179
        """
180
        if model_def is not None:
1✔
181
            check(sample_time == None, ValueError, 'The sample_time must be None if a model_def is provided')
1✔
182
            check(clear_model == False, ValueError, 'The clear_model must be False if a model_def is provided')
1✔
183
            self._model_def = ModelDef(model_def)
1✔
184
        else:
185
            self._model_def.updateParameters(model = None, clear_model = clear_model)
1✔
186

187
        self._model_def.setBuildWindow(sample_time)
1✔
188
        self._model = Model(self._model_def.getJson())
1✔
189
        self.__addInfo()
1✔
190

191
        self._input_ns_backward = {key:value['ns'][0] for key, value in self._model_def['Inputs'].items()}
1✔
192
        self._input_ns_forward = {key:value['ns'][1] for key, value in self._model_def['Inputs'].items()}
1✔
193
        self._max_samples_backward = max(self._input_ns_backward.values())
1✔
194
        self._max_samples_forward = max(self._input_ns_forward.values())
1✔
195
        self._input_n_samples = {}
1✔
196
        for key, value in self._model_def['Inputs'].items():
1✔
197
            if self._input_ns_forward[key] >= 0:
1✔
198
                if 'closedLoop' in value:
1✔
199
                    log.warning(f"Closed loop on {key} with sample in the future.")
1✔
200
                if 'connect' in value:
1✔
201
                    log.warning(f"Connect on {key} with sample in the future.")
1✔
202
            self._input_n_samples[key] = self._input_ns_backward[key] + self._input_ns_forward[key]
1✔
203
        self._max_n_samples = max(self._input_ns_backward.values()) + max(self._input_ns_forward.values())
1✔
204

205
        ## Initialize States
206
        self.resetStates()
1✔
207

208
        self._neuralized = True
1✔
209
        self._traced = False
1✔
210
        self._model_def.updateParameters(self._model)
1✔
211
        self.visualizer.showModel(self._model_def.getJson())
1✔
212
        self.visualizer.showModelInputWindow()
1✔
213
        self.visualizer.showBuiltModel()
1✔
214

215
    @enforce_types
1✔
216
    def __call__(self, inputs:dict={}, *, sampled:bool=False, closed_loop:dict={}, connect:dict={}, prediction_samples:str|int='auto', num_of_samples:int|None=None) -> dict:
1✔
217
        """
218
        Performs inference on the model.
219

220
        Parameters
221
        ----------
222
        inputs : dict, optional
223
            A dictionary of input data. The keys are input names and the values are the corresponding data. Default is an empty dictionary.
224
        sampled : bool, optional
225
            A boolean indicating whether the inputs are already sampled. Default is False.
226
        closed_loop : dict, optional
227
            A dictionary specifying closed loop connections. The keys are input names and the values are output names. Default is an empty dictionary.
228
        connect : dict, optional
229
            A dictionary specifying direct connections. The keys are input names and the values are output names. Default is an empty dictionary.
230
        prediction_samples : str or int, optional
231
            The number of prediction samples. Can be 'auto', None or an integer. Default is 'auto'.
232
        num_of_samples : str or int, optional
233
            The number of samples. Can be 'auto', None or an integer. Default is 'auto'.
234

235
        Returns
236
        -------
237
        dict
238
            A dictionary containing the model's prediction outputs.
239

240
        Raises
241
        ------
242
        RuntimeError
243
            If the network is not neuralized.
244
        ValueError
245
            If an input variable is not in the model definition or if an output variable is not in the model definition.
246

247
        Examples
248
        --------
249
        .. image:: https://colab.research.google.com/assets/colab-badge.svg
250
            :target: https://colab.research.google.com/github/tonegas/nnodely/blob/main/examples/inference.ipynb
251
            :alt: Open in Colab
252

253
        Example usage:
254
            >>> model = Modely()
255
            >>> x = Input('x')
256
            >>> out = Output('out', Fir(x.last()))
257
            >>> model.addModel('example_model', [out])
258
            >>> model.neuralizeModel()
259
            >>> predictions = model(inputs={'x': [1, 2, 3]})
260
        """
261

262
        ## Copy dict for avoid python bug
263
        inputs = copy.deepcopy(inputs)
1✔
264
        all_closed_loop = copy.deepcopy(closed_loop) #| self._model_def._input_closed_loop
1✔
265
        all_connect = copy.deepcopy(connect) #| self._model_def._input_connect
1✔
266

267
        ## Check neuralize
268
        check(self.neuralized, RuntimeError, "The network is not neuralized.")
1✔
269

270
        ## Check closed loop integrity
271
        prediction_samples = self._setup_recurrent_variables(prediction_samples, all_closed_loop, all_connect)
1✔
272

273
        ## List of keys
274
        model_inputs = list(self._model_def['Inputs'].keys())
1✔
275
        json_inputs = self._model_def['Inputs']
1✔
276
        extra_inputs = list(set(list(inputs.keys())) - set(model_inputs))
1✔
277
        non_mandatory_inputs = list(all_closed_loop.keys()) + list(all_connect.keys()) +  list(self._model_def.recurrentInputs().keys())
1✔
278
        mandatory_inputs = list(set(model_inputs) - set(non_mandatory_inputs))
1✔
279

280
        ## Remove extra inputs
281
        for key in extra_inputs:
1✔
282
            log.warning(
1✔
283
                f'The provided input {key} is not used inside the network. the inference will continue without using it')
284
            del inputs[key]
1✔
285

286
        ## Get the number of data windows for each input
287
        num_of_windows = {key: len(value) for key, value in inputs.items()} if sampled else {
1✔
288
            key: len(value) - self._input_n_samples[key] + 1 for key, value in inputs.items()}
289

290
        if num_of_samples is not None and sampled == True:
1✔
UNCOV
291
            log.warning(f'num_of_samples is ignored if sampled is equal to True')
×
292

293
        ## Get the maximum inference window
294
        if num_of_samples and not sampled:
1✔
295
            window_dim = num_of_samples
1✔
296
            for key in inputs.keys():
1✔
297
                input_dim = self._model_def['Inputs'][key]['dim']
1✔
298
                new_samples = num_of_samples - (len(inputs[key]) - self._input_n_samples[key] + 1)
1✔
299
                if input_dim > 1:
1✔
300
                    log.warning(f'The variable {key} is filled with {new_samples} samples equal to zeros.')
1✔
301
                    inputs[key] += [[0 for _ in range(input_dim)] for _ in range(new_samples)]
1✔
302
                else:
303
                    log.warning(f'The variable {key} is filled with {new_samples} samples equal to zeros.')
1✔
304
                    inputs[key] += [0 for _ in range(new_samples)]
1✔
305
        elif inputs:
1✔
306
            windows = []
1✔
307
            for key in inputs.keys():
1✔
308
                if key in mandatory_inputs:
1✔
309
                    n_samples = len(inputs[key]) if sampled else len(inputs[key]) - self._model_def['Inputs'][key]['ntot'] + 1
1✔
310
                    windows.append(n_samples)
1✔
311
            if not windows:
1✔
312
                for key in inputs.keys():
1✔
313
                    if key in non_mandatory_inputs:
1✔
314
                        if key in model_inputs:
1✔
315
                            n_samples = len(inputs[key]) if sampled else len(inputs[key]) - self._model_def['Inputs'][key]['ntot'] + 1
1✔
316
                        windows.append(n_samples)
1✔
317
            window_dim = min(windows) if windows else 0
1✔
318
        else:  ## No inputs
319
            window_dim = 1 if non_mandatory_inputs else 0
1✔
320
        check(window_dim > 0, StopIteration, f'Missing samples in the input window')
1✔
321

322
        if len(set(num_of_windows.values())) > 1:
1✔
323
            max_ind_key, max_dim = argmax_dict(num_of_windows)
1✔
324
            min_ind_key, min_dim = argmin_dict(num_of_windows)
1✔
325
            log.warning(
1✔
326
                f'Different number of samples between inputs [MAX {num_of_windows[max_ind_key]} = {max_dim}; MIN {num_of_windows[min_ind_key]} = {min_dim}]')
327

328
        ## Autofill the missing inputs
329
        provided_inputs = list(inputs.keys())
1✔
330
        missing_inputs = list(set(mandatory_inputs) - set(provided_inputs))
1✔
331
        if missing_inputs:
1✔
332
            log.warning(f'Inputs not provided: {missing_inputs}. Autofilling with zeros..')
1✔
333
            for key in missing_inputs:
1✔
334
                inputs[key] = np.zeros(
1✔
335
                    shape=(self._input_n_samples[key] + window_dim - 1, self._model_def['Inputs'][key]['dim']),
336
                    dtype=NP_DTYPE).tolist()
337

338
        ## Transform inputs in 3D Tensors
339
        for key in inputs.keys():
1✔
340
            input_dim = json_inputs[key]['dim']
1✔
341
            inputs[key] = torch.from_numpy(np.array(inputs[key])).to(TORCH_DTYPE)
1✔
342

343
            if input_dim > 1:
1✔
344
                correct_dim = 3 if sampled else 2
1✔
345
                check(len(inputs[key].shape) == correct_dim, ValueError,
1✔
346
                      f'The input {key} must have {correct_dim} dimensions')
347
                check(inputs[key].shape[correct_dim - 1] == input_dim, ValueError,
1✔
348
                      f'The second dimension of the input "{key}" must be equal to {input_dim}')
349

350
            if input_dim == 1 and inputs[key].shape[-1] != 1:  ## add the input dimension
1✔
351
                inputs[key] = inputs[key].unsqueeze(-1)
1✔
352
            if inputs[key].ndim <= 1:  ## add the batch dimension
1✔
353
                inputs[key] = inputs[key].unsqueeze(0)
1✔
354
            if inputs[key].ndim <= 2:  ## add the time dimension
1✔
355
                inputs[key] = inputs[key].unsqueeze(0)
1✔
356

357
        ## initialize the resulting dictionary
358
        result_dict = {}
1✔
359
        for key in self._model_def['Outputs'].keys():
1✔
360
            result_dict[key] = []
1✔
361

362
        ## Inference
363
        with (torch.enable_grad() if self._get_gradient_on_inference() else torch.inference_mode()):
1✔
364
            ## Update with virtual states
365
            if prediction_samples == 'auto' or prediction_samples >= 0:
1✔
366
                self._model.update(closed_loop=all_closed_loop, connect=all_connect)
1✔
367
            else:
368
                self._model.update(disconnect=True)
1✔
369
                prediction_samples = 0
1✔
370
            X = {}
1✔
371
            count = 0
1✔
372
            first = True
1✔
373
            for idx in range(window_dim):
1✔
374
                ## Get mandatory data inputs
375
                for key in mandatory_inputs:
1✔
376
                    X[key] = inputs[key][idx:idx + 1] if sampled else inputs[key][:,idx:idx + self._input_n_samples[key]]
1✔
377
                    if 'type' in json_inputs[key].keys():
1✔
378
                        X[key] = X[key].requires_grad_(True)
1✔
379
                ## reset states
380
                if count == 0 or prediction_samples == 'auto':
1✔
381
                    init_states = []
1✔
382
                    count = prediction_samples
1✔
383
                    for key in non_mandatory_inputs:  ## Get non mandatory data (from inputs, from states, or with zeros)
1✔
384
                        ## If it is given as input AND
385
                        ## if prediction_samples is 'auto' and there are enough samples OR
386
                        ## if prediction_samples is NOT 'auto'
387
                        if key in inputs.keys() and (
1✔
388
                                (prediction_samples == 'auto' and idx < num_of_windows[key]) or \
389
                                (prediction_samples != 'auto')
390
                        ):
391
                            X[key] = inputs[key][idx:idx + 1] if sampled else inputs[key][:,idx:idx + self._input_n_samples[key]]
1✔
392
                        ## if it is a state AND
393
                        ## if prediction_samples = 'auto' and there are not enough samples OR
394
                        ## it is the first iteration with prediction_samples = None
395
                        elif key in self._states.keys() and (
1✔
396
                                prediction_samples == 'auto' or
397
                                (first and prediction_samples == None)
398
                        ):
399
                            X[key] = self._states[key]
1✔
400
                        else:
401
                        ## if there are no samples
402
                            window_size = self._input_n_samples[key]
1✔
403
                            dim = json_inputs[key]['dim']
1✔
404
                            X[key] = torch.zeros(size=(1, window_size, dim), dtype=TORCH_DTYPE, requires_grad=False)
1✔
405

406
                        if 'type' in json_inputs[key].keys():
1✔
407
                            X[key] = X[key].requires_grad_(True)
1✔
408
                    first = False
1✔
409
                else:
410
                    # Remove the gradient of the previous forward
411
                    for key in X.keys():
1✔
412
                        if 'type' in json_inputs[key].keys():
1✔
413
                            X[key] = X[key].detach().requires_grad_(True)
1✔
414
                    count -= 1
1✔
415
                ## Forward pass
416
                result, _, out_closed_loop, out_connect = self._model(X)
1✔
417

418
                if init_states:
1✔
UNCOV
419
                    for key in init_states:
×
UNCOV
420
                        del self._model.connect_update[key]
×
UNCOV
421
                    init_states = []
×
422

423
                ## Append the prediction of the current sample to the result dictionary
424
                for key in self._model_def['Outputs'].keys():
1✔
425
                    if result[key].shape[-1] == 1:
1✔
426
                        result[key] = result[key].squeeze(-1)
1✔
427
                        if result[key].shape[-1] == 1:
1✔
428
                            result[key] = result[key].squeeze(-1)
1✔
429
                    result_dict[key].append(result[key].detach().squeeze(dim=0).tolist())
1✔
430

431
                ## Update closed_loop and connect
432
                if prediction_samples:
1✔
433
                    self._update_state(X, out_closed_loop, out_connect)
1✔
434

435
        ## Remove virtual states
436
        self._remove_virtual_states(connect, closed_loop)
1✔
437

438
        return result_dict
1✔
439

440

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc