• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 13056267505

30 Jan 2025 04:04PM UTC coverage: 94.525% (+0.6%) from 93.934%
13056267505

push

github

web-flow
Merge pull request #48 from tonegas/develop

Develop merge on main release 1.0.0

1185 of 1215 new or added lines in 21 files covered. (97.53%)

3 existing lines in 2 files now uncovered.

9426 of 9972 relevant lines covered (94.52%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.69
/nnodely/model.py
1
from itertools import product
1✔
2
from nnodely.utils import TORCH_DTYPE
1✔
3
import numpy as np
1✔
4

5
import torch.nn as nn
1✔
6
import torch
1✔
7

8
import copy
1✔
9

10
@torch.fx.wrap
1✔
11
def connect(data_in, rel, shift):
1✔
12
    virtual = torch.roll(data_in, shifts=-1, dims=1)
1✔
13
    virtual[:, -shift:, :] = rel
1✔
14
    return virtual
1✔
15

16
class Model(nn.Module):
1✔
17
    def __init__(self, model_def):
1✔
18
        super(Model, self).__init__()
1✔
19
        model_def = copy.deepcopy(model_def)
1✔
20
        self.inputs = model_def['Inputs']
1✔
21
        self.outputs = model_def['Outputs']
1✔
22
        self.relations = model_def['Relations']
1✔
23
        self.params = model_def['Parameters']
1✔
24
        self.constants = model_def['Constants']
1✔
25
        self.sample_time = model_def['Info']['SampleTime']
1✔
26
        self.functions = model_def['Functions']
1✔
27
        self.state_model_main = model_def['States']
1✔
28
        self.minimizers = model_def['Minimizers']
1✔
29
        self.state_model = copy.deepcopy(self.state_model_main)
1✔
30
        self.input_ns_backward = {key:value['ns'][0] for key, value in (model_def['Inputs']|model_def['States']).items()}
1✔
31
        self.input_n_samples = {key:value['ntot'] for key, value in (model_def['Inputs']|model_def['States']).items()}
1✔
32
        self.minimizers_keys = [self.minimizers[key]['A'] for key in self.minimizers] + [self.minimizers[key]['B'] for key in self.minimizers]
1✔
33

34
        ## Build the network
35
        self.all_parameters = {}
1✔
36
        self.all_constants = {}
1✔
37
        self.relation_forward = {}
1✔
38
        self.relation_inputs = {}
1✔
39
        self.closed_loop_update = {}
1✔
40
        self.connect_update = {}
1✔
41

42
        ## Define the correct slicing
43
        json_inputs = self.inputs | self.state_model
1✔
44
        for _, items in self.relations.items():
1✔
45
            if items[0] == 'SamplePart':
1✔
46
                if items[1][0] in json_inputs.keys():
1✔
47
                    items[3][0] = self.input_ns_backward[items[1][0]] + items[3][0]
1✔
48
                    items[3][1] = self.input_ns_backward[items[1][0]] + items[3][1]
1✔
49
                    if len(items) > 4: ## Offset
1✔
50
                        items[4] = self.input_ns_backward[items[1][0]] + items[4]
1✔
51
            if items[0] == 'TimePart':
1✔
52
                if items[1][0] in json_inputs.keys():
1✔
53
                    items[3][0] = self.input_ns_backward[items[1][0]] + round(items[3][0]/self.sample_time)
1✔
54
                    items[3][1] = self.input_ns_backward[items[1][0]] + round(items[3][1]/self.sample_time)
1✔
55
                    if len(items) > 4: ## Offset
1✔
56
                        items[4] = self.input_ns_backward[items[1][0]] + round(items[4]/self.sample_time)
1✔
57
                else:
58
                    items[3][0] = round(items[3][0]/self.sample_time)
1✔
59
                    items[3][1] = round(items[3][1]/self.sample_time)
1✔
60
                    if len(items) > 4: ## Offset
1✔
61
                        items[4] = round(items[4]/self.sample_time)
1✔
62

63
        ## Create all the parameters
64
        for name, param_data in self.params.items():
1✔
65
            window = 'tw' if 'tw' in param_data.keys() else ('sw' if 'sw' in param_data.keys() else None)
1✔
66
            aux_sample_time = self.sample_time if 'tw' == window else 1
1✔
67
            sample_window = round(param_data[window] / aux_sample_time) if window else 1
1✔
68
            param_size = (sample_window,)+tuple(param_data['dim']) if type(param_data['dim']) is list else (sample_window, param_data['dim'])
1✔
69
            if 'values' in param_data:
1✔
70
                self.all_parameters[name] = nn.Parameter(torch.tensor(param_data['values'], dtype=TORCH_DTYPE), requires_grad=True)
1✔
71
            # TODO clean code
72
            elif 'init_fun' in param_data:
1✔
73
                exec(param_data['init_fun']['code'], globals())
1✔
74
                function_to_call = globals()[param_data['init_fun']['name']]
1✔
75
                values = np.zeros(param_size)
1✔
76
                for indexes in product(*(range(v) for v in param_size)):
1✔
77
                    if 'params' in param_data['init_fun']:
1✔
78
                        values[indexes] = function_to_call(indexes, param_size, param_data['init_fun']['params'])
1✔
79
                    else:
80
                        values[indexes] = function_to_call(indexes, param_size)
1✔
81
                self.all_parameters[name] = nn.Parameter(torch.tensor(values.tolist(), dtype=TORCH_DTYPE), requires_grad=True)
1✔
82
            else:
83
                self.all_parameters[name] = nn.Parameter(torch.rand(size=param_size, dtype=TORCH_DTYPE), requires_grad=True)
1✔
84

85
        ## Create all the constants
86
        for name, param_data in self.constants.items():
1✔
87
            self.all_constants[name] = nn.Parameter(torch.tensor(param_data['values'], dtype=TORCH_DTYPE), requires_grad=False)
1✔
88
        all_params_and_consts = self.all_parameters | self.all_constants
1✔
89

90
        ## Create all the relations
91
        for relation, inputs in self.relations.items():
1✔
92
            ## Take the relation name and the inputs needed to solve the relation
93
            rel_name, input_var = inputs[0], inputs[1]
1✔
94
            
95
            ## Create All the Relations
96
            func = getattr(self,rel_name)
1✔
97
            if func:
1✔
98
                layer_inputs = []
1✔
99
                for item in inputs[2:]:
1✔
100
                    if item in list(self.params.keys()): ## the relation takes parameters
1✔
101
                        layer_inputs.append(self.all_parameters[item])
1✔
102
                    elif item in list(self.constants.keys()): ## the relation takes a constant
1✔
103
                        layer_inputs.append(self.all_constants[item])
×
104
                    elif item in list(self.functions.keys()): ## the relation takes a custom function
1✔
105
                        layer_inputs.append(self.functions[item])
1✔
106
                        if 'params_and_consts' in self.functions[item].keys() and len(self.functions[item]['params_and_consts']) >= 0: ## Parametric function that takes parameters
1✔
107
                            layer_inputs.append([all_params_and_consts[par] for par in self.functions[item]['params_and_consts']])
1✔
108
                        if 'map_over_dim' in self.functions[item].keys():
1✔
109
                            layer_inputs.append(self.functions[item]['map_over_dim'])
1✔
110
                    else:
111
                        layer_inputs.append(item)
1✔
112

113
                if rel_name == 'SamplePart':
1✔
114
                    if layer_inputs[0] == -1:
1✔
115
                        layer_inputs[0] = self.input_n_samples[input_var[0]]
1✔
116
                elif rel_name == 'TimePart':
1✔
117
                    if layer_inputs[0] == -1:
1✔
118
                        layer_inputs[0] = self.input_n_samples[input_var[0]]
1✔
119
                    else:
120
                        layer_inputs[0] = round(layer_inputs[0] / self.sample_time)
1✔
121
                ## Initialize the relation
122
                self.relation_forward[relation] = func(*layer_inputs)
1✔
123
                ## Save the inputs needed for the relative relation
124
                self.relation_inputs[relation] = input_var
1✔
125
            else:
UNCOV
126
                print(f"Key Error: [{rel_name}] Relation not defined")
×
127

128
        ## Add the gradient to all the relations and parameters that requires it
129
        self.relation_forward = nn.ParameterDict(self.relation_forward)
1✔
130
        self.all_constants = nn.ParameterDict(self.all_constants)
1✔
131
        self.all_parameters = nn.ParameterDict(self.all_parameters)
1✔
132

133
        ## list of network outputs
134
        self.network_output_predictions = set(self.outputs.values())
1✔
135

136
        ## list of network minimization outputs
137
        self.network_output_minimizers = [] 
1✔
138
        for _,value in self.minimizers.items():
1✔
139
            self.network_output_minimizers.append(self.outputs[value['A']]) if value['A'] in self.outputs.keys() else self.network_output_minimizers.append(value['A'])
1✔
140
            self.network_output_minimizers.append(self.outputs[value['B']]) if value['B'] in self.outputs.keys() else self.network_output_minimizers.append(value['B'])
1✔
141
        self.network_output_minimizers = set(self.network_output_minimizers)
1✔
142

143
        ## list of all the network Outputs
144
        self.network_outputs = self.network_output_predictions.union(self.network_output_minimizers)
1✔
145

146
    def forward(self, kwargs):
1✔
147
        result_dict = {}
1✔
148

149
        ## Initially i have only the inputs from the dataset, the parameters, and the constants
150
        available_inputs = [key for key in self.inputs.keys() if key not in self.connect_update.keys()]  ## remove connected inputs
1✔
151
        available_states = [key for key in self.state_model.keys() if key not in self.connect_update.keys()] ## remove connected states
1✔
152
        available_keys = set(available_inputs + list(self.all_parameters.keys()) + list(self.all_constants.keys()) + available_states)
1✔
153

154
        ## Forward pass through the relations
155
        while not self.network_outputs.issubset(available_keys): ## i need to climb the relation tree until i get all the outputs
1✔
156
            for relation in self.relations.keys():
1✔
157
                ## if i have all the variables i can calculate the relation
158
                if set(self.relation_inputs[relation]).issubset(available_keys) and (relation not in available_keys):
1✔
159
                    ## Collect all the necessary inputs for the relation
160
                    layer_inputs = []
1✔
161
                    for key in self.relation_inputs[relation]:
1✔
162
                        if key in self.all_constants.keys(): ## relation that takes a constant
1✔
163
                            layer_inputs.append(self.all_constants[key])
1✔
164
                        elif key in available_states: ## relation that takes a state
1✔
165
                            layer_inputs.append(kwargs[key])
1✔
166
                        elif key in available_inputs:  ## relation that takes inputs 
1✔
167
                            layer_inputs.append(kwargs[key])
1✔
168
                        elif key in self.all_parameters.keys(): ## relation that takes parameters
1✔
169
                            layer_inputs.append(self.all_parameters[key])
1✔
170
                        else: ## relation than takes another relation or a connect variable
171
                            layer_inputs.append(result_dict[key])
1✔
172

173
                    ## Execute the current relation
174
                    result_dict[relation] = self.relation_forward[relation](*layer_inputs)
1✔
175
                    available_keys.add(relation)
1✔
176

177
                    ## Check if the relation is inside the connect
178
                    for connect_input, connect_rel in self.connect_update.items():
1✔
179
                        if relation == connect_rel:
1✔
180
                            # shift = result_dict[relation].shape[1]
181
                            # virtual = torch.roll(kwargs[connect_input], shifts=-1, dims=1)
182
                            # virtual[:, -shift:, :] = result_dict[relation]
183
                            # result_dict[connect_input] = virtual.clone()
184
                            # available_keys.add(connect_input)
185
                            result_dict[connect_input] = connect(kwargs[connect_input], result_dict[relation], result_dict[relation].size(1))
1✔
186
                            available_keys.add(connect_input)
1✔
187

188
        ## Return a dictionary with all the connected inputs
189
        connect_update_dict = {key: result_dict[key] for key, value in self.connect_update.items()}
1✔
190
        ## Return a dictionary with all the relations that updates the state variables
191
        closed_loop_update_dict = {key: result_dict[value] for key, value in self.closed_loop_update.items()}
1✔
192
        ## Return a dictionary with all the outputs final values
193
        output_dict = {key: result_dict[value] for key, value in self.outputs.items()}
1✔
194
        ## Return a dictionary with the minimization relations
195
        minimize_dict = {}
1✔
196
        for key in self.minimizers_keys:
1✔
197
            minimize_dict[key] = result_dict[self.outputs[key]] if key in self.outputs.keys() else result_dict[key]
1✔
198
                
199
        return output_dict, minimize_dict, closed_loop_update_dict, connect_update_dict
1✔
200

201

202
    def update(self, closed_loop={}, connect={}):
1✔
203
        self.closed_loop_update = {}
1✔
204
        self.connect_update = {}
1✔
205

206
        for key, state in self.state_model.items():
1✔
207
            if 'connect' in state.keys():
1✔
208
                self.connect_update[key] = state['connect']
1✔
209
            elif 'closedLoop' in state.keys():
1✔
210
                self.closed_loop_update[key] = state['closedLoop']
1✔
211

212
        for connect_in, connect_rel in connect.items():
1✔
213
            self.connect_update[connect_in] = self.outputs[connect_rel]
1✔
214

215
        for close_in, close_rel in closed_loop.items():
1✔
216
            self.closed_loop_update[close_in] = self.outputs[close_rel]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc