• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WassimTenachi / PhySO / #16

28 Jul 2025 07:07AM UTC coverage: 70.145% (-10.8%) from 80.984%
#16

push

coveralls-python

WassimTenachi
fix

5963 of 8501 relevant lines covered (70.14%)

0.7 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

42.64
/physo/task/args_handler.py
1
import numpy as np
1✔
2
import torch
1✔
3
import warnings
1✔
4

5
# Internal imports
6
import physo.learn.monitoring as monitoring
1✔
7
import physo.physym.dataset as Dataset
1✔
8
import physo
1✔
9

10
# DEFAULT MONITORING CONFIG TO USE
11
get_default_run_logger = lambda : monitoring.RunLogger(
1✔
12
                                      save_path = 'SR.log',
13
                                      do_save   = True)
14
get_default_run_visualiser = lambda : monitoring.RunVisualiser (
1✔
15
                                           epoch_refresh_rate = 1,
16
                                           save_path = 'SR_curves.png',
17
                                           do_show   = False,
18
                                           do_prints = True,
19
                                           do_save   = True, )
20

21
# DEFAULT ALLOWED OPERATIONS
22
default_op_names = ["mul", "add", "sub", "div", "inv", "n2", "sqrt", "neg", "exp", "log", "sin", "cos"]
1✔
23
default_stop_after_n_epochs = 10
1✔
24

25

26
def check_args_and_build_run_config(multi_X, multi_y, multi_y_weights,
1✔
27
            # X
28
            X_names,
29
            X_units,
30
            # y
31
            y_name ,
32
            y_units,
33
            # Fixed constants
34
            fixed_consts,
35
            fixed_consts_units,
36
            # Class free constants
37
            class_free_consts_names   ,
38
            class_free_consts_units   ,
39
            class_free_consts_init_val,
40
            # Spe Free constants
41
            spe_free_consts_names   ,
42
            spe_free_consts_units   ,
43
            spe_free_consts_init_val,
44
            # Operations to use
45
            op_names,
46
            use_protected_ops,
47
            # Stopping
48
            epochs,
49
            # Candidate wrapper
50
            candidate_wrapper,
51
            # Default run config to use
52
            run_config,
53
            # Default run monitoring
54
            get_run_logger,
55
            get_run_visualiser,
56
            # Parallel mode
57
            parallel_mode,
58
            n_cpus,
59
            device,
60
    ):
61
    """
62
    Checks arguments of SR and ClassSR functions and builds run_config for physo.task.fit.
63
    """
64

65
    # ------------------------------- DATASETS -------------------------------
66

67
    # Data checking and conversion to torch if necessary is now handled by Dataset class which is called by Batch class.
68
    # We use it here to infer n_dim (this will also run most other assertions unrelated to the library which is unknown
69
    # here and extra time) and sending data to device.
70
    dataset = Dataset.Dataset(multi_X=multi_X, multi_y=multi_y, multi_y_weights=multi_y_weights)
1✔
71
    # Getting number of input variables
72
    n_dim   = dataset.n_dim
1✔
73
    # Getting number of realizations
74
    n_realizations = dataset.n_realizations
1✔
75
    # Sending data to device and using sent data
76
    dataset.to(device)
1✔
77
    multi_X         = dataset.multi_X
1✔
78
    multi_y         = dataset.multi_y
×
79
    multi_y_weights = dataset.multi_y_weights
×
80

81
    # ------------------------------- LIBRARY ARGS -------------------------------
82

83
    # -- X_names --
84
    # Handling input variables names
85
    if X_names is None:
×
86
        # If None use x00, x01... names
87
        X_names = ["x%s"%(str(i).zfill(2)) for i in range(n_dim)]
1✔
88
    X_names = np.array(X_names)
1✔
89
    assert X_names.dtype.char == "U", "Input variables names should be strings."
×
90
    assert X_names.shape == (n_dim,), "There should be one input variable name per dimension in X."
×
91

92
    # -- X_units --
93
    # Handling input variables units
94
    if X_units is None:
×
95
        warnings.warn("No units given for input variables, assuming dimensionless units.")
1✔
96
        X_units = [[0,0,0] for _ in range(n_dim)]
1✔
97
    X_units = np.array(X_units).astype(float)
×
98
    assert X_units.shape[0] == n_dim, "There should be one input variable units per dimension in X."
×
99

100
    # --- y_name ---
101
    if y_name is None:
1✔
102
        y_name = "y"
×
103
    y_name = str(y_name)
×
104

105
    # --- y_units ---
106
    if y_units is None:
×
107
        warnings.warn("No units given for root variable, assuming dimensionless units.")
1✔
108
        y_units = [0,0,0]
1✔
109
    y_units = np.array(y_units).astype(float)
×
110
    assert len(y_units.shape) == 1, "y_units must be a 1D units vector"
×
111

112
    # --- n_fixed_consts ---
113
    if fixed_consts is not None:
×
114
        n_fixed_consts = len(fixed_consts)
×
115
    else:
116
        n_fixed_consts = 0
×
117
        fixed_consts = []
×
118
        warnings.warn("No information about fixed constants, not using any.")
×
119

120
    # --- fixed_consts_names ---
121
    # Rounding name to avoid using too long names (eg. for np.pi)
122
    fixed_consts_names = np.array([str(round(c, 4)) for c in fixed_consts])
×
123
    fixed_consts       = np.array(fixed_consts).astype(float)
×
124

125
    # --- fixed_consts_units ---
126
    if fixed_consts_units is None:
×
127
        warnings.warn("No units given for fixed constants, assuming dimensionless units.")
1✔
128
        fixed_consts_units = [[0,0,0] for _ in range(n_fixed_consts)]
1✔
129
    fixed_consts_units = np.array(fixed_consts_units).astype(float)
×
130
    assert fixed_consts_units.shape[0] == n_fixed_consts, "There should be one fixed constant units vector per fixed constant in fixed_consts_names"
×
131

132
    # --- n_class_free_consts ---
133
    if class_free_consts_names is not None:
×
134
        n_class_free_consts = len(class_free_consts_names)
×
135
    elif class_free_consts_units is not None:
×
136
        n_class_free_consts = len(class_free_consts_units)
×
137
    else:
138
        n_class_free_consts = 0
×
139
        warnings.warn("No information about class free constants, not using any.")
×
140

141
    # --- class_free_consts_names ---
142
    if class_free_consts_names is None:
×
143
        # If None use c00, c01... names
144
        class_free_consts_names = ["c%s"%(str(i).zfill(2)) for i in range(n_class_free_consts)]
1✔
145
    # Convert to strings (this helps pass str assert in case array is empty)
146
    class_free_consts_names = np.array(class_free_consts_names).astype(str)
1✔
147
    assert class_free_consts_names.dtype.char == "U", "class_free_consts_names should be strings."
×
148
    assert class_free_consts_names.shape == (n_class_free_consts,), \
×
149
        "There should be one class free constant name per units in class_free_consts_units"
150

151
    # --- class_free_consts_units ---
152
    if class_free_consts_units is None:
×
153
        if n_class_free_consts > 0:
×
154
            warnings.warn("No units given for class free constants, assuming dimensionless units.")
1✔
155
        class_free_consts_units = [[0,0,0] for _ in range(n_class_free_consts)]
1✔
156
    class_free_consts_units = np.array(class_free_consts_units).astype(float)
×
157
    assert class_free_consts_units.shape[0] == n_class_free_consts, \
×
158
        "There should be one class free constant units vector per free constant in class_free_consts_names"
159

160
    # --- class_free_consts_init_val ---
161
    if class_free_consts_init_val is None:
1✔
162
        class_free_consts_init_val = np.ones(n_class_free_consts)
1✔
163
    class_free_consts_init_val = np.array(class_free_consts_init_val).astype(float)
×
164
    assert class_free_consts_init_val.shape[0] == n_class_free_consts, \
×
165
        "There should be one class free constant initial value per free constant in class_free_consts_names"
166

167
    # --- n_spe_free_consts ---
168
    if spe_free_consts_names is not None:
1✔
169
        n_spe_free_consts = len(spe_free_consts_names)
×
170
    elif spe_free_consts_units is not None:
×
171
        n_spe_free_consts = len(spe_free_consts_units)
1✔
172
    else:
173
        n_spe_free_consts = 0
1✔
174
        # Only warning if there are multiple realizations
175
        if n_realizations > 1:
×
176
            warnings.warn("No information about spe free constants, not using any.")
×
177

178
    # --- spe_free_consts_names ---
179
    if spe_free_consts_names is None:
1✔
180
        # If None use c00, c01... names
181
        spe_free_consts_names = ["k%s"%(str(i).zfill(2)) for i in range(n_spe_free_consts)]
1✔
182
    # Convert to strings (this helps pass str assert in case array is empty)
183
    spe_free_consts_names = np.array(spe_free_consts_names).astype(str)
1✔
184
    assert spe_free_consts_names.dtype.char == "U", "spe_free_consts_names should be strings."
×
185
    assert spe_free_consts_names.shape == (n_spe_free_consts,), \
×
186
        "There should be one spe free constant name per units in spe_free_consts_units"
187

188
    # --- spe_free_consts_units ---
189
    if spe_free_consts_units is None:
×
190
        if n_spe_free_consts > 0:
1✔
191
            warnings.warn("No units given for spe free constants, assuming dimensionless units.")
1✔
192
        spe_free_consts_units = [[0,0,0] for _ in range(n_spe_free_consts)]
1✔
193
    spe_free_consts_units = np.array(spe_free_consts_units).astype(float)
×
194
    assert spe_free_consts_units.shape[0] == n_spe_free_consts, \
×
195
        "There should be one spe free constant units vector per free constant in spe_free_consts_names"
196

197
    # --- spe_free_consts_init_val ---
198
    if spe_free_consts_init_val is None:
×
199
        spe_free_consts_init_val = np.ones(n_spe_free_consts)
1✔
200
    # Do not convert to array as user may use a mix of single floats and (n_realizations,) arrays
201
    assert len(spe_free_consts_init_val) == n_spe_free_consts, \
×
202
        "There should be one spe free constant initial value per free constant in spe_free_consts_names"
203

204
    # --- op_names ---
205
    if op_names is None:
×
206
        op_names = default_op_names
×
207

208
    # ------------------------------- WRAPPING LIBRARY -------------------------------
209

210
    # Converting fixed constants to torch and sending to device
211
    fixed_consts = torch.tensor(fixed_consts).to(device)
×
212

213
    # Embedding wrapping
214
    args_make_tokens = {
1✔
215
                    # operations
216
                    "op_names"             : op_names,
217
                    "use_protected_ops"    : use_protected_ops,
218
                    # input variables
219
                    "input_var_ids"        : {X_names[i]: i          for i in range(n_dim)},
220
                    "input_var_units"      : {X_names[i]: X_units[i] for i in range(n_dim)},
221
                    # constants
222
                    "constants"            : {fixed_consts_names[i] : fixed_consts[i]       for i in range(n_fixed_consts)},
223
                    "constants_units"      : {fixed_consts_names[i] : fixed_consts_units[i] for i in range(n_fixed_consts)},
224
                    # class_free_constants
225
                    "class_free_constants"          : {class_free_consts_names[i]                                 for i in range(n_class_free_consts)},
226
                    "class_free_constants_units"    : {class_free_consts_names[i] : class_free_consts_units   [i] for i in range(n_class_free_consts)},
227
                    "class_free_constants_init_val" : {class_free_consts_names[i] : class_free_consts_init_val[i] for i in range(n_class_free_consts)},
228
                    # spe_free_constants
229
                    "spe_free_constants"          : {spe_free_consts_names[i]                               for i in range(n_spe_free_consts)},
230
                    "spe_free_constants_units"    : {spe_free_consts_names[i] : spe_free_consts_units   [i] for i in range(n_spe_free_consts)},
231
                    "spe_free_constants_init_val" : {spe_free_consts_names[i] : spe_free_consts_init_val[i] for i in range(n_spe_free_consts)},
232
                        }
233

234
    library_config = {"args_make_tokens"  : args_make_tokens,
1✔
235
                      "superparent_units" : y_units,
236
                      "superparent_name"  : y_name,
237
                    }
238

239
    # Updating config
240
    run_config.update({
×
241
        "library_config" : library_config,
242
    })
243

244
    # ------------------------------- MONITORING -------------------------------
245
    run_logger     = get_run_logger()
×
246
    run_visualiser = get_run_visualiser()
×
247

248
    # Updating config
249
    run_config.update({
1✔
250
        "run_logger"           : run_logger,
251
        "run_visualiser"       : run_visualiser,
252
    })
253

254
    # ------------------------------- PARALLEL CONFIG AND BUILDING RewardsComputer -------------------------------
255

256
    # Update reward_config
257
    run_config["reward_config"].update({
1✔
258
        # with parallel config
259
        "parallel_mode" : parallel_mode,
260
        "n_cpus"        : n_cpus,
261
        })
262
    #  Updating reward config for parallel mode
263
    reward_config = run_config["reward_config"]
×
264
    run_config["learning_config"]["rewards_computer"] = physo.physym.reward.make_RewardsComputer(**reward_config)
×
265

266
    # ------------------------------- EPOCHS -------------------------------
267

268
    # Number of epochs (using epochs args in run_config if it was given).
269
    if epochs is not None:
×
270
        run_config["learning_config"]["n_epochs"] = epochs
×
271

272
    # ------------------------------- MAX_TIME_STEP ASSERTIONS -------------------------------
273

274
    # Asserting that max_time_step is >= HardLengthPrior's max_length
275
    for prior_config in run_config["priors_config"]:
1✔
276
        if prior_config[0] == "HardLengthPrior":
×
277
            assert run_config["learning_config"]["max_time_step"] >= prior_config[1]["max_length"], \
×
278
                "max_time_step should be greater than or equal to HardLengthPrior's max_length."
279

280
    # ------------------------------- LEARNING HYPERPARAMS ASSERTIONS -------------------------------
281
    # risk_factor should be a float >= 0 and <= 1
282
    risk_factor = run_config["learning_config"]["risk_factor"]
1✔
283
    try:
×
284
        risk_factor = float(risk_factor)
×
285
    except:
1✔
286
        raise ValueError("risk_factor should be castable to a float.")
1✔
287
    assert isinstance(risk_factor, float), "risk_factor should be a float."
×
288
    assert 0 <= risk_factor <= 1, "risk_factor should be >= 0 and <= 1."
×
289

290
    # gamma_decay should be a float
291
    gamma_decay = run_config["learning_config"]["gamma_decay"]
1✔
292
    try:
×
293
        gamma_decay = float(gamma_decay)
×
294
    except:
1✔
295
        raise ValueError("gamma_decay should be castable to a float.")
×
296
    assert isinstance(gamma_decay, float), "gamma_decay should be a float."
×
297

298
    # entropy_weight should be a float
299
    entropy_weight = run_config["learning_config"]["entropy_weight"]
1✔
300
    try:
×
301
        entropy_weight = float(entropy_weight)
×
302
    except:
1✔
303
        raise ValueError("entropy_weight should be castable to a float.")
×
304
    assert isinstance(entropy_weight, float), "entropy_weight should be a float."
×
305

306
    # ------------------------------- CANDIDATE_WRAPPER -------------------------------
307
    # candidate_wrapper should be callable or None
308
    assert candidate_wrapper is None or callable(candidate_wrapper), "candidate_wrapper should be callable or None."
1✔
309

310
    # ------------------------------- RETURN -------------------------------
311
    # Returning
312
    handled_args = {
1✔
313
        "multi_X"         : multi_X,
314
        "multi_y"         : multi_y,
315
        "multi_y_weights" : multi_y_weights,
316
        "run_config"      : run_config,
317
    }
318
    return handled_args
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc