• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WassimTenachi / PhySO / #10

14 Nov 2023 11:52PM UTC coverage: 82.385% (-1.1%) from 83.52%
#10

push

coveralls-python

WassimTenachi
Merge branch 'dev'

16 of 46 new or added lines in 2 files covered. (34.78%)

56 existing lines in 2 files now uncovered.

4635 of 5626 relevant lines covered (82.39%)

0.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.18
/physo/physym/tests/functions_UnitTest.py
1
import time
1✔
2
import unittest
1✔
3
import numpy as np
1✔
4
import torch as torch
1✔
5
import matplotlib.pyplot as plt
1✔
6

7
# Internal imports
8
from physo.physym import functions as Func
1✔
9
from physo.physym.functions import data_conversion, data_conversion_inv
1✔
10

11

12
# Test token and output shapes
13
def test_one_token(tester, token):
1✔
14
    data0 = data_conversion ( np.arange(-5, 5, 0.5)     )
1✔
15
    data1 = data_conversion ( np.arange(-5, 5, 0.5) + 1 )
1✔
16
    data2 = data_conversion ( np.arange(-5, 5, 0.5) * 2 )   # 0 in same place as data0
1✔
17
    pi    = data_conversion ( np.array(np.pi) )
1✔
18
    large = data_conversion ( np.array(1e10)  )   # large float
1✔
19
    n_data = len(data0)
1✔
20

21
    # Binary
22
    if token.arity == 2:
1✔
23
        tester.assertEqual(len( data_conversion_inv ( token(data0, data1)              )) , n_data)   # np.array    , np.array
1✔
24
        tester.assertEqual(len( data_conversion_inv ( token(data0, data2)              )) , n_data)   # np.array    , np.array with (0,0)
1✔
25
        tester.assertEqual(len( data_conversion_inv ( token(data0, pi   )              )) , n_data)   # np.array    , float
1✔
26
        tester.assertEqual(len( data_conversion_inv ( token(data0, large)              )) , n_data)   # np.array    , large float
1✔
27
        tester.assertEqual(len( data_conversion_inv ( token(large, data0)              )) , n_data)   # large float , np.array
1✔
28
        tester.assertEqual(len( data_conversion_inv ( token(*torch.stack((data0, data1))) )) , n_data)  # *[np. array    , np.array]
1✔
29
        # large float , large float
30
        # expecting length = 1 or n_data to be able to compute afterwards
31
        out_len = np.shape(np.atleast_1d(
1✔
32
                                data_conversion_inv ( token(large, large)              )))
33
        tester.assertEqual(out_len == n_data or out_len == (1,), True)
1✔
34
    # Unary
35
    if token.arity == 1:
1✔
36
        tester.assertEqual(len( data_conversion_inv ( token(data0)                     )) , n_data)  # np.array
1✔
37
        # large float
38
        # expecting length = 1 or n_data to be able to compute afterwards
39
        out_len = np.shape(np.atleast_1d(
1✔
40
                                data_conversion_inv ( token(large)                     )))
41
        tester.assertEqual(out_len == n_data or out_len == (1,), True)
1✔
42
    # Zero-arity
43
    if token.arity == 0:
1✔
44
        out_len = np.shape(np.atleast_1d(
1✔
45
                                data_conversion_inv( token()                           )))
46
        bool_works = (out_len == (n_data,) or out_len == (1,))
1✔
47
        tester.assertEqual(bool_works, True)
1✔
48

49

50
class FuncTest(unittest.TestCase):
1✔
51

52
    # Test make tokens function
53
    def test_make_tokens(self):
1✔
54
        op_names = ["mul", "add", "neg", "inv", "sin"]
1✔
55
        try:
1✔
56
            my_tokens = Func.make_tokens(op_names          = op_names,
1✔
57
                                         input_var_ids     = {"x0" : 0     , "x1" : 1 },
58
                                         constants         = {"pi" : np.pi , "c"  : 3e8},
59
                                         free_constants    = {"c0", "c1"},
60
                                         use_protected_ops = False,
61
                                         )
UNCOV
62
        except Exception: self.fail("Make tokens function failed")
×
63

64
    # Test make tokens function with units and complexity
65
    def test_make_tokens_units_and_complexity(self):
1✔
66
        # Test creation
67
        try:
1✔
68
            my_tokens = Func.make_tokens(
1✔
69
                # operations
70
                op_names             = ["mul", "add", "neg", "inv", "sin"],
71
                use_protected_ops    = False,
72
                # input variables
73
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
74
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] },
75
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
76
                # constants
77
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
78
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : [0, 0, 1] },
79
                constants_complexity = {"pi" : 0.        , "c" : 0.        , "M" : 1.        },
80
                # free constants
81
                free_constants            = {"c0"             , "c1"               , "c2"             },
82
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
83
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : [0, 0, 1] },
84
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
85
                                                )
86
        except Exception:
×
UNCOV
87
            self.fail("Make tokens function failed")
×
88
        # Test that properties were encoded
89
        my_tokens_dict = {token.name: token for token in my_tokens}
1✔
90
        # Checking sample units # Checking 3 first values because phy_units is padded to match Lib.UNITS_VECTOR_SIZE
91
        is_equal = np.array_equal(my_tokens_dict["x"].phy_units[:3], [1, 0, 0])
1✔
92
        self.assertEqual(is_equal, True)
1✔
93
        is_equal = np.array_equal(my_tokens_dict["v"].phy_units[:3], [1, -1, 0])
1✔
94
        self.assertEqual(is_equal, True)
1✔
95
        is_equal = np.array_equal(my_tokens_dict["c"].phy_units[:3], [1, -1, 0])
1✔
96
        self.assertEqual(is_equal, True)
1✔
97
        is_equal = np.array_equal(my_tokens_dict["M"].phy_units[:3], [0, 0, 1])
1✔
98
        self.assertEqual(is_equal, True)
1✔
99
        is_equal = np.array_equal(my_tokens_dict["c0"].phy_units[:3], [0, 0, 0])
1✔
100
        self.assertEqual(is_equal, True)
1✔
101
        is_equal = np.array_equal(my_tokens_dict["c1"].phy_units[:3], [1, -1, 0])
1✔
102
        self.assertEqual(is_equal, True)
1✔
103
        is_equal = np.array_equal(my_tokens_dict["c2"].phy_units[:3], [0, 0, 1])
1✔
104
        self.assertEqual(is_equal, True)
1✔
105
        # Checking sample complexities
106
        is_equal = np.array_equal(my_tokens_dict["x"].complexity, 0.)
1✔
107
        self.assertEqual(is_equal, True)
1✔
108
        is_equal = np.array_equal(my_tokens_dict["v"].complexity, 1.)
1✔
109
        self.assertEqual(is_equal, True)
1✔
110
        is_equal = np.array_equal(my_tokens_dict["c"].complexity, 0.)
1✔
111
        self.assertEqual(is_equal, True)
1✔
112
        is_equal = np.array_equal(my_tokens_dict["M"].complexity, 1.)
1✔
113
        self.assertEqual(is_equal, True)
1✔
114
        is_equal = np.array_equal(my_tokens_dict["c1"].complexity, 0.)
1✔
115
        self.assertEqual(is_equal, True)
1✔
116
        is_equal = np.array_equal(my_tokens_dict["c2"].complexity, 1.)
1✔
117
        self.assertEqual(is_equal, True)
1✔
118

119
    # Test make tokens function with units and complexity, missing units or complexity in dict
120
    def test_make_tokens_units_and_complexity_missing_info_warnings(self):
1✔
121

122
        # Test missing units in input variables
123
        with self.assertWarns(Warning):
1✔
124
            my_tokens = Func.make_tokens(
1✔
125
                # operations
126
                op_names             = ["mul", "add", "neg", "inv", "sin"],
127
                use_protected_ops    = False,
128
                # input variables
129
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
130
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0]},
131
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
132
                # constants
133
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
134
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : [0, 0, 1] },
135
                constants_complexity = {"pi" : 0.        , "c" : 0.        , "M" : 1.        },
136
                # free constants
137
                free_constants            = {"c0"             , "c1"               , "c2"             },
138
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
139
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : [0, 0, 1] },
140
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
141
                                                )
142
        # Test missing complexity in input variables
143
        with self.assertWarns(Warning):
1✔
144
            my_tokens = Func.make_tokens(
1✔
145
                # operations
146
                op_names             = ["mul", "add", "neg", "inv", "sin"],
147
                use_protected_ops    = False,
148
                # input variables
149
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
150
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] },
151
                input_var_complexity = {"x" : 0.        , "v" : 1.         },
152
                # constants
153
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
154
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : [0, 0, 1] },
155
                constants_complexity = {"pi" : 0.        , "c" : 0.        , "M" : 1.        },
156
                # free constants
157
                free_constants            = {"c0"             , "c1"               , "c2"             },
158
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
159
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : [0, 0, 1] },
160
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
161
                                                )
162
        # Test missing units in constants
163
        with self.assertWarns(Warning):
1✔
164
            my_tokens = Func.make_tokens(
1✔
165
                # operations
166
                op_names             = ["mul", "add", "neg", "inv", "sin"],
167
                use_protected_ops    = False,
168
                # input variables
169
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
170
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] },
171
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
172
                # constants
173
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
174
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0] },
175
                constants_complexity = {"pi" : 0.        , "c" : 0.        , "M" : 1.        },
176
                # free constants
177
                free_constants            = {"c0"             , "c1"               , "c2"             },
178
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
179
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : [0, 0, 1] },
180
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
181
                                                )
182
        # Test missing complexity in constants
183
        with self.assertWarns(Warning):
1✔
184
            my_tokens = Func.make_tokens(
1✔
185
                # operations
186
                op_names             = ["mul", "add", "neg", "inv", "sin"],
187
                use_protected_ops    = False,
188
                # input variables
189
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
190
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] },
191
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
192
                # constants
193
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
194
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : [0, 0, 1] },
195
                constants_complexity = {"pi" : 0.        , "c" : 0.        },
196
                # free constants
197
                free_constants            = {"c0"             , "c1"               , "c2"             },
198
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
199
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : [0, 0, 1] },
200
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
201
                                                )
202
        # Test missing units in free constants
203
        with self.assertWarns(Warning):
1✔
204
            my_tokens = Func.make_tokens(
1✔
205
                # operations
206
                op_names             = ["mul", "add", "neg", "inv", "sin"],
207
                use_protected_ops    = False,
208
                # input variables
209
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
210
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] },
211
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
212
                # constants
213
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
214
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0] },
215
                constants_complexity = {"pi" : 0.        , "c" : 0.        , "M" : 1.        },
216
                # free constants
217
                free_constants            = {"c0"             , "c1"               , "c2"             },
218
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
219
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] ,                  },
220
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
221
                                                )
222
        # Test missing complexity in free constants
223
        with self.assertWarns(Warning):
1✔
224
            my_tokens = Func.make_tokens(
1✔
225
                # operations
226
                op_names             = ["mul", "add", "neg", "inv", "sin"],
227
                use_protected_ops    = False,
228
                # input variables
229
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
230
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] },
231
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
232
                # constants
233
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
234
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : [0, 0, 1] },
235
                constants_complexity = {"pi" : 0.        , "c" : 0.        },
236
                # free constants
237
                free_constants            = {"c0"             , "c1"               , "c2"             },
238
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
239
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : [0, 0, 1] },
240
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         ,                  },
241
                                                )
242
        # Test missing init_val in free constants
243
        with self.assertWarns(Warning):
1✔
244
            my_tokens = Func.make_tokens(
1✔
245
                # operations
246
                op_names             = ["mul", "add", "neg", "inv", "sin"],
247
                use_protected_ops    = False,
248
                # input variables
249
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
250
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] },
251
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
252
                # constants
253
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
254
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : [0, 0, 1] },
255
                constants_complexity = {"pi" : 0.        , "c" : 0.                          },
256
                # free constants
257
                free_constants            = {"c0"             , "c1"               , "c2"             },
258
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        ,                  },
259
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : [0, 0, 1] },
260
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
261
                                                )
262

263
    # Test make tokens function with wrong units
264
    def test_make_tokens_units_and_complexity_wrong_unit(self):
1✔
265
        # Test unit too large in input variables
266
        with self.assertRaises(AssertionError):
1✔
267
            my_tokens = Func.make_tokens(
1✔
268
                # operations
269
                op_names             = ["mul", "add", "neg", "inv", "sin"],
270
                use_protected_ops    = False,
271
                # input variables
272
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
273
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : np.ones(10000) },
274
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
275
                # constants
276
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
277
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : [0, 0, 1] },
278
                constants_complexity = {"pi" : 0.        , "c" : 0.        , "M" : 1.        },
279
                # free constants
280
                free_constants            = {"c0"             , "c1"               , "c2"             },
281
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
282
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : [0, 0, 1] },
283
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
284
                                                )
285
        # Test unit too large in constants
286
        with self.assertRaises(AssertionError):
1✔
287
            my_tokens = Func.make_tokens(
1✔
288
                # operations
289
                op_names             = ["mul", "add", "neg", "inv", "sin"],
290
                use_protected_ops    = False,
291
                # input variables
292
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
293
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] },
294
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
295
                # constants
296
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
297
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : np.ones(10000) },
298
                constants_complexity = {"pi" : 0.        , "c" : 0.        , "M" : 1.        },
299
                # free constants
300
                free_constants            = {"c0"             , "c1"               , "c2"             },
301
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
302
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : [0, 0, 1] },
303
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
304
                                                )
305
        # Test unit too large in free constants
306
        with self.assertRaises(AssertionError):
1✔
307
            my_tokens = Func.make_tokens(
1✔
308
                # operations
309
                op_names             = ["mul", "add", "neg", "inv", "sin"],
310
                use_protected_ops    = False,
311
                # input variables
312
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
313
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] },
314
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
315
                # constants
316
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
317
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : [0, 0, 1] },
318
                constants_complexity = {"pi" : 0.        , "c" : 0.        , "M" : 1.        },
319
                # free constants
320
                free_constants            = {"c0"             , "c1"               , "c2"             },
321
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
322
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : np.ones(10000) },
323
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
324
                                                )
325

326
        # Test units having wrong variable type
327
        with self.assertRaises(AssertionError):
1✔
328
            my_tokens = Func.make_tokens(
1✔
329
                # operations
330
                op_names             = ["mul", "add", "neg", "inv", "sin"],
331
                use_protected_ops    = False,
332
                # input variables
333
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
334
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] },
335
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
336
                # constants
337
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
338
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : ['a', 'b', 1] },
339
                constants_complexity = {"pi" : 0.        , "c" : 0.        , "M" : 1.        },
340
                # free constants
341
                free_constants            = {"c0"             , "c1"               , "c2"             },
342
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
343
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : [0, 0, 1] },
344
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
345
                                                )
346
        # Test units having wrong shape
347
        with self.assertRaises(AssertionError):
1✔
348
            my_tokens = Func.make_tokens(
1✔
349
                # operations
350
                op_names             = ["mul", "add", "neg", "inv", "sin"],
351
                use_protected_ops    = False,
352
                # input variables
353
                input_var_ids        = {"x" : 0         , "v" : 1          , "t" : 2,        },
354
                input_var_units      = {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] },
355
                input_var_complexity = {"x" : 0.        , "v" : 1.         , "t" : 0.,       },
356
                # constants
357
                constants            = {"pi" : np.pi     , "c" : 3e8       , "M" : 1e6       },
358
                constants_units      = {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : np.ones((7,7)) },
359
                constants_complexity = {"pi" : 0.        , "c" : 0.        , "M" : 1.        },
360
                # free constants
361
                free_constants            = {"c0"             , "c1"               , "c2"             },
362
                free_constants_init_val   = {"c0" : 1.        , "c1"  : 10.        , "c2" : 1.        },
363
                free_constants_units      = {"c0" : [0, 0, 0] , "c1"  : [1, -1, 0] , "c2" : [0, 0, 1] },
364
                free_constants_complexity = {"c0" : 0.        , "c1"  : 0.         , "c2" : 1.        },
365
                                                )
366

367
    # Test unknown function exception
368
    def test_make_tokens_unknown_function(self):
1✔
369
        with self.assertRaises(Func.UnknownFunction, ):
1✔
370
            my_tokens = Func.make_tokens(op_names=["mul", "function_that_does_not_exist"])
1✔
371

372
    # Test all protected tokens and output shapes of their underlying functions
373
    def test_protected_tokens(self):
1✔
374
        my_tokens = Func.make_tokens(op_names="all",  #
1✔
375
                                     constants={"pi": np.pi, "c": 3e8},
376
                                     use_protected_ops=True, )
377
        for token in my_tokens:
1✔
378
            test_one_token(tester=self, token=token)
1✔
379

380
    # Test all unprotected tokens and output shapes of their underlying functions
381
    def test_unprotected_tokens(self):
1✔
382
        my_tokens = Func.make_tokens(op_names="all",
1✔
383
                                     constants={"pi": np.pi, "c": 3e8},
384
                                     use_protected_ops=False, )
385
        for token in my_tokens:
1✔
386
            test_one_token(tester=self, token=token)
1✔
387

388
    # Test that arity and complexity have the same values in protected and unprotected modes
389
    def test_protected_unprotected_same_attributes(self):
1✔
390
        unprotected_tokens = Func.make_tokens(op_names="all", use_protected_ops=False, )
1✔
391
        protected_tokens   = Func.make_tokens(op_names="all", use_protected_ops=True, )
1✔
392
        unprotected_tokens_names = np.array([token.name for token in unprotected_tokens])
1✔
393
        protected_tokens_names   = np.array([token.name for token in protected_tokens])
1✔
394
        for name in protected_tokens_names:
1✔
395
            # ------------------------------------ protected token ------------------------------------
396
            protected_token = protected_tokens[np.argwhere(protected_tokens_names == name)]
1✔
397
            # Check that there is only one version of current token in protected tokens
398
            self.assertEqual(np.array_equal(protected_token.shape, [1, 1]), True)
1✔
399
            protected_token = protected_token[0, 0]
1✔
400
            # ------------------------------------ unprotected token ------------------------------------
401
            unprotected_token = unprotected_tokens[np.argwhere(unprotected_tokens_names == name)]
1✔
402
            # Check that there is only one version of current token in unprotected tokens
403
            self.assertEqual(np.array_equal(unprotected_token.shape, [1, 1]), True)
1✔
404
            unprotected_token = unprotected_token[0, 0]
1✔
405
            # ---------------------------- check that attributes are the same ----------------------------
406
            for attribute_name, attribute_val in protected_token.__dict__.items():
1✔
407
                attribute_val_in_protected   = protected_token  .__dict__[attribute_name]
1✔
408
                attribute_val_in_unprotected = unprotected_token.__dict__[attribute_name]
1✔
409
                # Checking all attributes except function which is bound to be different in protected vs unprotected
410
                if attribute_name != "function":
1✔
411
                    # Do regular comparison for str (can not compare str using equal_nan=True)
412
                    if isinstance(attribute_val_in_protected   , str) or \
1✔
413
                       isinstance(attribute_val_in_unprotected , str):
414
                        is_equal = np.array_equal(attribute_val_in_protected  ,
1✔
415
                                                  attribute_val_in_unprotected,
416
                                                  equal_nan=False)
417
                    else:
418
                        is_equal = np.array_equal(attribute_val_in_protected  ,
1✔
419
                                                  attribute_val_in_unprotected,
420
                                                  equal_nan=True)
421
                    self.assertEqual(is_equal, True)
1✔
422

423
    # Test that arity of functions does not exceed the max nb of children
424
    def test_max_arity(self):
1✔
425
        protected_tokens = Func.make_tokens(op_names="all", use_protected_ops=True, )
1✔
426
        for token in protected_tokens:
1✔
427
            self.assertIs(token.arity <= Func.Tok.MAX_ARITY, True)
1✔
428

429
    # Test that tokens pointing to data work
430
    def test_data_pointers_work(self):
1✔
431
        const_data0 = data_conversion ( np.random.rand() )
1✔
432
        const_data1 = data_conversion ( np.random.rand() )
1✔
433
        my_tokens = Func.make_tokens(op_names="all",
1✔
434
                                     constants={"pi": np.pi, "const1": 1., "data0": const_data0,
435
                                                           "data1": const_data1}, )
436
        my_tokens_dict = {token.name: token for token in my_tokens}
1✔
437
        # test that tokens point to data
438
        bool = np.array_equal(data_conversion_inv ( my_tokens_dict["pi"]()    ) , np.pi)
1✔
439
        self.assertEqual(bool, True)
1✔
440
        bool = np.array_equal(data_conversion_inv ( my_tokens_dict["const1"]() ) , 1.)
1✔
441
        self.assertEqual(bool, True)
1✔
442
        bool = np.array_equal(data_conversion_inv ( my_tokens_dict["data0"]()  ) , const_data0)
1✔
443
        self.assertEqual(bool, True)
1✔
444
        bool = np.array_equal(data_conversion_inv ( my_tokens_dict["data1"]()  ) , const_data1)
1✔
445
        self.assertEqual(bool, True)
1✔
446
        ## test mul(data0,data1) === np.multiply(const_data0, const_data1)
447
        bool = np.array_equal(data_conversion_inv ( my_tokens_dict["mul"](my_tokens_dict["data0"](), my_tokens_dict["data1"]())),
1✔
448
                              data_conversion_inv ( my_tokens_dict["mul"].function(const_data0, const_data1)))
449
        self.assertEqual(bool, True)
1✔
450

451
    # Test that behavior objects contain different operation names
452
    # (eg. "add" must only have one unique behavior)
453
    def test_behavior_contain_different_ops(self):
1✔
454
        unprotected_tokens = Func.make_tokens(op_names="all", use_protected_ops=False, )
1✔
455
        protected_tokens = Func.make_tokens(op_names="all", use_protected_ops=True, )
1✔
456
        unprotected_tokens_names = [token.name for token in unprotected_tokens]
1✔
457
        protected_tokens_names = [token.name for token in protected_tokens]
1✔
458
        for name in unprotected_tokens_names+protected_tokens_names:
1✔
459
            count = 0
1✔
460
            for _, behavior in Func.OP_UNIT_BEHAVIORS_DICT.items():
1✔
461
                if (name in behavior.op_names): count+=1
1✔
462
            if count >1: self.fail("Token named %s appears in more than one behavior."%(name))
1✔
463

464
    # Test that each behavior has a unique identifier
465
    def test_behavior_have_unique_ids(self):
1✔
466
        ids = [behavior.behavior_id for _, behavior in Func.OP_UNIT_BEHAVIORS_DICT.items()]
1✔
467
        if not len(ids) == len(np.unique(ids)):
1✔
UNCOV
468
            self.fail("Behaviors ids are not unique, ids = %s"%(str(ids)))
×
469

470
    # Test that tokens encoded with dimensionless behavior id are dimensionless
471
    def test_behavior_dimensionless_are_dimensionless(self):
1✔
472
        dimensionless_id = Func.OP_UNIT_BEHAVIORS_DICT["UNARY_DIMENSIONLESS_OP"].behavior_id
1✔
473
        unprotected_tokens = Func.make_tokens(op_names="all", use_protected_ops=False, )
1✔
474
        protected_tokens = Func.make_tokens(op_names="all", use_protected_ops=True, )
1✔
475
        for token in unprotected_tokens.tolist() + protected_tokens.tolist():
1✔
476
            if token.behavior_id == dimensionless_id:
1✔
477
                self.assertEqual(token.is_constraining_phy_units, True)
1✔
478
                works_bool = np.array_equal(token.phy_units, np.zeros(Func.Tok.UNITS_VECTOR_SIZE))
1✔
479
                self.assertEqual(works_bool, True)
1✔
480

481
    # Test GroupUnitsBehavior
482
    def test_GroupUnitsBehavior (self):
1✔
483
        behavior_id_mul = Func.UNIT_BEHAVIORS_DICT["MULTIPLICATION_OP"].behavior_id
1✔
484
        behavior_id_div = Func.UNIT_BEHAVIORS_DICT["DIVISION_OP"]      .behavior_id
1✔
485
        group = Func.UNIT_BEHAVIORS_DICT["BINARY_MULTIPLICATIVE_OP"]
1✔
486
        # --- Vect 1 ---
487
        test_vect_ids   = np.array([behavior_id_mul]*int(1e4) + [behavior_id_div]*int(1e4),)
1✔
488
        batch_size = test_vect_ids.shape[0]
1✔
489
        # Should contain all True statements
490
        t0 = time.perf_counter()
1✔
491
        equal_test = group.is_id(test_vect_ids)
1✔
492
        t1 = time.perf_counter()
1✔
493
        # print("Eq time = %f ms"%((t1-t0)*1e3))
494
        works_bool = (equal_test.dtype == bool)
1✔
495
        self.assertEqual(works_bool, True)
1✔
496
        works_bool = np.array_equal(equal_test, np.full(shape=batch_size, fill_value=True))
1✔
497
        self.assertEqual(works_bool, True)
1✔
498
        # --- Vect 2 ---
499
        equal_test = group.is_id([behavior_id_mul, behavior_id_div, 999999999999])
1✔
500
        works_bool = (equal_test.dtype == bool)
1✔
501
        self.assertEqual(works_bool, True)
1✔
502
        works_bool = np.array_equal(equal_test, [True, True, False])
1✔
503
        self.assertEqual(works_bool, True)
1✔
504
        # --- Single value ---
505
        equal_test = group.is_id(behavior_id_mul)
1✔
506
        works_bool = (equal_test.dtype == bool)
1✔
507
        self.assertEqual(works_bool, True)
1✔
508
        works_bool = np.array_equal(equal_test, True)
1✔
509
        self.assertEqual(works_bool, True)
1✔
510
        # --- Single value ---
511
        equal_test = group.is_id(999999999999)
1✔
512
        works_bool = (equal_test.dtype == bool)
1✔
513
        self.assertEqual(works_bool, True)
1✔
514
        works_bool = np.array_equal(equal_test, False)
1✔
515
        self.assertEqual(works_bool, True)
1✔
516

517

518
    def test_protected_functions_plots (self):
1✔
519

520
        do_show = False
1✔
521
        do_save = False
1✔
522
        make_plots = do_show or do_save  # if either is true, make plots
1✔
523

524
        n_plot = int(1e4)
1✔
525

526
        if make_plots:
1✔
527
            # protected_div
NEW
528
            x1 = torch.linspace(-1, 1, n_plot)
×
NEW
529
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
NEW
530
            ax.plot(x1, Func.protected_div(x1, 4 * x1), label="protected_div", color="k", linestyle="solid")
×
NEW
531
            ax.plot(x1, Func.protected_div(x1, -4 * x1), label="protected_div (neg num)", color="r", linestyle="dotted")
×
UNCOV
532
            ax.legend()
×
UNCOV
533
            if do_show: plt.show()
×
UNCOV
534
            if do_save: fig.savefig("protected_div.png")
×
535

536
            # protected_exp
NEW
537
            x1 = torch.linspace(0, 1.1*Func.EXP_THRESHOLD, n_plot)
×
NEW
538
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
UNCOV
539
            ax.plot(x1, Func.protected_exp(x1), label="protected_exp", color="k")
×
UNCOV
540
            ax.legend()
×
UNCOV
541
            if do_show: plt.show()
×
UNCOV
542
            if do_save: fig.savefig("protected_exp.png")
×
543

544
            # protected_log
NEW
545
            x1 = torch.linspace(-10*Func.EPSILON, 10*Func.EPSILON, n_plot)
×
NEW
546
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
UNCOV
547
            ax.plot(x1, Func.protected_log(x1), label="protected_log", color="k")
×
UNCOV
548
            ax.legend()
×
UNCOV
549
            if do_show: plt.show()
×
UNCOV
550
            if do_save: fig.savefig("protected_log.png")
×
551

552
            # protected_logabs
NEW
553
            x1 = torch.linspace(-10*Func.EPSILON, 10*Func.EPSILON, n_plot)
×
NEW
554
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
UNCOV
555
            ax.plot(x1, Func.protected_logabs(x1), label="protected_logabs", color="k")
×
UNCOV
556
            ax.legend()
×
UNCOV
557
            if do_show: plt.show()
×
UNCOV
558
            if do_save: fig.savefig("protected_logabs.png")
×
559

560
            # protected_sqrt
NEW
561
            x1 = torch.linspace(-10, 10, n_plot)
×
NEW
562
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
UNCOV
563
            ax.plot(x1, Func.protected_sqrt(x1), label="protected_sqrt", color="k")
×
UNCOV
564
            ax.legend()
×
UNCOV
565
            if do_show: plt.show()
×
UNCOV
566
            if do_save: fig.savefig("protected_sqrt.png")
×
567

568
            # protected_inv
NEW
569
            x1 = torch.linspace(-10*Func.EPSILON, 10*Func.EPSILON, n_plot)
×
NEW
570
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
UNCOV
571
            ax.plot(x1, Func.protected_inv(x1), label="protected_inv", color="k")
×
UNCOV
572
            ax.legend()
×
UNCOV
573
            if do_show: plt.show()
×
UNCOV
574
            if do_save: fig.savefig("protected_inv.png")
×
575

576
            # protected_expneg
NEW
577
            x1 = torch.linspace(-1.1*Func.EXP_THRESHOLD, 0, n_plot)
×
NEW
578
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
UNCOV
579
            ax.plot(x1, Func.protected_expneg(x1), label="protected_expneg", color="k")
×
UNCOV
580
            ax.legend()
×
UNCOV
581
            if do_show: plt.show()
×
UNCOV
582
            if do_save: fig.savefig("protected_expneg.png")
×
583

584
            # protected_n2
NEW
585
            x1 = torch.linspace(-2 * Func.INF, 2 * Func.INF, n_plot)
×
NEW
586
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
UNCOV
587
            ax.plot(x1, Func.protected_n2(x1), label="protected_n2", color="k")
×
UNCOV
588
            ax.legend()
×
UNCOV
589
            if do_show: plt.show()
×
UNCOV
590
            if do_save: fig.savefig("protected_n2.png")
×
591

592
            # protected_n3
NEW
593
            x1 = torch.linspace(-2 * Func.INF, 2 * Func.INF, n_plot)
×
NEW
594
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
UNCOV
595
            ax.plot(x1, Func.protected_n3(x1), label="protected_n3", color="k")
×
UNCOV
596
            ax.legend()
×
UNCOV
597
            if do_show: plt.show()
×
UNCOV
598
            if do_save: fig.savefig("protected_n3.png")
×
599

600
            # protected_n4
NEW
601
            x1 = torch.linspace(-2 * Func.INF, 2 * Func.INF, n_plot)
×
NEW
602
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
UNCOV
603
            ax.plot(x1, Func.protected_n4(x1), label="protected_n4", color="k")
×
UNCOV
604
            ax.legend()
×
UNCOV
605
            if do_show: plt.show()
×
UNCOV
606
            if do_save: fig.savefig("protected_n4.png")
×
607

608
            # protected_arcsin
NEW
609
            x1 = torch.linspace(-2, 2, n_plot)
×
NEW
610
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
UNCOV
611
            ax.plot(x1, Func.protected_arcsin(x1), label="protected_arcsin", color="k")
×
UNCOV
612
            ax.legend()
×
UNCOV
613
            if do_show: plt.show()
×
UNCOV
614
            if do_save: fig.savefig("protected_arcsin.png")
×
615

616
            # protected_arccos
NEW
617
            x1 = torch.linspace(-2, 2, n_plot)
×
NEW
618
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
UNCOV
619
            ax.plot(x1, Func.protected_arccos(x1), label="protected_arccos", color="k")
×
UNCOV
620
            ax.legend()
×
UNCOV
621
            if do_show: plt.show()
×
UNCOV
622
            if do_save: fig.savefig("protected_arccos.png")
×
623

624
            # protected_torch_pow
NEW
625
            x1 = torch.linspace(-0.1*Func.INF, Func.INF, n_plot)
×
NEW
626
            x2 = torch.linspace(-8, 8, n_plot)
×
NEW
627
            fig, ax = plt.subplots(1, 1, figsize=(20, 10))
×
NEW
628
            ax.plot(x1, Func.protected_torch_pow(x1, x2), label="protected_div", color="k", linestyle="solid")
×
UNCOV
629
            ax.legend()
×
UNCOV
630
            if do_show: plt.show()
×
UNCOV
631
            if do_save: fig.savefig("protected_torch_pow.png")
×
632

633
        return None
1✔
634

635
if __name__ == '__main__':
1✔
UNCOV
636
    unittest.main(verbosity=2)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc