• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 13056267505

30 Jan 2025 04:04PM UTC coverage: 94.525% (+0.6%) from 93.934%
13056267505

push

github

web-flow
Merge pull request #48 from tonegas/develop

Develop merge on main release 1.0.0

1185 of 1215 new or added lines in 21 files covered. (97.53%)

3 existing lines in 2 files now uncovered.

9426 of 9972 relevant lines covered (94.52%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.08
/nnodely/interpolation.py
1
import copy, inspect, textwrap, torch
1✔
2

3
import torch.nn as nn
1✔
4

5
from collections.abc import Callable
1✔
6

7
from nnodely.relation import NeuObj, Stream, AutoToStream
1✔
8
from nnodely.model import Model
1✔
9
from nnodely.utils import check, merge, enforce_types
1✔
10

11
from nnodely.logger import logging, nnLogger
1✔
12
log = nnLogger(__name__, logging.CRITICAL)
1✔
13

14
interpolation_relation_name = 'Interpolation'
1✔
15
class Interpolation(NeuObj):
1✔
16
    """
17
    Represents an Interpolation relation in the neural network model.
18
    This class performs linear interpolation of an input tensor `x` given two vectors of points.
19

20
    Parameters
21
    ----------
22
    x_points : list[int]|list[float]|list[torch.Tensor]
23
        A tensor containing the x-coordinates of the data points.
24
    y_points : list[int]|list[float]|list[torch.Tensor]
25
        A tensor containing the y-coordinates of the data points.
26
    mode : str, optional
27
        The type of interpolation to perform. Possible modalities are: ['linear', ].
28
        Default is 'linear'.
29

30
    Examples
31
    --------
32
    Example - basic usage:
33
        >>> x_points = [1.0, 2.0, 3.0, 4.0]
34
        >>> y_points = [1.0, 4.0, 9.0, 16.0]
35
        >>> x = Input('x')
36

37
        >>> rel1 = Interpolation(x_points=x_points,y_points=y_points, mode='linear')(x.last())
38
        
39
        >>> out = Output('out',rel1)
40
    """
41

42
    @enforce_types
1✔
43
    def __init__(self, x_points:list|None = None,
1✔
44
                 y_points:list|None = None,
45
                 mode:str|None = 'linear'):
46

47
        self.relation_name = interpolation_relation_name
1✔
48
        self.x_points = x_points
1✔
49
        self.y_points = y_points
1✔
50
        self.mode = mode
1✔
51

52
        self.available_modes = ['linear', 'polynomial']
1✔
53

54
        super().__init__('P' + interpolation_relation_name + str(NeuObj.count))
1✔
55
        check(len(x_points) == len(y_points), ValueError, 'The x_points and y_points must have the same length.')
1✔
56
        check(mode in self.available_modes, ValueError, f'The mode must be one of {self.available_modes}.')
1✔
57
        check(len(torch.tensor(x_points).shape) == 1, ValueError, 'The x_points must be a 1D tensor.')
1✔
58
        check(len(torch.tensor(y_points).shape) == 1, ValueError, 'The y_points must be a 1D tensor.')
1✔
59

60
    def __call__(self, obj:Stream) -> Stream:
1✔
61
        stream_name = interpolation_relation_name + str(Stream.count)
1✔
62
        check(type(obj) is Stream, TypeError, f"The type of {obj} is {type(obj)} and is not supported for Interpolation operation.")
1✔
63

64
        stream_json = merge(self.json,obj.json)
1✔
65
        stream_json['Relations'][stream_name] = [interpolation_relation_name, [obj.name], self.x_points, self.y_points, self.mode]
1✔
66
        return Stream(stream_name, stream_json, obj.dim)
1✔
67

68

69
class Interpolation_Layer(nn.Module):
1✔
70
    def __init__(self, x_points, y_points, mode='linear'):
1✔
71
        super(Interpolation_Layer, self).__init__()
1✔
72
        self.mode = mode
1✔
73
        ## Sort the points
74
        if type(x_points) is not torch.Tensor:
1✔
75
            x_points = torch.tensor(x_points)
1✔
76
        if type(y_points) is not torch.Tensor:
1✔
77
            y_points = torch.tensor(y_points)
1✔
78
        self.x_points, indices = torch.sort(x_points)
1✔
79
        self.y_points = y_points[indices]
1✔
80

81
        self.x_points = self.x_points.unsqueeze(-1)
1✔
82
        self.y_points = self.y_points.unsqueeze(-1)
1✔
83

84
    def forward(self, x):
1✔
85
        if self.mode == 'linear':
1✔
86
            return self.linear_interpolation(x)
1✔
87
        else:
NEW
88
            raise NotImplementedError
×
89
    
90
    def linear_interpolation(self, x):
1✔
91
        # Inputs: 
92
        # x: query point, a tensor of shape torch.Size([N, 1, 1])
93
        # x_data: map of x values, sorted in ascending order, a tensor of shape torch.Size([Q, 1])
94
        # y_data: map of y values, a tensor of shape torch.Size([Q, 1])
95
        # Output:
96
        # y: interpolated value at x, a tensor of shape torch.Size([N, 1, 1])
97

98
        # Saturate x to the range of x_data
99
        x = torch.min(torch.max(x,self.x_points[0]),self.x_points[-1])
1✔
100

101
        # Find the index of the closest value in x_data
102
        idx = torch.argmin(torch.abs(self.x_points[:-1] - x),dim=1)
1✔
103
        
104
        # Linear interpolation
105
        y = self.y_points[idx] + (self.y_points[idx+1] - self.y_points[idx])/(self.x_points[idx+1] - self.x_points[idx])*(x - self.x_points[idx])
1✔
106
        return y
1✔
107

108

109

110
        # x_interpolated = torch.zeros_like(x)
111
        # for i, val in enumerate(x):
112
        #     # Find the interval [x1, x2] such that x1 <= val <= x2
113
        #     idx = torch.searchsorted(self.x_points, val).item()
114
        #     if idx == 0:
115
        #         # val is less than the smallest x_point, extrapolate
116
        #         x_interpolated[i] = self.y_points[0]
117
        #     elif idx >= len(self.x_points):
118
        #         # val is greater than the largest x_point, extrapolate
119
        #         x_interpolated[i] = self.y_points[-1]
120
        #     else:
121
        #         # Perform linear interpolation between x_points[idx-1] and x_points[idx]
122
        #         x1, x2 = self.x_points[idx - 1], self.x_points[idx]
123
        #         y1, y2 = self.y_points[idx - 1], self.y_points[idx]
124
        #         # Linear interpolation formula
125
        #         x_interpolated[i] = y1 + (val - x1) * (y2 - y1) / (x2 - x1)
126
        # return x_interpolated
127

128
def createInterpolation(self, *inputs):
1✔
129
    return Interpolation_Layer(x_points=inputs[0], y_points=inputs[1], mode=inputs[2])
1✔
130

131
setattr(Model, interpolation_relation_name, createInterpolation)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc