• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 13056267505

30 Jan 2025 04:04PM UTC coverage: 94.525% (+0.6%) from 93.934%
13056267505

push

github

web-flow
Merge pull request #48 from tonegas/develop

Develop merge on main release 1.0.0

1185 of 1215 new or added lines in 21 files covered. (97.53%)

3 existing lines in 2 files now uncovered.

9426 of 9972 relevant lines covered (94.52%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.37
/nnodely/utils.py
1
import copy, torch, inspect
1✔
2
import numpy as np
1✔
3

4
from pprint import pformat
1✔
5
from functools import wraps
1✔
6
from typing import get_type_hints
1✔
7

8
from nnodely.logger import logging, nnLogger
1✔
9
log = nnLogger(__name__, logging.CRITICAL)
1✔
10

11
TORCH_DTYPE = torch.float32
1✔
12
NP_DTYPE = np.float32
1✔
13

14
def enforce_types(func):
1✔
15
    @wraps(func)
1✔
16
    def wrapper(*args, **kwargs):
1✔
17
        hints = get_type_hints(func)
1✔
18
        all_args = kwargs.copy()
1✔
19
        all_args.update(dict(zip(inspect.signature(func).parameters, args)))
1✔
20

21
        for arg, arg_type in hints.items():
1✔
22
            if arg in all_args and not isinstance(all_args[arg], arg_type):
1✔
23
                raise TypeError(
×
24
                    f"Expected argument '{arg}' to be of type {arg_type.__name__}, but got {type(all_args[arg]).__name__}")
25

26
        return func(*args, **kwargs)
1✔
27

28
    return wrapper
1✔
29

30
# Linear interpolation function, operating on batches of input data and returning batches of output data
31
def linear_interp(x,x_data,y_data):
1✔
32
    # Inputs: 
33
    # x: query point, a tensor of shape torch.Size([N, 1, 1])
34
    # x_data: map of x values, sorted in ascending order, a tensor of shape torch.Size([Q, 1])
35
    # y_data: map of y values, a tensor of shape torch.Size([Q, 1])
36
    # Output:
37
    # y: interpolated value at x, a tensor of shape torch.Size([N, 1, 1])
38

39
    # Saturate x to the range of x_data
40
    x = torch.min(torch.max(x,x_data[0]),x_data[-1])
1✔
41

42
    # Find the index of the closest value in x_data
43
    idx = torch.argmin(torch.abs(x_data[:-1] - x),dim=1)
1✔
44
    
45
    # Linear interpolation
46
    y = y_data[idx] + (y_data[idx+1] - y_data[idx])/(x_data[idx+1] - x_data[idx])*(x - x_data[idx])
1✔
47
    return y
1✔
48

49
def tensor_to_list(data):
1✔
50
    if isinstance(data, torch.Tensor):
1✔
51
        # Converte il tensore in una lista
52
        return data.tolist()
1✔
53
    elif isinstance(data, dict):
1✔
54
        # Ricorsione per i dizionari
55
        return {key: tensor_to_list(value) for key, value in data.items()}
1✔
56
    elif isinstance(data, list):
1✔
57
        # Ricorsione per le liste
58
        return [tensor_to_list(item) for item in data]
1✔
59
    elif isinstance(data, tuple):
1✔
60
        # Ricorsione per tuple
61
        return tuple(tensor_to_list(item) for item in data)
×
62
    elif isinstance(data, torch.nn.modules.container.ParameterDict):
1✔
63
        # Ricorsione per parameter dict
64
        return {key: tensor_to_list(value) for key, value in data.items()}
1✔
65
    else:
66
        # Altri tipi di dati rimangono invariati
67
        return data
1✔
68

69
def merge(source, destination, main = True):
1✔
70
    if main:
1✔
71
        log.debug("Merge Source")
1✔
72
        log.debug("\n"+pformat(source))
1✔
73
        log.debug("Merge Destination")
1✔
74
        log.debug("\n"+pformat(destination))
1✔
75
        result = copy.deepcopy(destination)
1✔
76
    else:
77
        result = destination
1✔
78
    for key, value in source.items():
1✔
79
        if isinstance(value, dict):
1✔
80
            # get node or create one
81
            node = result.setdefault(key, {})
1✔
82
            merge(value, node, False)
1✔
83
        else:
84
            if key in result and type(result[key]) is list:
1✔
85
                if key == 'tw' or key == 'sw':
1✔
86
                    if result[key][0] > value[0]:
1✔
87
                        result[key][0] = value[0]
1✔
88
                    if result[key][1] < value[1]:
1✔
89
                        result[key][1] = value[1]
1✔
90
            else:
91
                result[key] = value
1✔
92
    if main == True:
1✔
93
        log.debug("Merge Result")
1✔
94
        log.debug("\n" + pformat(result))
1✔
95
    return result
1✔
96

97
def check(condition, exception, string):
1✔
98
    if not condition:
1✔
99
        raise exception(string)
1✔
100

101
def argmax_max(iterable):
1✔
UNCOV
102
    return max(enumerate(iterable), key=lambda x: x[1])
×
103

104
def argmin_min(iterable):
1✔
UNCOV
105
    return min(enumerate(iterable), key=lambda x: x[1])
×
106

107
def argmax_dict(iterable: dict):
1✔
108
    return max(iterable.items(), key=lambda x: x[1])
1✔
109

110
def argmin_dict(iterable: dict):
1✔
111
    return min(iterable.items(), key=lambda x: x[1])
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc