• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

moeyensj / thor / 6733525095

02 Nov 2023 01:52PM UTC coverage: 40.544% (+0.9%) from 39.595%
6733525095

push

github

web-flow
Merge pull request #123 from moeyensj/v2.0-link-aims-sample

Link AIMS sample

301 of 301 new or added lines in 12 files covered. (100.0%)

1878 of 4632 relevant lines covered (40.54%)

0.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

7.3
/thor/orbits/od.py
1
import os
1✔
2

3
os.environ["OMP_NUM_THREADS"] = "1"
1✔
4
os.environ["OPENBLAS_NUM_THREADS"] = "1"
1✔
5
os.environ["MKL_NUM_THREADS"] = "1"
1✔
6
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
1✔
7
os.environ["NUMEXPR_NUM_THREADS"] = "1"
1✔
8

9
import concurrent.futures as cf
1✔
10
import copy
1✔
11
import logging
1✔
12
import multiprocessing as mp
1✔
13
import time
1✔
14
from functools import partial
1✔
15

16
import numpy as np
1✔
17
import pandas as pd
1✔
18
from astropy import units as u
1✔
19
from astropy.time import Time
1✔
20
from scipy.linalg import solve
1✔
21

22
from ..backend import PYOORB
1✔
23
from ..utils import (
1✔
24
    _checkParallel,
25
    _initWorker,
26
    calcChunkSize,
27
    sortLinkages,
28
    yieldChunks,
29
)
30
from .orbits import Orbits
1✔
31
from .residuals import calcResiduals
1✔
32

33
logger = logging.getLogger(__name__)
1✔
34

35
__all__ = ["od_worker", "od", "differentialCorrection"]
1✔
36

37

38
def od_worker(
1✔
39
    orbits_list,
40
    observations_list,
41
    rchi2_threshold=100,
42
    min_obs=5,
43
    min_arc_length=1.0,
44
    contamination_percentage=20,
45
    delta=1e-6,
46
    max_iter=20,
47
    method="central",
48
    fit_epoch=False,
49
    test_orbit=None,
50
    backend="PYOORB",
51
    backend_kwargs={},
52
):
53
    od_orbits_dfs = []
×
54
    od_orbit_members_dfs = []
×
55
    for orbit, observations in zip(orbits_list, observations_list):
×
56
        try:
×
57
            assert orbit.ids[0] == observations["orbit_id"].unique()[0]
×
58
            assert np.all(
×
59
                sorted(observations["mjd_utc"].values) == observations["mjd_utc"].values
60
            )
61
            assert len(np.unique(observations["mjd_utc"].values)) == len(
×
62
                observations["mjd_utc"].values
63
            )
64
        except:
×
65
            err = (
×
66
                "Invalid observations and orbit have been passed to the OD code.\n"
67
                "Orbit ID: {}".format(orbit.ids[0])
68
            )
69
            raise ValueError(err)
×
70

71
        time_start = time.time()
×
72
        logger.debug(f"Differentially correcting orbit {orbit.ids[0]}...")
×
73
        od_orbit, od_orbit_members = od(
×
74
            orbit,
75
            observations,
76
            rchi2_threshold=rchi2_threshold,
77
            min_obs=min_obs,
78
            min_arc_length=min_arc_length,
79
            contamination_percentage=contamination_percentage,
80
            delta=delta,
81
            max_iter=max_iter,
82
            method=method,
83
            fit_epoch=fit_epoch,
84
            test_orbit=test_orbit,
85
            backend=backend,
86
            backend_kwargs=backend_kwargs,
87
        )
88
        time_end = time.time()
×
89
        duration = time_end - time_start
×
90
        logger.debug(f"OD for orbit {orbit.ids[0]} completed in {duration:.3f}s.")
×
91

92
        od_orbits_dfs.append(od_orbit)
×
93
        od_orbit_members_dfs.append(od_orbit_members)
×
94

95
    od_orbits = pd.concat(od_orbits_dfs, ignore_index=True)
×
96
    od_orbit_members = pd.concat(od_orbit_members_dfs, ignore_index=True)
×
97
    return od_orbits, od_orbit_members
×
98

99

100
def od(
1✔
101
    orbit,
102
    observations,
103
    rchi2_threshold=100,
104
    min_obs=5,
105
    min_arc_length=1.0,
106
    contamination_percentage=0.0,
107
    delta=1e-6,
108
    max_iter=20,
109
    method="central",
110
    fit_epoch=False,
111
    test_orbit=None,
112
    backend="PYOORB",
113
    backend_kwargs={},
114
):
115
    if backend == "PYOORB":
×
116
        backend = PYOORB(**backend_kwargs)
×
117
    else:
118
        err = "backend should be 'PYOORB'"
×
119
        raise ValueError(err)
×
120

121
    if method not in ["central", "finite"]:
×
122
        err = "method should be one of 'central' or 'finite'."
×
123
        raise ValueError(err)
×
124

125
    observables = ["RA_deg", "Dec_deg"]
×
126

127
    obs_ids_all = observations["obs_id"].values
×
128
    coords = observations[observables].values
×
129
    coords_sigma = observations[["RA_sigma_deg", "Dec_sigma_deg"]].values
×
130

131
    observers = {}
×
132
    for observatory_code in observations["observatory_code"].unique():
×
133
        observatory_mask = observations["observatory_code"].isin([observatory_code])
×
134
        observers[observatory_code] = Time(
×
135
            observations[observatory_mask]["mjd_utc"].unique(),
136
            format="mjd",
137
            scale="utc",
138
        )
139

140
    # FLAG: can we stop iterating to find a solution?
141
    converged = False
×
142
    # FLAG: has an orbit with reduced chi2 less than the reduced chi2 of the input orbit been found?
143
    improved = False
×
144
    # FLAG: has an orbit with reduced chi2 less than the rchi2_threshold been found?
145
    solution_found = False
×
146
    # FLAG: is this orbit processable? Does it have at least min_obs observations?
147
    processable = True
×
148
    # FLAG: is this the first iteration with a successful differential correction (this solution is always stored as the solution
149
    # which needs to be improved.. input orbits may not have been previously corrected with current set of observations so this
150
    # forces at least one succesful iteration to have been taken.)
151
    first_solution = True
×
152

153
    num_obs = len(observations)
×
154
    if num_obs < min_obs:
×
155
        logger.debug("This orbit has fewer than {} observations.".format(min_obs))
×
156
        processable = False
×
157
    else:
158
        num_outliers = int(num_obs * contamination_percentage / 100.0)
×
159
        num_outliers = np.maximum(np.minimum(num_obs - min_obs, num_outliers), 0)
×
160
        logger.debug("Maximum number of outliers allowed: {}".format(num_outliers))
×
161
        outliers_tried = 0
×
162

163
        # Calculate chi2 for residuals on the given observations
164
        # for the current orbit, the goal is for the orbit to improve
165
        # such that the chi2 improves
166
        orbit_prev_ = copy.deepcopy(orbit)
×
167

168
        ephemeris_prev_ = backend._generateEphemeris(orbit_prev_, observers)
×
169
        residuals_prev_, stats_prev_ = calcResiduals(
×
170
            coords,
171
            ephemeris_prev_[observables].values,
172
            sigmas_actual=coords_sigma,
173
            include_probabilistic=False,
174
        )
175
        num_obs_ = len(observations)
×
176
        chi2_prev_ = stats_prev_[0]
×
177
        chi2_total_prev_ = np.sum(chi2_prev_)
×
178
        rchi2_prev_ = np.sum(chi2_prev_) / (2 * num_obs - 6)
×
179

180
        # Save the initial orbit in case we need to reset
181
        # to it later
182
        orbit_prev = orbit_prev_
×
183
        ephemeris_prev = ephemeris_prev_
×
184
        residuals_prev = residuals_prev_
×
185
        num_obs = num_obs_
×
186
        chi2_prev = chi2_prev_
×
187
        chi2_total_prev = chi2_total_prev_
×
188
        rchi2_prev = rchi2_prev_
×
189

190
        ids_mask = np.array([True for i in range(num_obs)])
×
191
        times_all = ephemeris_prev["mjd_utc"].values
×
192
        obs_id_outlier = []
×
193
        delta_prev = delta
×
194
        iterations = 0
×
195

196
        DELTA_INCREASE_FACTOR = 5
×
197
        DELTA_DECREASE_FACTOR = 100
×
198

199
        max_iter_i = max_iter
×
200
        max_iter_outliers = max_iter * (num_outliers + 1)
×
201

202
    while not converged and processable:
×
203
        iterations += 1
×
204

205
        # We add 1 here because the iterations are counted as they start, not
206
        # as they finish. There are a lot of 'continue' statements down below that
207
        # will exit the current iteration if something fails which makes accounting for
208
        # iterations at the start of an iteration easier.
209
        if iterations == max_iter_outliers + 1:
×
210
            logger.debug(f"Maximum number of iterations completed.")
×
211
            break
×
212
        if iterations == max_iter_i + 1 and (
×
213
            solution_found or (num_outliers == outliers_tried)
214
        ):
215
            logger.debug(f"Maximum number of iterations completed.")
×
216
            break
×
217
        logger.debug(f"Starting iteration number: {iterations}/{max_iter_outliers}")
×
218

219
        # Make sure delta is well bounded
220
        if delta_prev < 1e-14:
×
221
            delta_prev *= DELTA_INCREASE_FACTOR
×
222
            logger.debug("Delta is too small, increasing.")
×
223
        elif delta_prev > 1e-2:
×
224
            delta_prev /= DELTA_DECREASE_FACTOR
×
225
            logger.debug("Delta is too large, decreasing.")
×
226
        else:
227
            pass
×
228

229
        delta_iter = delta_prev
×
230
        logger.debug(f"Starting iteration {iterations} with delta {delta_iter}.")
×
231

232
        # Initialize the partials derivatives matrix
233
        if num_obs > 6 and fit_epoch:
×
234
            num_params = 7
×
235
        else:
236
            num_params = 6
×
237

238
        A = np.zeros((coords.shape[1], num_params, num_obs))
×
239
        ATWA = np.zeros((num_params, num_params, num_obs))
×
240
        ATWb = np.zeros((num_params, 1, num_obs))
×
241

242
        # Generate ephemeris with current nominal orbit
243
        ephemeris_nom = backend._generateEphemeris(orbit_prev, observers)
×
244
        coords_nom = ephemeris_nom[observables].values
×
245

246
        # Modify each component of the state by a small delta
247
        d = np.zeros((1, 7))
×
248
        for i in range(num_params):
×
249

250
            # zero the delta vector
251
            d *= 0.0
×
252

253
            # x, y, z [au]: 0, 1, 2
254
            # vx, vy, vz [au per day]: 3, 4, 5
255
            # time [days] : 6
256
            if i < 3:
×
257
                delta_iter = delta_prev
×
258

259
                d[0, i] = orbit_prev.cartesian[0, i] * delta_iter
×
260
            elif i < 6:
×
261
                delta_iter = delta_prev
×
262

263
                d[0, i] = orbit_prev.cartesian[0, i] * delta_iter
×
264
            else:
265
                delta_iter = delta_prev / 100000
×
266

267
                d[0, i] = delta_iter
×
268

269
            # Modify component i of the orbit by a small delta
270
            orbit_iter_p = Orbits(
×
271
                orbit_prev.cartesian + d[0, :6],
272
                orbit_prev.epochs + d[0, 6] * u.day,
273
                orbit_type="cartesian",
274
            )
275

276
            # Calculate the modified ephemerides
277
            ephemeris_mod_p = backend._generateEphemeris(orbit_iter_p, observers)
×
278
            coords_mod_p = ephemeris_mod_p[observables].values
×
279

280
            delta_denom = d[0, i]
×
281
            if method == "central":
×
282

283
                # Modify component i of the orbit by a small delta
284
                orbit_iter_n = Orbits(
×
285
                    orbit_prev.cartesian - d[0, :6],
286
                    orbit_prev.epochs - d[0, 6] * u.day,
287
                    orbit_type="cartesian",
288
                )
289

290
                # Calculate the modified ephemerides
291
                ephemeris_mod_n = backend._generateEphemeris(orbit_iter_n, observers)
×
292
                coords_mod_n = ephemeris_mod_n[observables].values
×
293

294
                delta_denom *= 2
×
295

296
            else:
297
                coords_mod_n = coords_nom
×
298

299
            residuals_mod, _ = calcResiduals(
×
300
                coords_mod_p,
301
                coords_mod_n,
302
                sigmas_actual=None,
303
                include_probabilistic=False,
304
            )
305

306
            for n in range(num_obs):
×
307
                try:
×
308
                    A[:, i : i + 1, n] = (
×
309
                        residuals_mod[ids_mask][n : n + 1].T / delta_denom
310
                    )
311
                except RuntimeError:
×
312
                    print(orbit_prev.ids)
×
313

314
        for n in range(num_obs):
×
315
            W = np.diag(1 / coords_sigma[n] ** 2)
×
316
            ATWA[:, :, n] = A[:, :, n].T @ W @ A[:, :, n]
×
317
            ATWb[:, :, n] = A[:, :, n].T @ W @ residuals_prev[n : n + 1].T
×
318

319
        ATWA = np.sum(ATWA, axis=2)
×
320
        ATWb = np.sum(ATWb, axis=2)
×
321

322
        ATWA_condition = np.linalg.cond(ATWA)
×
323
        ATWb_condition = np.linalg.cond(ATWb)
×
324

325
        if (ATWA_condition > 1e15) or (ATWb_condition > 1e15):
×
326
            delta_prev /= DELTA_DECREASE_FACTOR
×
327
            continue
×
328
        if np.any(np.isnan(ATWA)) or np.any(np.isnan(ATWb)):
×
329
            delta_prev *= DELTA_INCREASE_FACTOR
×
330
            continue
×
331
        else:
332
            try:
×
333
                delta_state = solve(
×
334
                    ATWA,
335
                    ATWb,
336
                ).T
337
                covariance_matrix = np.linalg.inv(ATWA)
×
338
                variances = np.diag(covariance_matrix)
×
339
                if np.any(variances <= 0) or np.any(np.isnan(variances)):
×
340
                    delta_prev /= DELTA_DECREASE_FACTOR
×
341
                    logger.debug(
×
342
                        "Variances are negative, 0.0, or NaN. Discarding solution."
343
                    )
344
                    continue
×
345

346
                r_variances = variances[0:3]
×
347
                r_sigma = np.sqrt(np.sum(r_variances))
×
348
                r = np.linalg.norm(orbit_prev.cartesian[0, :3])
×
349
                if (r_sigma / r) > 1:
×
350
                    delta_prev /= DELTA_DECREASE_FACTOR
×
351
                    logger.debug(
×
352
                        "Covariance matrix is largely unconstrained. Discarding solution."
353
                    )
354
                    continue
×
355

356
                if np.any(np.isnan(covariance_matrix)):
×
357
                    delta_prev *= DELTA_INCREASE_FACTOR
×
358
                    logger.debug(
×
359
                        "Covariance matrix contains NaNs. Discarding solution."
360
                    )
361
                    continue
×
362

363
            except np.linalg.LinAlgError:
×
364
                delta_prev *= DELTA_INCREASE_FACTOR
×
365
                continue
×
366

367
        if num_params == 6:
×
368
            d_state = delta_state
×
369
            d_time = 0
×
370
        else:
371
            d_state = delta_state[0, :6]
×
372
            d_time = delta_state[0, 6]
×
373

374
        if np.linalg.norm(d_state[:3]) < 1e-16:
×
375
            logger.debug("Change in state is less than 1e-16 au, discarding solution.")
×
376
            delta_prev *= DELTA_DECREASE_FACTOR
×
377
            continue
×
378
        if np.linalg.norm(d_state[:3]) > 100:
×
379
            delta_prev /= DELTA_DECREASE_FACTOR
×
380
            logger.debug("Change in state is more than 100 au, discarding solution.")
×
381
            continue
×
382

383
        orbit_iter = Orbits(
×
384
            orbit_prev.cartesian + d_state,
385
            orbit_prev.epochs + d_time * u.day,
386
            orbit_type="cartesian",
387
            ids=orbit_prev.ids,
388
            covariance=[covariance_matrix],
389
        )
390
        if np.linalg.norm(orbit_iter.cartesian[0, 3:]) > 1:
×
391
            delta_prev *= DELTA_INCREASE_FACTOR
×
392
            logger.debug("Orbit is moving extraordinarily fast, discarding solution.")
×
393
            continue
×
394

395
        # Generate ephemeris with current nominal orbit
396
        ephemeris_iter = backend._generateEphemeris(orbit_iter, observers)
×
397
        coords_iter = ephemeris_iter[observables].values
×
398

399
        residuals, stats = calcResiduals(
×
400
            coords, coords_iter, sigmas_actual=coords_sigma, include_probabilistic=False
401
        )
402
        chi2_iter = stats[0]
×
403
        chi2_total_iter = np.sum(chi2_iter[ids_mask])
×
404
        rchi2_iter = chi2_total_iter / (2 * num_obs - num_params)
×
405
        arc_length = times_all[ids_mask].max() - times_all[ids_mask].min()
×
406

407
        # If the new orbit has lower residuals than the previous,
408
        # accept the orbit and continue iterating until max iterations has been
409
        # reached. Once max iterations have been reached and the orbit still has not converged
410
        # to an acceptable solution, try removing an observation as an outlier and iterate again.
411
        if (
×
412
            (rchi2_iter < rchi2_prev) or first_solution
413
        ) and arc_length >= min_arc_length:
414

415
            if first_solution:
×
416
                logger.debug(
×
417
                    "Storing first successful differential correction iteration for these observations."
418
                )
419
                first_solution = False
×
420
            else:
421
                logger.debug("Potential improvement orbit has been found.")
×
422
            orbit_prev = orbit_iter
×
423
            residuals_prev = residuals
×
424
            chi2_prev = chi2_iter
×
425
            chi2_total_prev = chi2_total_iter
×
426
            rchi2_prev = rchi2_iter
×
427

428
            if rchi2_prev <= rchi2_prev_:
×
429
                improved = True
×
430

431
            if rchi2_prev <= rchi2_threshold:
×
432
                logger.debug("Potential solution orbit has been found.")
×
433
                solution_found = True
×
434
                converged = True
×
435

436
        elif (
×
437
            num_outliers > 0
438
            and outliers_tried <= num_outliers
439
            and iterations > max_iter_i
440
            and not solution_found
441
        ):
442

443
            logger.debug("Attempting to identify possible outliers.")
×
444
            # Previous fits have failed, lets reset the current best fit orbit back to its original
445
            # state and re-run fitting, this time removing outliers
446
            orbit_prev = orbit_prev_
×
447
            ephemeris_prev = ephemeris_prev_
×
448
            residuals_prev = residuals_prev_
×
449
            num_obs = num_obs_
×
450
            chi2_prev = chi2_prev_
×
451
            chi2_total_prev = chi2_total_prev_
×
452
            rchi2_prev = rchi2_prev_
×
453
            delta_prev = delta
×
454

455
            # Select i highest observations that contribute to
456
            # chi2 (and thereby the residuals)
457
            remove = chi2_prev.argsort()[-(outliers_tried + 1) :]
×
458

459
            # Grab the obs_ids for these outliers
460
            obs_id_outlier = obs_ids_all[remove]
×
461
            num_obs = len(observations) - len(obs_id_outlier)
×
462
            ids_mask = np.isin(obs_ids_all, obs_id_outlier, invert=True)
×
463
            arc_length = times_all[ids_mask].max() - times_all[ids_mask].min()
×
464

465
            logger.debug("Possible outlier(s): {}".format(obs_id_outlier))
×
466
            outliers_tried += 1
×
467
            if arc_length >= min_arc_length:
×
468
                max_iter_i = max_iter * (outliers_tried + 1)
×
469
            else:
470
                logger.debug(
×
471
                    "Removing the outlier will cause the arc length to go below the minimum."
472
                )
473

474
        # If the new orbit does not have lower residuals, try changing
475
        # delta to see if we get an improvement
476
        else:
477
            # logger.debug("Orbit did not improve since previous iteration, decrease delta and continue.")
478
            # delta_prev /= DELTA_DECREASE_FACTOR
479
            pass
×
480

481
        logger.debug(
×
482
            "Current r-chi2: {}, Previous r-chi2: {}, Max Iterations: {}, Outliers Tried: {}".format(
483
                rchi2_iter, rchi2_prev, max_iter_i, outliers_tried
484
            )
485
        )
486

487
    logger.debug("Solution found: {}".format(solution_found))
×
488
    logger.debug("First solution: {}".format(first_solution))
×
489

490
    if not solution_found or not processable or first_solution:
×
491

492
        od_orbit = pd.DataFrame(
×
493
            columns=[
494
                "orbit_id",
495
                "mjd_tdb",
496
                "x",
497
                "y",
498
                "z",
499
                "vx",
500
                "vy",
501
                "vz",
502
                "covariance",
503
                "r",
504
                "r_sigma",
505
                "v",
506
                "v_sigma",
507
                "arc_length",
508
                "num_obs",
509
                "num_params",
510
                "num_iterations",
511
                "chi2",
512
                "rchi2",
513
                "improved",
514
            ]
515
        )
516

517
        od_orbit_members = pd.DataFrame(
×
518
            columns=[
519
                "orbit_id",
520
                "obs_id",
521
                "residual_ra_arcsec",
522
                "residual_dec_arcsec",
523
                "chi2",
524
                "outlier",
525
            ]
526
        )
527

528
    else:
529
        variances = np.diag(orbit_prev.cartesian_covariance[0])
×
530
        r_variances = variances[0:3]
×
531
        v_variances = variances[3:6]
×
532

533
        obs_times = observations["mjd_utc"].values[ids_mask]
×
534
        od_orbit = orbit_prev.to_df(include_units=False)
×
535
        od_orbit["r"] = np.linalg.norm(orbit_prev.cartesian[0, :3])
×
536
        od_orbit["r_sigma"] = np.sqrt(np.sum(r_variances))
×
537
        od_orbit["v"] = np.linalg.norm(orbit_prev.cartesian[0, 3:])
×
538
        od_orbit["v_sigma"] = np.sqrt(np.sum(v_variances))
×
539
        od_orbit["arc_length"] = np.max(obs_times) - np.min(obs_times)
×
540
        od_orbit["num_obs"] = num_obs
×
541
        od_orbit["num_params"] = num_params
×
542
        od_orbit["num_iterations"] = iterations
×
543
        od_orbit["chi2"] = chi2_total_prev
×
544
        od_orbit["rchi2"] = rchi2_prev
×
545
        od_orbit["improved"] = improved
×
546

547
        od_orbit_members = pd.DataFrame(
×
548
            {
549
                "orbit_id": [orbit_prev.ids[0] for i in range(len(obs_ids_all))],
550
                "obs_id": obs_ids_all,
551
                "residual_ra_arcsec": residuals_prev[:, 0] * 3600,
552
                "residual_dec_arcsec": residuals_prev[:, 1] * 3600,
553
                "chi2": chi2_prev,
554
                "outlier": np.zeros(len(obs_ids_all), dtype=int),
555
            }
556
        )
557
        od_orbit_members.loc[
×
558
            od_orbit_members["obs_id"].isin(obs_id_outlier), "outlier"
559
        ] = 1
560

561
    return od_orbit, od_orbit_members
×
562

563

564
def differentialCorrection(
1✔
565
    orbits,
566
    orbit_members,
567
    observations,
568
    min_obs=5,
569
    min_arc_length=1.0,
570
    contamination_percentage=20,
571
    rchi2_threshold=100,
572
    delta=1e-8,
573
    max_iter=20,
574
    method="central",
575
    fit_epoch=False,
576
    test_orbit=None,
577
    backend="PYOORB",
578
    backend_kwargs={},
579
    chunk_size=10,
580
    num_jobs=60,
581
    parallel_backend="cf",
582
):
583
    """
584
    Differentially correct (via finite/central differencing).
585

586
    Parameters
587
    ----------
588
    chunk_size : int, optional
589
        Number of orbits to send to each job.
590
    num_jobs : int, optional
591
        Number of jobs to launch.
592
    parallel_backend : str, optional
593
        Which parallelization backend to use {'ray', 'mp', 'cf'}. Defaults to using Python's concurrent.futures
594
        module ('cf').
595
    """
596
    logger.info("Running differential correction...")
×
597

598
    time_start = time.time()
×
599

600
    if len(orbits) > 0 and len(orbit_members) > 0:
×
601

602
        orbits_, orbit_members_ = sortLinkages(orbits, orbit_members, observations)
×
603

604
        start = time.time()
×
605
        logger.debug("Merging observations on linkage members...")
×
606
        linked_observations = orbit_members_[
×
607
            orbit_members_[["orbit_id", "obs_id"]]["orbit_id"].isin(
608
                orbits_["orbit_id"].values
609
            )
610
        ].merge(observations, on="obs_id", how="left")
611
        duration = time.time() - start
×
612
        logger.debug(f"Merging completed in {duration:.3f}s.")
×
613

614
        start = time.time()
×
615
        logger.debug("Grouping observations by orbit ID...")
×
616
        grouped_observations = linked_observations.groupby(by=["orbit_id"])
×
617
        logger.debug("Splitting grouped observations by orbit ID...")
×
618
        observations_split = [
×
619
            grouped_observations.get_group(g).reset_index(drop=True)
620
            for g in grouped_observations.groups
621
        ]
622
        duration = time.time() - start
×
623
        logger.debug(f"Grouping and splitting completed in {duration:.3f}s.")
×
624

625
        orbits_initial = Orbits.from_df(orbits_)
×
626
        orbits_split = orbits_initial.split(1)
×
627
        num_orbits = len(orbits)
×
628

629
        parallel, num_workers = _checkParallel(num_jobs, parallel_backend)
×
630
        if num_workers > 1:
×
631

632
            if parallel_backend == "ray":
×
633
                import ray
×
634

635
                if not ray.is_initialized():
×
636
                    ray.init(address="auto")
×
637

638
                od_worker_ray = ray.remote(od_worker)
×
639
                od_worker_ray = od_worker_ray.options(num_returns=2, num_cpus=1)
×
640

641
                # Send up to chunk_size orbits to each OD worker for processing
642
                chunk_size_ = calcChunkSize(
×
643
                    num_orbits, num_workers, chunk_size, min_chunk_size=1
644
                )
645
                logger.info(
×
646
                    f"Distributing linkages in chunks of {chunk_size_} to {num_workers} ray workers."
647
                )
648

649
                # Put the observations and orbits into ray's local object storage ("plasma")
650
                orbit_oids = []
×
651
                observation_oids = []
×
652
                for orbits_i, observations_i in zip(
×
653
                    yieldChunks(orbits_split, chunk_size_),
654
                    yieldChunks(observations_split, chunk_size_),
655
                ):
656
                    orbit_oids.append(ray.put(orbits_i))
×
657
                    observation_oids.append(ray.put(observations_i))
×
658

659
                od_orbits_oids = []
×
660
                od_orbit_members_oids = []
×
661
                for orbits_oid, observations_oid in zip(orbit_oids, observation_oids):
×
662

663
                    od_orbits_oid, od_orbit_members_oid = od_worker_ray.remote(
×
664
                        orbits_oid,
665
                        observations_oid,
666
                        rchi2_threshold=rchi2_threshold,
667
                        min_obs=min_obs,
668
                        min_arc_length=min_arc_length,
669
                        contamination_percentage=contamination_percentage,
670
                        delta=delta,
671
                        max_iter=max_iter,
672
                        method=method,
673
                        fit_epoch=fit_epoch,
674
                        test_orbit=test_orbit,
675
                        backend=backend,
676
                        backend_kwargs=backend_kwargs,
677
                    )
678
                    od_orbits_oids.append(od_orbits_oid)
×
679
                    od_orbit_members_oids.append(od_orbit_members_oid)
×
680

681
                od_orbits_dfs = ray.get(od_orbits_oids)
×
682
                od_orbit_members_dfs = ray.get(od_orbit_members_oids)
×
683

684
            elif parallel_backend == "mp":
×
685

686
                chunk_size_ = calcChunkSize(
×
687
                    num_orbits, num_workers, chunk_size, min_chunk_size=1
688
                )
689
                logger.info(
×
690
                    f"Distributing linkages in chunks of {chunk_size_} to {num_workers} workers."
691
                )
692

693
                p = mp.Pool(
×
694
                    processes=num_workers,
695
                    initializer=_initWorker,
696
                )
697
                results = p.starmap(
×
698
                    partial(
699
                        od_worker,
700
                        rchi2_threshold=rchi2_threshold,
701
                        min_obs=min_obs,
702
                        min_arc_length=min_arc_length,
703
                        contamination_percentage=contamination_percentage,
704
                        delta=delta,
705
                        max_iter=max_iter,
706
                        method=method,
707
                        fit_epoch=fit_epoch,
708
                        test_orbit=test_orbit,
709
                        backend=backend,
710
                        backend_kwargs=backend_kwargs,
711
                    ),
712
                    zip(
713
                        yieldChunks(orbits_split, chunk_size_),
714
                        yieldChunks(observations_split, chunk_size_),
715
                    ),
716
                )
717
                p.close()
×
718

719
                results = list(zip(*results))
×
720
                od_orbits_dfs = results[0]
×
721
                od_orbit_members_dfs = results[1]
×
722

723
            elif parallel_backend == "cf":
×
724
                with cf.ProcessPoolExecutor(
×
725
                    max_workers=num_workers, initializer=_initWorker
726
                ) as executor:
727
                    futures = []
×
728
                    for orbits_i, observations_i in zip(
×
729
                        yieldChunks(orbits_split, chunk_size),
730
                        yieldChunks(observations_split, chunk_size),
731
                    ):
732
                        futures.append(
×
733
                            executor.submit(
734
                                od_worker,
735
                                orbits_i,
736
                                observations_i,
737
                                rchi2_threshold=rchi2_threshold,
738
                                min_obs=min_obs,
739
                                min_arc_length=min_arc_length,
740
                                contamination_percentage=contamination_percentage,
741
                                delta=delta,
742
                                max_iter=max_iter,
743
                                method=method,
744
                                fit_epoch=fit_epoch,
745
                                test_orbit=test_orbit,
746
                                backend=backend,
747
                                backend_kwargs=backend_kwargs,
748
                            )
749
                        )
750
                    od_orbits_dfs = []
×
751
                    od_orbit_members_dfs = []
×
752
                    for future in cf.as_completed(futures):
×
753
                        od_orbits_df, od_orbit_members_df = future.result()
×
754
                        od_orbits_dfs.append(od_orbits_df)
×
755
                        od_orbit_members_dfs.append(od_orbit_members_df)
×
756

757
            else:
758
                raise ValueError(
×
759
                    f"Unknown parallel backend: {parallel_backend}. Must be one of: 'ray', 'mp', 'cf'."
760
                )
761

762
        else:
763

764
            od_orbits_dfs = []
×
765
            od_orbit_members_dfs = []
×
766
            for orbits_i, observations_i in zip(
×
767
                yieldChunks(orbits_split, chunk_size),
768
                yieldChunks(observations_split, chunk_size),
769
            ):
770

771
                od_orbits_df, od_orbit_members_df = od_worker(
×
772
                    orbits_i,
773
                    observations_i,
774
                    rchi2_threshold=rchi2_threshold,
775
                    min_obs=min_obs,
776
                    min_arc_length=min_arc_length,
777
                    contamination_percentage=contamination_percentage,
778
                    delta=delta,
779
                    max_iter=max_iter,
780
                    method=method,
781
                    fit_epoch=fit_epoch,
782
                    test_orbit=test_orbit,
783
                    backend=backend,
784
                    backend_kwargs=backend_kwargs,
785
                )
786
                od_orbits_dfs.append(od_orbits_df)
×
787
                od_orbit_members_dfs.append(od_orbit_members_df)
×
788

789
        od_orbits = pd.concat(od_orbits_dfs, ignore_index=True)
×
790
        od_orbit_members = pd.concat(od_orbit_members_dfs, ignore_index=True)
×
791

792
        for col in ["num_obs"]:
×
793
            od_orbits[col] = od_orbits[col].astype(int)
×
794
        for col in ["outlier"]:
×
795
            od_orbit_members[col] = od_orbit_members[col].astype(int)
×
796

797
        od_orbits, od_orbit_members = sortLinkages(
×
798
            od_orbits, od_orbit_members, observations, linkage_id_col="orbit_id"
799
        )
800

801
    else:
802
        od_orbits = pd.DataFrame(
×
803
            columns=[
804
                "orbit_id",
805
                "mjd_tdb",
806
                "x",
807
                "y",
808
                "z",
809
                "vx",
810
                "vy",
811
                "vz",
812
                "covariance",
813
                "r",
814
                "r_sigma",
815
                "v",
816
                "v_sigma",
817
                "arc_length",
818
                "num_obs",
819
                "chi2",
820
                "rchi2",
821
            ]
822
        )
823

824
        od_orbit_members = pd.DataFrame(
×
825
            columns=[
826
                "orbit_id",
827
                "obs_id",
828
                "residual_ra_arcsec",
829
                "residual_dec_arcsec",
830
                "chi2",
831
                "outlier",
832
            ]
833
        )
834

835
    time_end = time.time()
×
836
    logger.info("Differentially corrected {} orbits.".format(len(od_orbits)))
×
837
    logger.info(
×
838
        "Differential correction completed in {:.3f} seconds.".format(
839
            time_end - time_start
840
        )
841
    )
842

843
    return od_orbits, od_orbit_members
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc