• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

moeyensj / thor / 6733525095

02 Nov 2023 01:52PM UTC coverage: 40.544% (+0.9%) from 39.595%
6733525095

push

github

web-flow
Merge pull request #123 from moeyensj/v2.0-link-aims-sample

Link AIMS sample

301 of 301 new or added lines in 12 files covered. (100.0%)

1878 of 4632 relevant lines covered (40.54%)

0.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

9.39
/thor/orbits/iod.py
1
import os
1✔
2

3
os.environ["OMP_NUM_THREADS"] = "1"
1✔
4
os.environ["OPENBLAS_NUM_THREADS"] = "1"
1✔
5
os.environ["MKL_NUM_THREADS"] = "1"
1✔
6
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
1✔
7
os.environ["NUMEXPR_NUM_THREADS"] = "1"
1✔
8

9
import concurrent.futures as cf
1✔
10
import logging
1✔
11
import multiprocessing as mp
1✔
12
import time
1✔
13
import uuid
1✔
14
from functools import partial
1✔
15
from itertools import combinations
1✔
16

17
import numpy as np
1✔
18
import pandas as pd
1✔
19
from astropy.time import Time
1✔
20

21
from ..backend import PYOORB
1✔
22
from ..utils import (
1✔
23
    _checkParallel,
24
    _initWorker,
25
    calcChunkSize,
26
    identifySubsetLinkages,
27
    sortLinkages,
28
    yieldChunks,
29
)
30
from .gauss import gaussIOD
1✔
31
from .residuals import calcResiduals
1✔
32

33
logger = logging.getLogger(__name__)
1✔
34

35
__all__ = ["selectObservations", "iod", "iod_worker", "initialOrbitDetermination"]
1✔
36

37

38
def selectObservations(observations, method="combinations"):
1✔
39
    """
40
    Selects which three observations to use for IOD depending on the method.
41

42
    Methods:
43
        'first+middle+last' : Grab the first, middle and last observations in time.
44
        'thirds' : Grab the middle observation in the first third, second third, and final third.
45
        'combinations' : Return the observation IDs corresponding to every possible combination of three observations with
46
            non-coinciding observation times.
47

48
    Parameters
49
    ----------
50
    observations : `~pandas.DataFrame`
51
        Pandas DataFrame containing observations with at least a column of observation IDs and a column
52
        of exposure times.
53
    method : {'first+middle+last', 'thirds', 'combinations'}, optional
54
        Which method to use to select observations.
55
        [Default = 'combinations']
56

57
    Returns
58
    -------
59
    obs_id : `~numpy.ndarray' (N, 3 or 0)
60
        An array of selected observation IDs. If three unique observations could
61
        not be selected then returns an empty array.
62
    """
63
    obs_ids = observations["obs_id"].values
×
64
    if len(obs_ids) < 3:
×
65
        return np.array([])
×
66

67
    indexes = np.arange(0, len(obs_ids))
×
68
    times = observations["mjd_utc"].values
×
69

70
    if method == "first+middle+last":
×
71
        selected_times = np.percentile(times, [0, 50, 100], interpolation="nearest")
×
72
        selected_index = np.intersect1d(times, selected_times, return_indices=True)[1]
×
73
        selected_index = np.array([selected_index])
×
74

75
    elif method == "thirds":
×
76
        selected_times = np.percentile(
×
77
            times, [1 / 6 * 100, 50, 5 / 6 * 100], interpolation="nearest"
78
        )
79
        selected_index = np.intersect1d(times, selected_times, return_indices=True)[1]
×
80
        selected_index = np.array([selected_index])
×
81

82
    elif method == "combinations":
×
83
        # Make all possible combinations of 3 observations
84
        selected_index = np.array(
×
85
            [np.array(index) for index in combinations(indexes, 3)]
86
        )
87

88
        # Calculate arc length
89
        arc_length = times[selected_index][:, 2] - times[selected_index][:, 0]
×
90

91
        # Calculate distance of second observation from middle point (last + first) / 2
92
        time_from_mid = np.abs(
×
93
            (times[selected_index][:, 2] + times[selected_index][:, 0]) / 2
94
            - times[selected_index][:, 1]
95
        )
96

97
        # Sort by descending arc length and ascending time from midpoint
98
        sort = np.lexsort((time_from_mid, -arc_length))
×
99
        selected_index = selected_index[sort]
×
100

101
    else:
102
        raise ValueError("method should be one of {'first+middle+last', 'thirds'}")
×
103

104
    # Make sure each returned combination of observation ids have at least 3 unique
105
    # times
106
    keep = []
×
107
    for i, comb in enumerate(times[selected_index]):
×
108
        if len(np.unique(comb)) == 3:
×
109
            keep.append(i)
×
110
    keep = np.array(keep)
×
111

112
    # Return an empty array if no observations satisfy the criteria
113
    if len(keep) == 0:
×
114
        return np.array([])
×
115
    else:
116
        selected_index = selected_index[keep, :]
×
117

118
    return obs_ids[selected_index]
×
119

120

121
def iod_worker(
1✔
122
    observations_list,
123
    observation_selection_method="combinations",
124
    min_obs=6,
125
    min_arc_length=1.0,
126
    rchi2_threshold=10**3,
127
    contamination_percentage=0.0,
128
    iterate=False,
129
    light_time=True,
130
    linkage_id_col="cluster_id",
131
    backend="PYOORB",
132
    backend_kwargs={},
133
):
134
    iod_orbits_dfs = []
×
135
    iod_orbit_members_dfs = []
×
136
    for observations in observations_list:
×
137
        assert np.all(
×
138
            sorted(observations["mjd_utc"].values) == observations["mjd_utc"].values
139
        )
140

141
        time_start = time.time()
×
142
        linkage_id = observations[linkage_id_col].unique()[0]
×
143
        logger.debug(f"Finding initial orbit for linkage {linkage_id}...")
×
144

145
        iod_orbit, iod_orbit_members = iod(
×
146
            observations,
147
            observation_selection_method=observation_selection_method,
148
            min_obs=min_obs,
149
            min_arc_length=min_arc_length,
150
            rchi2_threshold=rchi2_threshold,
151
            contamination_percentage=contamination_percentage,
152
            iterate=iterate,
153
            light_time=light_time,
154
            backend=backend,
155
            backend_kwargs=backend_kwargs,
156
        )
157
        if len(iod_orbit) > 0:
×
158
            iod_orbit.insert(1, linkage_id_col, linkage_id)
×
159

160
        time_end = time.time()
×
161
        duration = time_end - time_start
×
162
        logger.debug(f"IOD for linkage {linkage_id} completed in {duration:.3f}s.")
×
163

164
        iod_orbits_dfs.append(iod_orbit)
×
165
        iod_orbit_members_dfs.append(iod_orbit_members)
×
166

167
    iod_orbits = pd.concat(iod_orbits_dfs, ignore_index=True)
×
168
    iod_orbit_members = pd.concat(iod_orbit_members_dfs, ignore_index=True)
×
169
    return iod_orbits, iod_orbit_members
×
170

171

172
def iod(
1✔
173
    observations,
174
    min_obs=6,
175
    min_arc_length=1.0,
176
    contamination_percentage=0.0,
177
    rchi2_threshold=200,
178
    observation_selection_method="combinations",
179
    iterate=False,
180
    light_time=True,
181
    backend="PYOORB",
182
    backend_kwargs={},
183
):
184
    """
185
    Run initial orbit determination on a set of observations believed to belong to a single
186
    object.
187

188
    Parameters
189
    ----------
190
    observations : `~pandas.DataFrame`
191
        Dataframe of observations with at least the following columns:
192
            "obs_id" : Observation IDs [str],
193
            "mjd_utc" : Observation time in MJD UTC [float],
194
            "RA_deg" : equatorial J2000 Right Ascension in degrees [float],
195
            "Dec_deg" : equatorial J2000 Declination in degrees [float],
196
            "RA_sigma_deg" : 1-sigma uncertainty in equatorial J2000 RA [float],
197
            "Dec_sigma_deg" : 1 sigma uncertainty in equatorial J2000 Dec [float],
198
            "observatory_code" : MPC recognized observatory code [str],
199
            "obs_x" : Observatory's heliocentric ecliptic J2000 x-position in au [float],
200
            "obs_y" : Observatory's heliocentric ecliptic J2000 y-position in au [float],
201
            "obs_z" : Observatory's heliocentric ecliptic J2000 z-position in au [float],
202
            "obs_vx" [Optional] : Observatory's heliocentric ecliptic J2000 x-velocity in au per day [float],
203
            "obs_vy" [Optional] : Observatory's heliocentric ecliptic J2000 y-velocity in au per day [float],
204
            "obs_vz" [Optional] : Observatory's heliocentric ecliptic J2000 z-velocity in au per day [float]
205
    min_obs : int, optional
206
        Minimum number of observations that must remain in the linkage. For example, if min_obs is set to 6 and
207
        a linkage has 8 observations, at most the two worst observations will be flagged as outliers if their individual
208
        chi2 values exceed the chi2 threshold.
209
    contamination_percentage : float, optional
210
        Maximum percent of observations that can flagged as outliers.
211
    rchi2_threshold : float, optional
212
        Maximum reduced chi2 required for an initial orbit to be accepted.
213
    observation_selection_method : {'first+middle+last', 'thirds', 'combinations'}, optional
214
        Selects which three observations to use for IOD depending on the method. The avaliable methods are:
215
            'first+middle+last' : Grab the first, middle and last observations in time.
216
            'thirds' : Grab the middle observation in the first third, second third, and final third.
217
            'combinations' : Return the observation IDs corresponding to every possible combination of three observations with
218
                non-coinciding observation times.
219
    iterate : bool, optional
220
        Iterate the preliminary orbit solution using the state transition iterator.
221
    light_time : bool, optional
222
        Correct preliminary orbit for light travel time.
223
    linkage_id_col : str, optional
224
        Name of linkage_id column in the linkage_members dataframe.
225
    backend : {'MJOLNIR', 'PYOORB'}, optional
226
        Which backend to use for ephemeris generation.
227
    backend_kwargs : dict, optional
228
        Settings and additional parameters to pass to selected
229
        backend.
230

231
    Returns
232
    -------
233
    iod_orbits : `~pandas.DataFrame`
234
        Dataframe with orbits found in linkages.
235
            "orbit_id" : Orbit ID, a uuid [str],
236
            "epoch" : Epoch at which orbit is defined in MJD TDB [float],
237
            "x" : Orbit's ecliptic J2000 x-position in au [float],
238
            "y" : Orbit's ecliptic J2000 y-position in au [float],
239
            "z" : Orbit's ecliptic J2000 z-position in au [float],
240
            "vx" : Orbit's ecliptic J2000 x-velocity in au per day [float],
241
            "vy" : Orbit's ecliptic J2000 y-velocity in au per day [float],
242
            "vz" : Orbit's ecliptic J2000 z-velocity in au per day [float],
243
            "arc_length" : Arc length in days [float],
244
            "num_obs" : Number of observations that were within the chi2 threshold
245
                of the orbit.
246
            "chi2" : Total chi2 of the orbit calculated using the predicted location of the orbit
247
                on the sky compared to the consituent observations.
248

249
    iod_orbit_members : `~pandas.DataFrame`
250
        Dataframe of orbit members with the following columns:
251
            "orbit_id" : Orbit ID, a uuid [str],
252
            "obs_id" : Observation IDs [str], one ID per row.
253
            "residual_ra_arcsec" : Residual (observed - expected) equatorial J2000 Right Ascension in arcseconds [float]
254
            "residual_dec_arcsec" : Residual (observed - expected) equatorial J2000 Declination in arcseconds [float]
255
            "chi2" : Observation's chi2 [float]
256
            "gauss_sol" : Flag to indicate which observations were used to calculate the Gauss soluton [int]
257
            "outlier" : Flag to indicate which observations are potential outliers (their chi2 is higher than
258
                the chi2 threshold) [float]
259
    """
260
    processable = True
×
261
    if len(observations) == 0:
×
262
        processable = False
×
263

264
    # Extract column names
265
    obs_id_col = "obs_id"
×
266
    time_col = "mjd_utc"
×
267
    ra_col = "RA_deg"
×
268
    dec_col = "Dec_deg"
×
269
    ra_err_col = "RA_sigma_deg"
×
270
    dec_err_col = "Dec_sigma_deg"
×
271
    obs_code_col = "observatory_code"
×
272
    obs_x_col = "obs_x"
×
273
    obs_y_col = "obs_y"
×
274
    obs_z_col = "obs_z"
×
275

276
    # Extract observation IDs, sky-plane positions, sky-plane position uncertainties, times of observation,
277
    # and the location of the observer at each time
278
    obs_ids_all = observations[obs_id_col].values
×
279
    coords_all = observations[[ra_col, dec_col]].values
×
280
    sigmas_all = observations[[ra_err_col, dec_err_col]].values
×
281
    coords_obs_all = observations[[obs_x_col, obs_y_col, obs_z_col]].values
×
282
    times_all = observations[time_col].values
×
283
    times_all = Time(times_all, scale="utc", format="mjd")
×
284

285
    observers = {}
×
286
    for code in observations[obs_code_col].unique():
×
287
        observers[code] = Time(
×
288
            observations[observations[obs_code_col] == code][time_col].values,
289
            scale="utc",
290
            format="mjd",
291
        )
292

293
    if backend == "PYOORB":
×
294
        if light_time == False:
×
295
            err = "PYOORB does not support turning light time correction off."
×
296
            raise ValueError(err)
×
297

298
        backend = PYOORB(**backend_kwargs)
×
299
    else:
300
        err = "backend should be 'PYOORB'"
×
301
        raise ValueError(err)
×
302

303
    chi2_sol = 1e10
×
304
    orbit_sol = None
×
305
    obs_ids_sol = None
×
306
    arc_length = None
×
307
    outliers = np.array([])
×
308
    converged = False
×
309
    num_obs = len(observations)
×
310
    if num_obs < min_obs:
×
311
        processable = False
×
312
    num_outliers = int(num_obs * contamination_percentage / 100.0)
×
313
    num_outliers = np.maximum(np.minimum(num_obs - min_obs, num_outliers), 0)
×
314

315
    # Select observation IDs to use for IOD
316
    obs_ids = selectObservations(
×
317
        observations,
318
        method=observation_selection_method,
319
    )
320
    obs_ids = obs_ids[: (3 * (num_outliers + 1))]
×
321

322
    if len(obs_ids) == 0:
×
323
        processable = False
×
324

325
    j = 0
×
326
    while not converged and processable:
×
327
        if j == len(obs_ids):
×
328
            break
×
329

330
        ids = obs_ids[j]
×
331
        mask = np.isin(obs_ids_all, ids)
×
332

333
        # Grab sky-plane positions of the selected observations, the heliocentric ecliptic position of the observer,
334
        # and the times at which the observations occur
335
        coords = coords_all[mask, :]
×
336
        coords_obs = coords_obs_all[mask, :]
×
337
        times = times_all[mask]
×
338

339
        # Run IOD
340
        iod_orbits = gaussIOD(
×
341
            coords,
342
            times.utc.mjd,
343
            coords_obs,
344
            light_time=light_time,
345
            iterate=iterate,
346
            max_iter=100,
347
            tol=1e-15,
348
        )
349
        if len(iod_orbits) == 0:
×
350
            j += 1
×
351
            continue
×
352

353
        # Propagate initial orbit to all observation times
354
        ephemeris = backend._generateEphemeris(iod_orbits, observers)
×
355

356
        # For each unique initial orbit calculate residuals and chi-squared
357
        # Find the orbit which yields the lowest chi-squared
358
        orbit_ids = iod_orbits.ids
×
359
        for i, orbit_id in enumerate(orbit_ids):
×
360
            orbit_name = str(uuid.uuid4().hex)
×
361
            iod_orbits.ids[i] = orbit_name
×
362

363
            ephemeris_orbit = ephemeris[ephemeris["orbit_id"] == orbit_id]
×
364

365
            # Calculate residuals and chi2
366
            residuals, stats = calcResiduals(
×
367
                coords_all,
368
                ephemeris_orbit[["RA_deg", "Dec_deg"]].values,
369
                sigmas_actual=sigmas_all,
370
                include_probabilistic=False,
371
            )
372
            chi2 = stats[0]
×
373
            chi2_total = np.sum(chi2)
×
374
            rchi2 = chi2_total / (2 * num_obs - 6)
×
375

376
            # The reduced chi2 is above the threshold and no outliers are
377
            # allowed, this cannot be improved by outlier rejection
378
            # so continue to the next IOD orbit
379
            if rchi2 > rchi2_threshold and num_outliers == 0:
×
380
                # If we have iterated through all iod orbits and no outliers
381
                # are allowed for this linkage then no other combination of
382
                # observations will make it acceptable, so exit here.
383
                if (i + 1) == len(iod_orbits):
×
384
                    processable = False
×
385
                    break
×
386

387
                continue
×
388

389
            # If the total reduced chi2 is less than the threshold accept the orbit
390
            elif rchi2 <= rchi2_threshold:
×
391
                logger.debug("Potential solution orbit has been found.")
×
392
                orbit_sol = iod_orbits[i : i + 1]
×
393
                obs_ids_sol = ids
×
394
                chi2_total_sol = chi2_total
×
395
                chi2_sol = chi2
×
396
                rchi2_sol = rchi2
×
397
                residuals_sol = residuals
×
398
                outliers = np.array([])
×
399
                arc_length = times_all.utc.mjd.max() - times_all.utc.mjd.min()
×
400
                converged = True
×
401
                break
×
402

403
            # Let's now test to see if we can remove some outliers, we
404
            # anticipate that we get to this stage if the three selected observations
405
            # belonging to one object yield a good initial orbit but the presence of outlier
406
            # observations is skewing the sum total of the residuals and chi2
407
            elif num_outliers > 0:
×
408

409
                logger.debug("Attempting to identify possible outliers.")
×
410
                for o in range(num_outliers):
×
411
                    # Select i highest observations that contribute to
412
                    # chi2 (and thereby the residuals)
413
                    remove = chi2[~mask].argsort()[-(o + 1) :]
×
414

415
                    # Grab the obs_ids for these outliers
416
                    obs_id_outlier = obs_ids_all[~mask][remove]
×
417
                    logger.debug("Possible outlier(s): {}".format(obs_id_outlier))
×
418

419
                    # Subtract the outlier's chi2 contribution
420
                    # from the total chi2
421
                    # Then recalculate the reduced chi2
422
                    chi2_new = chi2_total - np.sum(chi2[~mask][remove])
×
423
                    num_obs_new = len(observations) - len(remove)
×
424
                    rchi2_new = chi2_new / (2 * num_obs_new - 6)
×
425

426
                    ids_mask = np.isin(obs_ids_all, obs_id_outlier, invert=True)
×
427
                    arc_length = (
×
428
                        times_all[ids_mask].utc.mjd.max()
429
                        - times_all[ids_mask].utc.mjd.min()
430
                    )
431

432
                    # If the updated reduced chi2 total is lower than our desired
433
                    # threshold, accept the soluton. If not, keep going.
434
                    if rchi2_new <= rchi2_threshold and arc_length >= min_arc_length:
×
435
                        orbit_sol = iod_orbits[i : i + 1]
×
436
                        obs_ids_sol = ids
×
437
                        chi2_total_sol = chi2_new
×
438
                        rchi2_sol = rchi2_new
×
439
                        residuals_sol = residuals
×
440
                        outliers = obs_id_outlier
×
441
                        num_obs = num_obs_new
×
442
                        ids_mask = np.isin(obs_ids_all, outliers, invert=True)
×
443
                        arc_length = (
×
444
                            times_all[ids_mask].utc.mjd.max()
445
                            - times_all[ids_mask].utc.mjd.min()
446
                        )
447
                        chi2_sol = chi2
×
448
                        converged = True
×
449
                        break
×
450

451
            else:
452
                continue
×
453

454
        j += 1
×
455

456
    if not converged or not processable:
×
457

458
        orbit = pd.DataFrame(
×
459
            columns=[
460
                "orbit_id",
461
                "mjd_tdb",
462
                "x",
463
                "y",
464
                "z",
465
                "vx",
466
                "vy",
467
                "vz",
468
                "arc_length",
469
                "num_obs",
470
                "chi2",
471
                "rchi2",
472
            ]
473
        )
474

475
        orbit_members = pd.DataFrame(
×
476
            columns=[
477
                "orbit_id",
478
                "obs_id",
479
                "residual_ra_arcsec",
480
                "residual_dec_arcsec",
481
                "chi2",
482
                "gauss_sol",
483
                "outlier",
484
            ]
485
        )
486

487
    else:
488
        orbit = orbit_sol.to_df(include_units=False)
×
489
        orbit["arc_length"] = arc_length
×
490
        orbit["num_obs"] = num_obs
×
491
        orbit["chi2"] = chi2_total_sol
×
492
        orbit["rchi2"] = rchi2_sol
×
493

494
        orbit_members = pd.DataFrame(
×
495
            {
496
                "orbit_id": [orbit_sol.ids[0] for i in range(len(obs_ids_all))],
497
                "obs_id": obs_ids_all,
498
                "residual_ra_arcsec": residuals_sol[:, 0] * 3600,
499
                "residual_dec_arcsec": residuals_sol[:, 1] * 3600,
500
                "chi2": chi2_sol,
501
                "gauss_sol": np.zeros(len(obs_ids_all), dtype=int),
502
                "outlier": np.zeros(len(obs_ids_all), dtype=int),
503
            }
504
        )
505
        orbit_members.loc[orbit_members["obs_id"].isin(outliers), "outlier"] = 1
×
506
        orbit_members.loc[orbit_members["obs_id"].isin(obs_ids_sol), "gauss_sol"] = 1
×
507

508
    return orbit, orbit_members
×
509

510

511
def initialOrbitDetermination(
1✔
512
    observations,
513
    linkage_members,
514
    min_obs=6,
515
    min_arc_length=1.0,
516
    contamination_percentage=20.0,
517
    rchi2_threshold=10**3,
518
    observation_selection_method="combinations",
519
    iterate=False,
520
    light_time=True,
521
    linkage_id_col="cluster_id",
522
    identify_subsets=True,
523
    backend="PYOORB",
524
    backend_kwargs={},
525
    chunk_size=1,
526
    num_jobs=1,
527
    parallel_backend="cf",
528
):
529
    """
530
    Run initial orbit determination on linkages found in observations.
531

532
    Parameters
533
    ----------
534
    observations : `~pandas.DataFrame`
535
        Dataframe of observations with at least the following columns:
536
            "obs_id" : Observation IDs [str],
537
            "mjd_utc" : Observation time in MJD UTC [float],
538
            "RA_deg" : equatorial J2000 Right Ascension in degrees [float],
539
            "Dec_deg" : equatorial J2000 Declination in degrees [float],
540
            "RA_sigma_deg" : 1-sigma uncertainty in equatorial J2000 RA [float],
541
            "Dec_sigma_deg" : 1 sigma uncertainty in equatorial J2000 Dec [float],
542
            "observatory_code" : MPC recognized observatory code [str],
543
            "obs_x" : Observatory's heliocentric ecliptic J2000 x-position in au [float],
544
            "obs_y" : Observatory's heliocentric ecliptic J2000 y-position in au [float],
545
            "obs_z" : Observatory's heliocentric ecliptic J2000 z-position in au [float],
546
            "obs_vx" [Optional] : Observatory's heliocentric ecliptic J2000 x-velocity in au per day [float],
547
            "obs_vy" [Optional] : Observatory's heliocentric ecliptic J2000 y-velocity in au per day [float],
548
            "obs_vz" [Optional] : Observatory's heliocentric ecliptic J2000 z-velocity in au per day [float]
549
    linkage_members : `~pandas.DataFrame`
550
        Dataframe of linkages with at least two columns:
551
            "linkage_id" : Linkage ID [str],
552
            "obs_id" : Observation IDs [str], one ID per row.
553
    observation_selection_method : {'first+middle+last', 'thirds', 'combinations'}, optional
554
        Selects which three observations to use for IOD depending on the method. The avaliable methods are:
555
            'first+middle+last' : Grab the first, middle and last observations in time.
556
            'thirds' : Grab the middle observation in the first third, second third, and final third.
557
            'combinations' : Return the observation IDs corresponding to every possible combination of three observations with
558
                non-coinciding observation times.
559
    min_obs : int, optional
560
        Minimum number of observations that must remain in the linkage. For example, if min_obs is set to 6 and
561
        a linkage has 8 observations, at most the two worst observations will be flagged as outliers. Only up t o
562
        the contamination percentage of observations of will be flagged as outliers, provided that at least min_obs
563
        observations remain in the linkage.
564
    rchi2_threshold : float, optional
565
        Minimum reduced chi2 for an initial orbit to be accepted. If an orbit
566
    contamination_percentage : float, optional
567
        Maximum percent of observations that can flagged as outliers.
568
    iterate : bool, optional
569
        Iterate the preliminary orbit solution using the state transition iterator.
570
    light_time : bool, optional
571
        Correct preliminary orbit for light travel time.
572
    linkage_id_col : str, optional
573
        Name of linkage_id column in the linkage_members dataframe.
574
    backend : {'MJOLNIR', 'PYOORB'}, optional
575
        Which backend to use for ephemeris generation.
576
    backend_kwargs : dict, optional
577
        Settings and additional parameters to pass to selected
578
        backend.
579
    chunk_size : int, optional
580
        Number of linkages to send to each job.
581
    num_jobs : int, optional
582
        Number of jobs to launch.
583
    parallel_backend : str, optional
584
        Which parallelization backend to use {'ray', 'mp', 'cf'}. Defaults to using Python's concurrent.futures module ('cf').
585

586
    Returns
587
    -------
588
    iod_orbits : `~pandas.DataFrame`
589
        Dataframe with orbits found in linkages.
590
            "orbit_id" : Orbit ID, a uuid [str],
591
            "epoch" : Epoch at which orbit is defined in MJD TDB [float],
592
            "x" : Orbit's ecliptic J2000 x-position in au [float],
593
            "y" : Orbit's ecliptic J2000 y-position in au [float],
594
            "z" : Orbit's ecliptic J2000 z-position in au [float],
595
            "vx" : Orbit's ecliptic J2000 x-velocity in au per day [float],
596
            "vy" : Orbit's ecliptic J2000 y-velocity in au per day [float],
597
            "vz" : Orbit's ecliptic J2000 z-velocity in au per day [float],
598
            "arc_length" : Arc length in days [float],
599
            "num_obs" : Number of observations that were within the chi2 threshold
600
                of the orbit.
601
            "chi2" : Total chi2 of the orbit calculated using the predicted location of the orbit
602
                on the sky compared to the consituent observations.
603

604
    iod_orbit_members : `~pandas.DataFrame`
605
        Dataframe of orbit members with the following columns:
606
            "orbit_id" : Orbit ID, a uuid [str],
607
            "obs_id" : Observation IDs [str], one ID per row.
608
            "residual_ra_arcsec" : Residual (observed - expected) equatorial J2000 Right Ascension in arcseconds [float]
609
            "residual_dec_arcsec" : Residual (observed - expected) equatorial J2000 Declination in arcseconds [float]
610
            "chi2" : Observation's chi2 [float]
611
            "gauss_sol" : Flag to indicate which observations were used to calculate the Gauss soluton [int]
612
            "outlier" : Flag to indicate which observations are potential outliers (their chi2 is higher than
613
                the chi2 threshold) [float]
614
    """
615
    time_start = time.time()
×
616
    logger.info("Running initial orbit determination...")
×
617

618
    if len(observations) > 0 and len(linkage_members) > 0:
×
619

620
        iod_orbits_dfs = []
×
621
        iod_orbit_members_dfs = []
×
622

623
        start = time.time()
×
624
        logger.debug("Merging observations on linkage members...")
×
625
        linked_observations = linkage_members.merge(observations, on="obs_id")
×
626
        logger.debug("Sorting observations by linkage ID and mjd_utc...")
×
627
        linked_observations.sort_values(
×
628
            by=[linkage_id_col, "mjd_utc"], inplace=True, ignore_index=True
629
        )
630
        duration = time.time() - start
×
631
        logger.debug(f"Merging and sorting completed in {duration:.3f}s.")
×
632

633
        start = time.time()
×
634
        logger.debug("Grouping observations by linkage ID...")
×
635
        grouped_observations = linked_observations.groupby(by=[linkage_id_col])
×
636
        logger.debug("Splitting grouped observations by linkage ID...")
×
637
        observations_split = [
×
638
            grouped_observations.get_group(g).reset_index(drop=True)
639
            for g in grouped_observations.groups
640
        ]
641
        duration = time.time() - start
×
642
        logger.debug(f"Grouping and splitting completed in {duration:.3f}s.")
×
643

644
        parallel, num_workers = _checkParallel(num_jobs, parallel_backend)
×
645
        if parallel:
×
646

647
            # The number of linkages that need to be fit for an initial orbit
648
            num_linkages = linkage_members[linkage_id_col].nunique()
×
649

650
            if parallel_backend == "ray":
×
651

652
                import ray
×
653

654
                if not ray.is_initialized():
×
655
                    ray.init(address="auto")
×
656

657
                iod_worker_ray = ray.remote(iod_worker)
×
658
                iod_worker_ray = iod_worker_ray.options(num_returns=2, num_cpus=1)
×
659

660
                # Send up to chunk_size linkages to each IOD worker for processing
661
                chunk_size_ = calcChunkSize(
×
662
                    num_linkages, num_workers, chunk_size, min_chunk_size=1
663
                )
664
                logger.info(
×
665
                    f"Distributing linkages in chunks of {chunk_size_} to {num_workers} ray workers."
666
                )
667

668
                # Put the observations into ray's local object storage ("plasma")
669
                observation_oids = []
×
670
                for observations_i in yieldChunks(observations_split, chunk_size_):
×
671
                    observation_oids.append(ray.put(observations_i))
×
672

673
                iod_orbits_oids = []
×
674
                iod_orbit_members_oids = []
×
675
                for observations_oid in observation_oids:
×
676

677
                    iod_orbits_oid, iod_orbit_members_oid = iod_worker_ray.remote(
×
678
                        observations_oid,
679
                        observation_selection_method=observation_selection_method,
680
                        rchi2_threshold=rchi2_threshold,
681
                        min_obs=min_obs,
682
                        min_arc_length=min_arc_length,
683
                        contamination_percentage=contamination_percentage,
684
                        iterate=iterate,
685
                        light_time=light_time,
686
                        linkage_id_col=linkage_id_col,
687
                        backend=backend,
688
                        backend_kwargs=backend_kwargs,
689
                    )
690
                    iod_orbits_oids.append(iod_orbits_oid)
×
691
                    iod_orbit_members_oids.append(iod_orbit_members_oid)
×
692

693
                iod_orbits_dfs = ray.get(iod_orbits_oids)
×
694
                iod_orbit_members_dfs = ray.get(iod_orbit_members_oids)
×
695

696
            elif parallel_backend == "mp":
×
697

698
                chunk_size_ = calcChunkSize(
×
699
                    num_linkages, num_workers, chunk_size, min_chunk_size=1
700
                )
701
                logger.info(
×
702
                    f"Distributing linkages in chunks of {chunk_size_} to {num_workers} workers."
703
                )
704

705
                p = mp.Pool(
×
706
                    processes=num_workers,
707
                    initializer=_initWorker,
708
                )
709

710
                results = p.starmap(
×
711
                    partial(
712
                        iod_worker,
713
                        observation_selection_method=observation_selection_method,
714
                        rchi2_threshold=rchi2_threshold,
715
                        min_obs=min_obs,
716
                        min_arc_length=min_arc_length,
717
                        contamination_percentage=contamination_percentage,
718
                        iterate=iterate,
719
                        light_time=light_time,
720
                        linkage_id_col=linkage_id_col,
721
                        backend=backend,
722
                        backend_kwargs=backend_kwargs,
723
                    ),
724
                    zip(yieldChunks(observations_split, chunk_size_)),
725
                )
726
                p.close()
×
727

728
                results = list(zip(*results))
×
729
                iod_orbits_dfs = results[0]
×
730
                iod_orbit_members_dfs = results[1]
×
731

732
            elif parallel_backend == "cf":
×
733
                with cf.ProcessPoolExecutor(
×
734
                    max_workers=num_workers, initializer=_initWorker
735
                ) as executor:
736
                    futures = []
×
737
                    for observations_i in yieldChunks(observations_split, chunk_size):
×
738
                        futures.append(
×
739
                            executor.submit(
740
                                iod_worker,
741
                                observations_i,
742
                                observation_selection_method=observation_selection_method,
743
                                rchi2_threshold=rchi2_threshold,
744
                                min_obs=min_obs,
745
                                min_arc_length=min_arc_length,
746
                                contamination_percentage=contamination_percentage,
747
                                iterate=iterate,
748
                                light_time=light_time,
749
                                linkage_id_col=linkage_id_col,
750
                                backend=backend,
751
                                backend_kwargs=backend_kwargs,
752
                            )
753
                        )
754

755
                    iod_orbits_dfs = []
×
756
                    iod_orbit_members_dfs = []
×
757
                    for f in cf.as_completed(futures):
×
758
                        iod_orbits_df, iod_orbit_members_df = f.result()
×
759
                        iod_orbits_dfs.append(iod_orbits_df)
×
760
                        iod_orbit_members_dfs.append(iod_orbit_members_df)
×
761
            else:
762
                raise ValueError(
×
763
                    f"Unknown parallel backend: {parallel_backend}. Must be one of: 'ray', 'mp', 'cf'."
764
                )
765

766
        else:
767

768
            for observations_i in yieldChunks(observations_split, chunk_size):
×
769
                iod_orbits_df, iod_orbit_members_df = iod_worker(
×
770
                    observations_i,
771
                    observation_selection_method=observation_selection_method,
772
                    rchi2_threshold=rchi2_threshold,
773
                    min_obs=min_obs,
774
                    min_arc_length=min_arc_length,
775
                    contamination_percentage=contamination_percentage,
776
                    iterate=iterate,
777
                    light_time=light_time,
778
                    linkage_id_col=linkage_id_col,
779
                    backend=backend,
780
                    backend_kwargs=backend_kwargs,
781
                )
782
                iod_orbits_dfs.append(iod_orbits_df)
×
783
                iod_orbit_members_dfs.append(iod_orbit_members_df)
×
784

785
        iod_orbits = pd.concat(iod_orbits_dfs, ignore_index=True)
×
786
        iod_orbit_members = pd.concat(iod_orbit_members_dfs, ignore_index=True)
×
787

788
        for col in ["num_obs"]:
×
789
            iod_orbits[col] = iod_orbits[col].astype(int)
×
790
        for col in ["gauss_sol", "outlier"]:
×
791
            iod_orbit_members[col] = iod_orbit_members[col].astype(int)
×
792

793
        logger.info("Found {} initial orbits.".format(len(iod_orbits)))
×
794

795
        if identify_subsets and len(iod_orbits) > 0:
×
796
            iod_orbits, iod_orbit_members = identifySubsetLinkages(
×
797
                iod_orbits, iod_orbit_members, linkage_id_col="orbit_id"
798
            )
799
            logger.info(
×
800
                "{} subset orbits identified.".format(
801
                    len(iod_orbits[~iod_orbits["subset_of"].isna()])
802
                )
803
            )
804

805
        iod_orbits, iod_orbit_members = sortLinkages(
×
806
            iod_orbits, iod_orbit_members, observations, linkage_id_col="orbit_id"
807
        )
808

809
    else:
810
        iod_orbits = pd.DataFrame(
×
811
            columns=[
812
                "orbit_id",
813
                "mjd_tdb",
814
                "x",
815
                "y",
816
                "z",
817
                "vx",
818
                "vy",
819
                "vz",
820
                "arc_length",
821
                "num_obs",
822
                "chi2",
823
                "rchi2",
824
            ]
825
        )
826

827
        iod_orbit_members = pd.DataFrame(
×
828
            columns=[
829
                "orbit_id",
830
                "obs_id",
831
                "residual_ra_arcsec",
832
                "residual_dec_arcsec",
833
                "chi2",
834
                "gauss_sol",
835
                "outlier",
836
            ]
837
        )
838

839
    time_end = time.time()
×
840
    logger.info(
×
841
        "Initial orbit determination completed in {:.3f} seconds.".format(
842
            time_end - time_start
843
        )
844
    )
845

846
    return iod_orbits, iod_orbit_members
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc