• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 14359872492

09 Apr 2025 02:33PM UTC coverage: 97.602% (+0.6%) from 97.035%
14359872492

Pull #86

github

web-flow
Merge ec719935a into e9c323c4f
Pull Request #86: Smallclasses

2291 of 2418 new or added lines in 54 files covered. (94.75%)

3 existing lines in 1 file now uncovered.

11683 of 11970 relevant lines covered (97.6%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.67
/nnodely/support/utils.py
1
import copy, torch, inspect, typing
1✔
2

3
from collections import OrderedDict
1✔
4

5
import numpy as np
1✔
6
from contextlib import suppress
1✔
7
from pprint import pformat
1✔
8
from functools import wraps
1✔
9
from typing import get_type_hints
1✔
10
import keyword
1✔
11

12
from nnodely.support.logger import logging, nnLogger
1✔
13
log = nnLogger(__name__, logging.CRITICAL)
1✔
14

15
TORCH_DTYPE = torch.float32
1✔
16
NP_DTYPE = np.float32
1✔
17

18
ForbiddenTags = keyword.kwlist
1✔
19

20
class ReadOnlyDict:
1✔
21
    def __init__(self, data):
1✔
22
        self._data = data
1✔
23

24
    def __getitem__(self, key):
1✔
25
        value = self._data[key]
1✔
26
        if isinstance(value, dict):
1✔
27
            return ReadOnlyDict(value)
1✔
28
        return value
1✔
29

30
    def __len__(self):
1✔
NEW
31
        return len(self._data)
×
32

33
    def __iter__(self):
1✔
34
        return iter(self._data)
1✔
35

36
    def keys(self):
1✔
37
        return self._data.keys()
1✔
38

39
    def items(self):
1✔
40
        return self._data.items()
1✔
41

42
    def values(self):
1✔
NEW
43
        return self._data.values()
×
44

45
    def __or__(self, other):
1✔
46
        if not isinstance(other, ReadOnlyDict):
1✔
NEW
47
            return NotImplemented
×
48
        combined_data = {**self._data, **other._data}
1✔
49
        return ReadOnlyDict(combined_data)
1✔
50

51
    def __str__(self):
1✔
52
        from nnodely.visualizer.visualizer import color, GREEN
1✔
53
        from pprint import pformat
1✔
54
        return color(pformat(self._data), GREEN)
1✔
55

56
    def __eq__(self, other):
1✔
57
        if not isinstance(other, ReadOnlyDict):
1✔
NEW
58
            return NotImplemented
×
59
        return self._data == other._data
1✔
60

61

62
def get_window(obj):
1✔
63
    return 'tw' if 'tw' in obj.dim else ('sw' if 'sw' in obj.dim else None)
1✔
64

65
def get_inputs(json, relation, inputs):
1✔
66
    # Get all the inputs needed to compute a specific relation from the json graph
67
    for rel in json['Relations'][relation][1]:
1✔
68
        if rel in (json['Inputs'] | json['States']): ## find an input
1✔
69
            return inputs.append(rel)
1✔
70
        else: ## another relation
71
            return get_inputs(json, rel, inputs) ## recursive call to find the inputs of the relation
×
72

73
def enforce_types(func):
1✔
74
    @wraps(func)
1✔
75
    def wrapper(*args, **kwargs):
1✔
76
        hints = get_type_hints(func)
1✔
77
        all_args = kwargs.copy()
1✔
78

79
        sig = OrderedDict(inspect.signature(func).parameters)
1✔
80
        if len(sig) != len(args):
1✔
81
            var_type = None
1✔
82
            for ind, arg in enumerate(args):
1✔
83
                if ind < len(list(sig.values())) and list(sig.values())[ind].kind == inspect.Parameter.VAR_POSITIONAL:
1✔
84
                    var_name = list(sig.keys())[ind]
1✔
85
                    var_type = sig.pop(var_name)
1✔
86
                if var_type:
1✔
87
                    sig[var_name+str(ind)] = var_type
1✔
88

89
        all_args.update(dict(zip(sig, args)))
1✔
90
        if 'self' in sig.keys():
1✔
91
            sig.pop('self')
1✔
92

93
        for arg_name, arg in all_args.items():
1✔
94
            if (arg_name in hints.keys() or arg_name in sig.keys()) and not isinstance(arg,sig[arg_name].annotation):
1✔
95
                raise TypeError(
1✔
96
                    f"In Function or Class {func} Expected argument '{arg_name}={arg}' to be of type {sig[arg_name].annotation}, but got {type(arg)}")
97

98
        # for arg, arg_type in hints.items():
99
        #     if arg in all_args and not isinstance(all_args[arg], arg_type):
100
        #         raise TypeError(
101
        #             f"In Function or Class {func} Expected argument '{arg}' to be of type {arg_type}, but got {type(all_args[arg]).__name__}")
102

103
        return func(*args, **kwargs)
1✔
104

105
    return wrapper
1✔
106

107

108
# Linear interpolation function, operating on batches of input data and returning batches of output data
109
def linear_interp(x,x_data,y_data):
1✔
110
    # Inputs: 
111
    # x: query point, a tensor of shape torch.Size([N, 1, 1])
112
    # x_data: map of x values, sorted in ascending order, a tensor of shape torch.Size([Q, 1])
113
    # y_data: map of y values, a tensor of shape torch.Size([Q, 1])
114
    # Output:
115
    # y: interpolated value at x, a tensor of shape torch.Size([N, 1, 1])
116

117
    # Saturate x to the range of x_data
118
    x = torch.min(torch.max(x,x_data[0]),x_data[-1])
1✔
119

120
    # Find the index of the closest value in x_data
121
    idx = torch.argmin(torch.abs(x_data[:-1] - x),dim=1)
1✔
122
    
123
    # Linear interpolation
124
    y = y_data[idx] + (y_data[idx+1] - y_data[idx])/(x_data[idx+1] - x_data[idx])*(x - x_data[idx])
1✔
125
    return y
1✔
126

127
def tensor_to_list(data):
1✔
128
    if isinstance(data, torch.Tensor):
1✔
129
        # Converte il tensore in una lista
130
        return data.tolist()
1✔
131
    elif isinstance(data, dict):
1✔
132
        # Ricorsione per i dizionari
133
        return {key: tensor_to_list(value) for key, value in data.items()}
1✔
134
    elif isinstance(data, list):
1✔
135
        # Ricorsione per le liste
136
        return [tensor_to_list(item) for item in data]
1✔
137
    elif isinstance(data, tuple):
1✔
138
        # Ricorsione per tuple
139
        return tuple(tensor_to_list(item) for item in data)
×
140
    elif isinstance(data, torch.nn.modules.container.ParameterDict):
1✔
141
        # Ricorsione per parameter dict
142
        return {key: tensor_to_list(value) for key, value in data.items()}
1✔
143
    else:
144
        # Altri tipi di dati rimangono invariati
145
        return data
1✔
146

147
# Codice per comprimere le relazioni
148
        #print(self.json['Relations'])
149
        # used_rel = {string for values in self.json['Relations'].values() for string in values[1]}
150
        # if obj1.name not in used_rel and obj1.name in self.json['Relations'].keys() and self.json['Relations'][obj1.name][0] == add_relation_name:
151
        #     self.json['Relations'][self.name] = [add_relation_name, self.json['Relations'][obj1.name][1]+[obj2.name]]
152
        #     del self.json['Relations'][obj1.name]
153
        # else:
154
        # Devo aggiungere un operazione che rimuove un operazione di Add,Sub,Mul,Div se può essere unita ad un'altra operazione dello stesso tipo
155
        #
156
def merge(source, destination, main = True):
1✔
157
    if main:
1✔
158
        for key, value in destination["Functions"].items():
1✔
159
            if key in source["Functions"].keys() and 'n_input' in value.keys() and 'n_input' in source["Functions"][key].keys():
1✔
160
                check(value == {} or source["Functions"][key] == {} or value['n_input'] == source["Functions"][key]['n_input'],
1✔
161
                      TypeError,
162
                      f"The ParamFun {key} is present multiple times, with different number of inputs. "
163
                      f"The ParamFun {key} is called with {value['n_input']} parameters and with {source['Functions'][key]['n_input']} parameters.")
164
        for key, value in destination["Parameters"].items():
1✔
165
            if key in source["Parameters"].keys():
1✔
166
                if 'dim' in value.keys() and 'dim' in source["Parameters"][key].keys():
1✔
167
                    check(value['dim'] == source["Parameters"][key]['dim'],
1✔
168
                          TypeError,
169
                          f"The Parameter {key} is present multiple times, with different dimensions. "
170
                          f"The Parameter {key} is called with {value['dim']} dimension and with {source['Parameters'][key]['dim']} dimension.")
171
                window_dest = 'tw' if 'tw' in value else ('sw' if 'sw' in value else None)
1✔
172
                window_source = 'tw' if 'tw' in source["Parameters"][key] else ('sw' if 'sw' in source["Parameters"][key] else None)
1✔
173
                if window_dest is not None:
1✔
174
                    check(window_dest == window_source and value[window_dest] == source["Parameters"][key][window_source] ,
1✔
175
                          TypeError,
176
                          f"The Parameter {key} is present multiple times, with different window. "
177
                          f"The Parameter {key} is called with {window_dest}={value[window_dest]} dimension and with {window_source}={source['Parameters'][key][window_source]} dimension.")
178

179
        log.debug("Merge Source")
1✔
180
        log.debug("\n"+pformat(source))
1✔
181
        log.debug("Merge Destination")
1✔
182
        log.debug("\n"+pformat(destination))
1✔
183
        result = copy.deepcopy(destination)
1✔
184
    else:
185
        result = destination
1✔
186
    for key, value in source.items():
1✔
187
        if isinstance(value, dict):
1✔
188
            # get node or create one
189
            node = result.setdefault(key, {})
1✔
190
            merge(value, node, False)
1✔
191
        else:
192
            if key in result and type(result[key]) is list:
1✔
193
                if key == 'tw' or key == 'sw':
1✔
194
                    if result[key][0] > value[0]:
1✔
195
                        result[key][0] = value[0]
1✔
196
                    if result[key][1] < value[1]:
1✔
197
                        result[key][1] = value[1]
1✔
198
            else:
199
                result[key] = value
1✔
200
    if main == True:
1✔
201
        log.debug("Merge Result")
1✔
202
        log.debug("\n" + pformat(result))
1✔
203
    return result
1✔
204

205
def check(condition, exception, string):
1✔
206
    if not condition:
1✔
207
        raise exception(string)
1✔
208

209
def argmax_max(iterable):
1✔
210
    return max(enumerate(iterable), key=lambda x: x[1])
×
211

212
def argmin_min(iterable):
1✔
213
    return min(enumerate(iterable), key=lambda x: x[1])
×
214

215
def argmax_dict(iterable: dict):
1✔
216
    return max(iterable.items(), key=lambda x: x[1])
1✔
217

218
def argmin_dict(iterable: dict):
1✔
219
    return min(iterable.items(), key=lambda x: x[1])
1✔
220

221
def count_gradient_operations(grad_fn):
1✔
222
    count = 0
1✔
223
    if grad_fn is None:
1✔
224
        return count
1✔
225
    nodes = [grad_fn]
×
226
    while nodes:
×
227
        node = nodes.pop()
×
228
        count += 1
×
229
        nodes.extend(next_fn[0] for next_fn in node.next_functions if next_fn[0] is not None)
×
230
    return count
×
231

232
def check_gradient_operations(X:dict):
1✔
233
    count = 0
1✔
234
    for key in X.keys():
1✔
235
        count += count_gradient_operations(X[key].grad_fn)
1✔
236
    return count
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc