• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 20067870175

09 Dec 2025 02:52PM UTC coverage: 96.745% (-0.001%) from 96.746%
20067870175

push

github

MisterMandarino
Merge branch 'develop' of https://github.com/tonegas/nnodely into develop

41 of 43 new or added lines in 14 files covered. (95.35%)

20 existing lines in 2 files now uncovered.

13020 of 13458 relevant lines covered (96.75%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.37
/nnodely/basic/model.py
1
import torch
1✔
2
import copy
1✔
3

4
import torch.nn as nn
1✔
5
import numpy as np
1✔
6

7
from itertools import product
1✔
8

9
from nnodely.support.utils import TORCH_DTYPE
1✔
10
from nnodely.support import initializer
1✔
11

12
@torch.fx.wrap
1✔
13
def update_state(data_in, rel):
1✔
14
    #virtual = torch.roll(data_in, shifts=-1, dims=1)
15
    max_dim = min(rel.size(1), data_in.size(1))
1✔
16
    data_out = data_in.clone()
1✔
17
    data_out[:, -max_dim:, :] = rel[:, -max_dim:, :]
1✔
18
    return data_out
1✔
19

20
class Model(nn.Module):
1✔
21
    def __init__(self, model_def):
1✔
22
        super(Model, self).__init__()
1✔
23
        model_def = copy.deepcopy(model_def)
1✔
24

25
        self.states = {key: value for key, value in model_def['Inputs'].items() if ('closedLoop' in value.keys() or 'connect' in value.keys())}
1✔
26

27
        self.inputs = model_def['Inputs']
1✔
28
        self.outputs = model_def['Outputs']
1✔
29
        self.relations = model_def['Relations']
1✔
30
        self.params = model_def['Parameters']
1✔
31
        self.constants = model_def['Constants']
1✔
32
        self.sample_time = model_def['Info']['SampleTime']
1✔
33
        self.functions = model_def['Functions']
1✔
34

35
        self.minimizers = model_def['Minimizers'] if 'Minimizers' in model_def else {}
1✔
36
        self.minimizers_keys = [self.minimizers[key]['A'] for key in self.minimizers] + [self.minimizers[key]['B'] for key in self.minimizers]
1✔
37

38
        self.input_ns_backward = {key:value['ns'][0] for key, value in model_def['Inputs'].items()}
1✔
39
        self.input_n_samples = {key:value['ntot'] for key, value in model_def['Inputs'].items()}
1✔
40

41
        ## Build the network
42
        self.all_parameters = {}
1✔
43
        self.all_constants = {}
1✔
44
        self.relation_forward = {}
1✔
45
        self.relation_inputs = {}
1✔
46
        self.closed_loop_update = {}
1✔
47
        self.connect_update = {}
1✔
48

49
        ## Update the connect_update and closed_loop_update
50
        self.update()
1✔
51

52
        ## Define the correct slicing
53
        for _, items in self.relations.items():
1✔
54
            if items[0] == 'SamplePart':
1✔
55
                if items[1][0] in self.inputs.keys():
1✔
56
                    items[3][0] = self.input_ns_backward[items[1][0]] + items[3][0]
1✔
57
                    items[3][1] = self.input_ns_backward[items[1][0]] + items[3][1]
1✔
58
                    if len(items) > 4: ## Offset
1✔
59
                        items[4] = self.input_ns_backward[items[1][0]] + items[4]
1✔
60
            if items[0] == 'TimePart':
1✔
61
                if items[1][0] in self.inputs.keys():
1✔
62
                    items[3][0] = self.input_ns_backward[items[1][0]] + round(items[3][0]/self.sample_time)
1✔
63
                    items[3][1] = self.input_ns_backward[items[1][0]] + round(items[3][1]/self.sample_time)
1✔
64
                    if len(items) > 4: ## Offset
1✔
65
                        items[4] = self.input_ns_backward[items[1][0]] + round(items[4]/self.sample_time)
1✔
66
                else:
67
                    items[3][0] = round(items[3][0]/self.sample_time)
1✔
68
                    items[3][1] = round(items[3][1]/self.sample_time)
1✔
69
                    if len(items) > 4: ## Offset
1✔
70
                        items[4] = round(items[4]/self.sample_time)
1✔
71

72
        ## Create all the parameters
73
        for name, param_data in self.params.items():
1✔
74
            window = 'tw' if 'tw' in param_data.keys() else ('sw' if 'sw' in param_data.keys() else None)
1✔
75
            aux_sample_time = self.sample_time if 'tw' == window else 1
1✔
76
            sample_window = round(param_data[window] / aux_sample_time) if window else None
1✔
77
            if sample_window is None:
1✔
78
                param_size = tuple(param_data['dim']) if type(param_data['dim']) is list else (param_data['dim'],)
1✔
79
            else:
80
                param_size = (sample_window,)+tuple(param_data['dim']) if type(param_data['dim']) is list else (sample_window, param_data['dim'])
1✔
81
            if 'values' in param_data:
1✔
82
                self.all_parameters[name] = nn.Parameter(torch.tensor(param_data['values'], dtype=TORCH_DTYPE), requires_grad=True)
1✔
83
            # TODO clean code
84
            elif 'init_fun' in param_data:
1✔
85
                if 'code' in param_data['init_fun'].keys():
1✔
86
                    exec(param_data['init_fun']['code'], globals())
1✔
87
                    function_to_call = globals()[param_data['init_fun']['name']]
1✔
88
                else:
89
                    function_to_call = getattr(initializer, param_data['init_fun']['name'])
1✔
90
                values = np.zeros(param_size)
1✔
91
                for indexes in product(*(range(v) for v in param_size)):
1✔
92
                    if 'params' in param_data['init_fun']:
1✔
93
                        values[indexes] = function_to_call(indexes, param_size, param_data['init_fun']['params'])
1✔
94
                    else:
95
                        values[indexes] = function_to_call(indexes, param_size)
1✔
96
                self.all_parameters[name] = nn.Parameter(torch.tensor(values.tolist(), dtype=TORCH_DTYPE), requires_grad=True)
1✔
97
            else:
98
                self.all_parameters[name] = nn.Parameter(torch.rand(size=param_size, dtype=TORCH_DTYPE), requires_grad=True)
1✔
99

100
        ## Create all the constants
101
        for name, param_data in self.constants.items():
1✔
102
            self.all_constants[name] = nn.Parameter(torch.tensor(param_data['values'], dtype=TORCH_DTYPE), requires_grad=False)
1✔
103
        all_params_and_consts = self.all_parameters | self.all_constants
1✔
104

105
        ## Create all the relations
106
        for relation, inputs in self.relations.items():
1✔
107
            ## Take the relation name and the inputs needed to solve the relation
108
            rel_name, input_var = inputs[0], inputs[1]
1✔
109
            ## Create All the Relations
110
            func = getattr(self,rel_name)
1✔
111
            if func:
1✔
112
                layer_inputs = []
1✔
113
                for item in inputs[2:]:
1✔
114
                    if item in list(self.params.keys()): ## the relation takes parameters
1✔
115
                        layer_inputs.append(self.all_parameters[item])
1✔
116
                    elif item in list(self.constants.keys()): ## the relation takes a constant
1✔
UNCOV
117
                        layer_inputs.append(self.all_constants[item])
×
118
                    elif item in list(self.functions.keys()): ## the relation takes a custom function
1✔
119
                        layer_inputs.append(self.functions[item])
1✔
120
                        if 'params_and_consts' in self.functions[item].keys() and len(self.functions[item]['params_and_consts']) >= 0: ## Parametric function that takes parameters
1✔
121
                            layer_inputs.append([all_params_and_consts[par] for par in self.functions[item]['params_and_consts']])
1✔
122
                        if 'map_over_dim' in self.functions[item].keys():
1✔
123
                            layer_inputs.append(self.functions[item]['map_over_dim'])
1✔
124
                    else:
125
                        layer_inputs.append(item)
1✔
126

127
                if rel_name == 'SamplePart':
1✔
128
                    if layer_inputs[0] == -1:
1✔
129
                        layer_inputs[0] = self.input_n_samples[input_var[0]]
1✔
130
                elif rel_name == 'TimePart':
1✔
131
                    if layer_inputs[0] == -1:
1✔
132
                        layer_inputs[0] = self.input_n_samples[input_var[0]]
1✔
133
                    else:
134
                        layer_inputs[0] = round(layer_inputs[0] / self.sample_time)
1✔
135
                ## Initialize the relation
136
                self.relation_forward[relation] = func(*layer_inputs)
1✔
137
                ## Save the inputs needed for the relative relation
138
                self.relation_inputs[relation] = input_var
1✔
139

140
        ## Add the gradient to all the relations and parameters that requires it
141
        self.relation_forward = nn.ParameterDict(self.relation_forward)
1✔
142
        self.all_constants = nn.ParameterDict(self.all_constants)
1✔
143
        self.all_parameters = nn.ParameterDict(self.all_parameters)
1✔
144
        ## list of network outputs
145
        self.network_output_predictions = set(self.outputs.values())
1✔
146
        ## list of network minimization outputs
147
        self.network_output_minimizers = []
1✔
148
        for _,value in self.minimizers.items():
1✔
149
            self.network_output_minimizers.append(self.outputs[value['A']]) if value['A'] in self.outputs.keys() else self.network_output_minimizers.append(value['A'])
1✔
150
            self.network_output_minimizers.append(self.outputs[value['B']]) if value['B'] in self.outputs.keys() else self.network_output_minimizers.append(value['B'])
1✔
151
        self.network_output_minimizers = set(self.network_output_minimizers)
1✔
152
        ## list of all the network Outputs
153
        self.network_outputs = self.network_output_predictions.union(self.network_output_minimizers)
1✔
154

155
    def forward(self, kwargs):
1✔
156
        result_dict = {}
1✔
157

158
        ## Initially i have only the inputs from the dataset, the parameters, and the constants
159
        available_inputs = [key for key in self.inputs.keys() if key not in self.connect_update.keys()]  ## remove connected inputs
1✔
160
        available_keys = set(available_inputs + list(self.all_parameters.keys()) + list(self.all_constants.keys()))
1✔
161

162
        ## Forward pass through the relations
163
        while not self.network_outputs.issubset(available_keys): ## i need to climb the relation tree until i get all the outputs
1✔
164
            for relation in self.relations.keys():
1✔
165
                ## if i have all the variables i can calculate the relation
166
                if set(self.relation_inputs[relation]).issubset(available_keys) and (relation not in available_keys):
1✔
167
                    ## Collect all the necessary inputs for the relation
168
                    layer_inputs = []
1✔
169
                    for key in self.relation_inputs[relation]:
1✔
170
                        if key in self.all_constants.keys(): ## relation that takes a constant
1✔
171
                            layer_inputs.append(self.all_constants[key])
1✔
172
                        elif key in available_inputs:  ## relation that takes inputs
1✔
173
                            layer_inputs.append(kwargs[key])
1✔
174
                        elif key in self.all_parameters.keys(): ## relation that takes parameters
1✔
175
                            layer_inputs.append(self.all_parameters[key])
1✔
176
                        else: ## relation than takes another relation or a connect variable
177
                            layer_inputs.append(result_dict[key])
1✔
178

179
                    ## Execute the current relation
180
                    result_dict[relation] = self.relation_forward[relation](*layer_inputs)
1✔
181
                    available_keys.add(relation)
1✔
182

183
                    ## Check if the relation is inside the connect
184
                    for connect_input, connect_rel in self.connect_update.items():
1✔
185
                        if relation == connect_rel:
1✔
186
                            result_dict[connect_input] = update_state(kwargs[connect_input], result_dict[relation])
1✔
187
                            available_keys.add(connect_input)
1✔
188

189
        ## Return a dictionary with all the connected inputs
190
        connect_update_dict = {key: result_dict[key] for key in self.connect_update.keys()}
1✔
191
        ## Return a dictionary with all the relations that updates the state variables
192
        closed_loop_update_dict = {key: result_dict[value] for key, value in self.closed_loop_update.items()}
1✔
193
        ## Return a dictionary with all the outputs final values
194
        output_dict = {key: result_dict[value] for key, value in self.outputs.items()}
1✔
195
        ## Return a dictionary with the minimization relations
196
        minimize_dict = {}
1✔
197
        for key in self.minimizers_keys:
1✔
198
            minimize_dict[key] = result_dict[self.outputs[key]] if key in self.outputs.keys() else result_dict[key]
1✔
199
        return output_dict, minimize_dict, closed_loop_update_dict, connect_update_dict
1✔
200

201
    def update(self, *, closed_loop = {}, connect = {}, disconnect = False):
1✔
202
        self.closed_loop_update = {}
1✔
203
        self.connect_update = {}
1✔
204

205
        if disconnect:
1✔
206
            return
1✔
207

208
        for key, state in self.states.items():
1✔
209
            if 'connect' in state.keys():
1✔
210
                self.connect_update[key] = state['connect']
1✔
211
            elif 'closedLoop' in state.keys():
1✔
212
                self.closed_loop_update[key] = state['closedLoop']
1✔
213

214
        # Get relation from outputs
215
        for connect_in, connect_rel in connect.items():
1✔
216
            set_relation = self.outputs[connect_rel] if connect_rel in self.outputs.keys() else connect_rel
1✔
217
            self.connect_update[connect_in] = set_relation
1✔
218
        for close_in, close_rel in closed_loop.items():
1✔
219
            set_relation = self.outputs[close_rel] if close_rel in self.outputs.keys() else close_rel
1✔
220
            self.closed_loop_update[close_in] = set_relation
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc