• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 14268084640

04 Apr 2025 02:51PM UTC coverage: 97.035% (+0.04%) from 96.995%
14268084640

push

github

web-flow
Merge pull request #82 from tonegas/develop

Added some new features:

1. the import from pandas dataframe with resample feature
2. new example files for each layers
3. categorical loss

407 of 430 new or added lines in 20 files covered. (94.65%)

7 existing lines in 2 files now uncovered.

11453 of 11803 relevant lines covered (97.03%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.53
/nnodely/utils.py
1
import copy, torch, inspect
1✔
2
from collections import OrderedDict
1✔
3

4
import numpy as np
1✔
5

6
from pprint import pformat
1✔
7
from functools import wraps
1✔
8
from typing import get_type_hints
1✔
9
import keyword
1✔
10

11
from nnodely.logger import logging, nnLogger
1✔
12
log = nnLogger(__name__, logging.CRITICAL)
1✔
13

14
TORCH_DTYPE = torch.float32
1✔
15
NP_DTYPE = np.float32
1✔
16

17
ForbiddenTags = keyword.kwlist
1✔
18

19
def get_window(obj):
1✔
20
    return 'tw' if 'tw' in obj.dim else ('sw' if 'sw' in obj.dim else None)
1✔
21

22
def get_inputs(json, relation, inputs):
1✔
23
    # Get all the inputs needed to compute a specific relation from the json graph
24
    for rel in json['Relations'][relation][1]:
1✔
25
        if rel in (json['Inputs'] | json['States']): ## find an input
1✔
26
            return inputs.append(rel)
1✔
27
        else: ## another relation
NEW
28
            return get_inputs(json, rel, inputs) ## recursive call to find the inputs of the relation
×
29

30

31
def enforce_types(func):
1✔
32
    @wraps(func)
1✔
33
    def wrapper(*args, **kwargs):
1✔
34
        hints = get_type_hints(func)
1✔
35
        all_args = kwargs.copy()
1✔
36

37
        sig = OrderedDict(inspect.signature(func).parameters)
1✔
38
        if len(sig) != len(args):
1✔
39
            var_type = None
1✔
40
            for ind, arg in enumerate(args):
1✔
41
                if ind < len(list(sig.values())) and list(sig.values())[ind].kind == inspect.Parameter.VAR_POSITIONAL:
1✔
42
                    var_name = list(sig.keys())[ind]
1✔
43
                    var_type = sig.pop(var_name)
1✔
44
                if var_type:
1✔
45
                    sig[var_name+str(ind)] = var_type
1✔
46

47
        all_args.update(dict(zip(sig, args)))
1✔
48
        if 'self' in sig.keys():
1✔
49
            sig.pop('self')
1✔
50

51
        for arg_name, arg in all_args.items():
1✔
52
            if (arg_name in hints.keys() or arg_name in sig.keys()) and not isinstance(arg,sig[arg_name].annotation):
1✔
53
                raise TypeError(
1✔
54
                    f"In Function or Class {func} Expected argument '{arg}' to be of type {sig[arg_name].annotation}, but got {type(arg)}")
55

56
        # for arg, arg_type in hints.items():
57
        #     if arg in all_args and not isinstance(all_args[arg], arg_type):
58
        #         raise TypeError(
59
        #             f"In Function or Class {func} Expected argument '{arg}' to be of type {arg_type}, but got {type(all_args[arg]).__name__}")
60

61
        return func(*args, **kwargs)
1✔
62

63
    return wrapper
1✔
64

65

66
# Linear interpolation function, operating on batches of input data and returning batches of output data
67
def linear_interp(x,x_data,y_data):
1✔
68
    # Inputs: 
69
    # x: query point, a tensor of shape torch.Size([N, 1, 1])
70
    # x_data: map of x values, sorted in ascending order, a tensor of shape torch.Size([Q, 1])
71
    # y_data: map of y values, a tensor of shape torch.Size([Q, 1])
72
    # Output:
73
    # y: interpolated value at x, a tensor of shape torch.Size([N, 1, 1])
74

75
    # Saturate x to the range of x_data
76
    x = torch.min(torch.max(x,x_data[0]),x_data[-1])
1✔
77

78
    # Find the index of the closest value in x_data
79
    idx = torch.argmin(torch.abs(x_data[:-1] - x),dim=1)
1✔
80
    
81
    # Linear interpolation
82
    y = y_data[idx] + (y_data[idx+1] - y_data[idx])/(x_data[idx+1] - x_data[idx])*(x - x_data[idx])
1✔
83
    return y
1✔
84

85
def tensor_to_list(data):
1✔
86
    if isinstance(data, torch.Tensor):
1✔
87
        # Converte il tensore in una lista
88
        return data.tolist()
1✔
89
    elif isinstance(data, dict):
1✔
90
        # Ricorsione per i dizionari
91
        return {key: tensor_to_list(value) for key, value in data.items()}
1✔
92
    elif isinstance(data, list):
1✔
93
        # Ricorsione per le liste
94
        return [tensor_to_list(item) for item in data]
1✔
95
    elif isinstance(data, tuple):
1✔
96
        # Ricorsione per tuple
97
        return tuple(tensor_to_list(item) for item in data)
×
98
    elif isinstance(data, torch.nn.modules.container.ParameterDict):
1✔
99
        # Ricorsione per parameter dict
100
        return {key: tensor_to_list(value) for key, value in data.items()}
1✔
101
    else:
102
        # Altri tipi di dati rimangono invariati
103
        return data
1✔
104

105
# Codice per comprimere le relazioni
106
        #print(self.json['Relations'])
107
        # used_rel = {string for values in self.json['Relations'].values() for string in values[1]}
108
        # if obj1.name not in used_rel and obj1.name in self.json['Relations'].keys() and self.json['Relations'][obj1.name][0] == add_relation_name:
109
        #     self.json['Relations'][self.name] = [add_relation_name, self.json['Relations'][obj1.name][1]+[obj2.name]]
110
        #     del self.json['Relations'][obj1.name]
111
        # else:
112
        # Devo aggiungere un operazione che rimuove un operazione di Add,Sub,Mul,Div se può essere unita ad un'altra operazione dello stesso tipo
113
        #
114
def merge(source, destination, main = True):
1✔
115
    if main:
1✔
116
        for key, value in destination["Functions"].items():
1✔
117
            if key in source["Functions"].keys() and 'n_input' in value.keys() and 'n_input' in source["Functions"][key].keys():
1✔
118
                check(value == {} or source["Functions"][key] == {} or value['n_input'] == source["Functions"][key]['n_input'],
1✔
119
                      TypeError,
120
                      f"The ParamFun {key} is present multiple times, with different number of inputs. "
121
                      f"The ParamFun {key} is called with {value['n_input']} parameters and with {source['Functions'][key]['n_input']} parameters.")
122
        for key, value in destination["Parameters"].items():
1✔
123
            if key in source["Parameters"].keys():
1✔
124
                if 'dim' in value.keys() and 'dim' in source["Parameters"][key].keys():
1✔
125
                    check(value['dim'] == source["Parameters"][key]['dim'],
1✔
126
                          TypeError,
127
                          f"The Parameter {key} is present multiple times, with different dimensions. "
128
                          f"The Parameter {key} is called with {value['dim']} dimension and with {source['Parameters'][key]['dim']} dimension.")
129
                window_dest = 'tw' if 'tw' in value else ('sw' if 'sw' in value else None)
1✔
130
                window_source = 'tw' if 'tw' in source["Parameters"][key] else ('sw' if 'sw' in source["Parameters"][key] else None)
1✔
131
                if window_dest is not None:
1✔
132
                    check(window_dest == window_source and value[window_dest] == source["Parameters"][key][window_source] ,
1✔
133
                          TypeError,
134
                          f"The Parameter {key} is present multiple times, with different window. "
135
                          f"The Parameter {key} is called with {window_dest}={value[window_dest]} dimension and with {window_source}={source['Parameters'][key][window_source]} dimension.")
136

137
        log.debug("Merge Source")
1✔
138
        log.debug("\n"+pformat(source))
1✔
139
        log.debug("Merge Destination")
1✔
140
        log.debug("\n"+pformat(destination))
1✔
141
        result = copy.deepcopy(destination)
1✔
142
    else:
143
        result = destination
1✔
144
    for key, value in source.items():
1✔
145
        if isinstance(value, dict):
1✔
146
            # get node or create one
147
            node = result.setdefault(key, {})
1✔
148
            merge(value, node, False)
1✔
149
        else:
150
            if key in result and type(result[key]) is list:
1✔
151
                if key == 'tw' or key == 'sw':
1✔
152
                    if result[key][0] > value[0]:
1✔
153
                        result[key][0] = value[0]
1✔
154
                    if result[key][1] < value[1]:
1✔
155
                        result[key][1] = value[1]
1✔
156
            else:
157
                result[key] = value
1✔
158
    if main == True:
1✔
159
        log.debug("Merge Result")
1✔
160
        log.debug("\n" + pformat(result))
1✔
161
    return result
1✔
162

163
def check(condition, exception, string):
1✔
164
    if not condition:
1✔
165
        raise exception(string)
1✔
166

167
def argmax_max(iterable):
1✔
168
    return max(enumerate(iterable), key=lambda x: x[1])
×
169

170
def argmin_min(iterable):
1✔
171
    return min(enumerate(iterable), key=lambda x: x[1])
×
172

173
def argmax_dict(iterable: dict):
1✔
174
    return max(iterable.items(), key=lambda x: x[1])
1✔
175

176
def argmin_dict(iterable: dict):
1✔
177
    return min(iterable.items(), key=lambda x: x[1])
1✔
178

179
def count_gradient_operations(grad_fn):
1✔
180
    count = 0
1✔
181
    if grad_fn is None:
1✔
182
        return count
1✔
NEW
183
    nodes = [grad_fn]
×
NEW
184
    while nodes:
×
NEW
185
        node = nodes.pop()
×
NEW
186
        count += 1
×
NEW
187
        nodes.extend(next_fn[0] for next_fn in node.next_functions if next_fn[0] is not None)
×
NEW
188
    return count
×
189

190
def check_gradient_operations(X:dict):
1✔
191
    count = 0
1✔
192
    for key in X.keys():
1✔
193
        count += count_gradient_operations(X[key].grad_fn)
1✔
194
    return count
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc