• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

moeyensj / thor / 7849403328

09 Feb 2024 08:43PM UTC coverage: 73.806%. First build
7849403328

Pull #161

github

web-flow
Merge bef0351aa into 4d093c0d7
Pull Request #161: Remove fit epoch

2750 of 3726 relevant lines covered (73.81%)

0.74 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

29.53
/thor/orbit_selection.py
1
import logging
1✔
2
import multiprocessing as mp
1✔
3
import time
1✔
4
from dataclasses import dataclass
1✔
5
from typing import Optional, Union
1✔
6

7
import numpy as np
1✔
8
import pyarrow as pa
1✔
9
import pyarrow.compute as pc
1✔
10
import pyarrow.parquet as pq
1✔
11
import quivr as qv
1✔
12
import ray
1✔
13
from adam_core.coordinates import KeplerianCoordinates
1✔
14
from adam_core.observers import Observers
1✔
15
from adam_core.orbits import Ephemeris, Orbits
1✔
16
from adam_core.propagator import PYOORB, Propagator
1✔
17
from adam_core.propagator.utils import _iterate_chunks
1✔
18
from adam_core.ray_cluster import initialize_use_ray
1✔
19
from adam_core.time import Timestamp
1✔
20

21
from thor.observations import Observations
1✔
22
from thor.orbit import TestOrbits
1✔
23

24
from .observations.utils import calculate_healpixels
1✔
25

26
logger = logging.getLogger(__name__)
1✔
27

28
__all__ = ["generate_test_orbits"]
1✔
29

30

31
@dataclass
1✔
32
class KeplerianPhaseSpace:
1✔
33
    a_min: float = -1_000_000.0
1✔
34
    a_max: float = 1_000_000.0
1✔
35
    e_min: float = 0.0
1✔
36
    e_max: float = 1_000.0
1✔
37
    i_min: float = 0.0
1✔
38
    i_max: float = 180.0
1✔
39

40

41
def select_average_within_region(coordinates: KeplerianCoordinates) -> int:
1✔
42
    """
43
    Select the Keplerian coordinate as close to the median in semi-major axis,
44
    eccentricity, and inclination.
45

46
    Parameters
47
    ----------
48
    coordinates
49
        Keplerian coordinates to select from.
50

51
    Returns
52
    -------
53
    index
54
        Index of the selected coordinates.
55
    """
56
    keplerian = coordinates.values
1✔
57
    aei = keplerian[:, 0:3]
1✔
58

59
    median = np.median(aei, axis=0)
1✔
60
    percent_diff = np.abs((aei - median) / median)
1✔
61

62
    # Sum the percent differences
63
    summed_diff = np.sum(percent_diff, axis=1)
1✔
64

65
    # Find the minimum summed percent difference and call that
66
    # the average object
67
    index = np.where(summed_diff == np.min(summed_diff))[0][0]
1✔
68
    return index
1✔
69

70

71
def select_test_orbits(ephemeris: Ephemeris, orbits: Orbits) -> Orbits:
1✔
72
    """
73
    Select test orbits from orbits using the predicted ephemeris
74
    for different regions of Keplerian phase space.
75

76
    The regions are:
77
    - 3 in the Hungarias
78
    - 5 in the main belt
79
    - 1 in the outer solar system
80

81
    Parameters
82
    ----------
83
    ephemeris
84
        Ephemeris for the orbits.
85
    orbits
86
        Orbits to select from.
87

88
    Returns
89
    -------
90
    test_orbits
91
        Test orbits selected from the orbits.
92
    """
93
    orbits_patch = orbits.apply_mask(pc.is_in(orbits.orbit_id, ephemeris.orbit_id))
×
94

95
    # Convert to keplerian coordinates
96
    keplerian = orbits_patch.coordinates.to_keplerian()
×
97

98
    # Create 3 phase space regions for the Hungarias
99
    hungarias_01 = KeplerianPhaseSpace(
×
100
        a_min=1.7,
101
        a_max=2.06,
102
        e_max=0.1,
103
    )
104
    hungarias_02 = KeplerianPhaseSpace(
×
105
        a_min=hungarias_01.a_min,
106
        a_max=hungarias_01.a_max,
107
        e_min=hungarias_01.e_max,
108
        e_max=0.2,
109
    )
110
    hungarias_03 = KeplerianPhaseSpace(
×
111
        a_min=hungarias_01.a_min,
112
        a_max=hungarias_01.a_max,
113
        e_min=hungarias_02.e_max,
114
        e_max=0.4,
115
    )
116

117
    # Create 5 phase space regions for the rest of the main belt
118
    mainbelt_01 = KeplerianPhaseSpace(
×
119
        a_min=hungarias_03.a_max,
120
        a_max=2.5,
121
        e_max=0.5,
122
    )
123
    mainbelt_02 = KeplerianPhaseSpace(
×
124
        a_min=mainbelt_01.a_max,
125
        a_max=2.82,
126
        e_max=0.5,
127
    )
128
    mainbelt_03 = KeplerianPhaseSpace(
×
129
        a_min=mainbelt_02.a_max,
130
        a_max=2.95,
131
        e_max=0.5,
132
    )
133
    mainbelt_04 = KeplerianPhaseSpace(
×
134
        a_min=mainbelt_03.a_max,
135
        a_max=3.27,
136
        e_max=0.5,
137
    )
138
    mainbelt_05 = KeplerianPhaseSpace(
×
139
        a_min=mainbelt_04.a_max,
140
        a_max=5.0,
141
        e_max=0.5,
142
    )
143

144
    # Create 1 phase space region for trojans, TNOs, etc..
145
    outer = KeplerianPhaseSpace(
×
146
        a_min=mainbelt_05.a_max,
147
        a_max=50.0,
148
        e_max=0.5,
149
    )
150

151
    phase_space_regions = [
×
152
        hungarias_01,
153
        hungarias_02,
154
        hungarias_03,
155
        mainbelt_01,
156
        mainbelt_02,
157
        mainbelt_03,
158
        mainbelt_04,
159
        mainbelt_05,
160
        outer,
161
    ]
162

163
    test_orbits = []
×
164
    for region in phase_space_regions:
×
165
        mask = pc.and_(
×
166
            pc.and_(
167
                pc.and_(
168
                    pc.and_(
169
                        pc.and_(
170
                            pc.greater_equal(keplerian.a, region.a_min),
171
                            pc.less(keplerian.a, region.a_max),
172
                        ),
173
                        pc.greater_equal(keplerian.e, region.e_min),
174
                    ),
175
                    pc.less(keplerian.e, region.e_max),
176
                ),
177
                pc.greater_equal(keplerian.i, region.i_min),
178
            ),
179
            pc.less(keplerian.i, region.i_max),
180
        )
181

182
        keplerian_region = keplerian.apply_mask(mask)
×
183
        orbits_region = orbits_patch.apply_mask(mask)
×
184

185
        if len(keplerian_region) != 0:
×
186
            index = select_average_within_region(keplerian_region)
×
187
            test_orbits.append(orbits_region[int(index)])
×
188

189
    if len(test_orbits) > 0:
×
190
        return qv.concatenate(test_orbits)
×
191
    else:
192
        return Orbits.empty()
×
193

194

195
def generate_test_orbits_worker(
1✔
196
    healpixel_chunk: pa.Array,
197
    ephemeris_healpixels: pa.Array,
198
    propagated_orbits: Union[Orbits, ray.ObjectRef],
199
    ephemeris: Union[Ephemeris, ray.ObjectRef],
200
) -> TestOrbits:
201
    """
202
    Worker function for generating test orbits.
203

204
    Parameters
205
    ----------
206
    healpixel_chunk
207
        Healpixels to generate test orbits for.
208
    ephemeris_healpixels
209
        Healpixels for the ephemeris.
210
    propagated_orbits
211
        Propagated orbits.
212
    ephemeris
213
        Ephemeris for the propagated orbits.
214

215
    Returns
216
    -------
217
    test_orbits
218
        Test orbits generated from the propagated orbits.
219
    """
220
    test_orbits_list = []
×
221

222
    # Filter the ephemerides to only those in the observations
223
    ephemeris_mask = pc.is_in(ephemeris_healpixels, healpixel_chunk)
×
224
    ephemeris_filtered = ephemeris.apply_mask(ephemeris_mask)
×
225
    ephemeris_healpixels = pc.filter(ephemeris_healpixels, ephemeris_mask)
×
226
    logger.info(
×
227
        f"{len(ephemeris_filtered)} orbit ephemerides overlap with the observations."
228
    )
229

230
    # Filter the orbits to only those in the ephemeris
231
    orbits_filtered = propagated_orbits.apply_mask(
×
232
        pc.is_in(propagated_orbits.orbit_id, ephemeris_filtered.orbit_id)
233
    )
234

235
    logger.info("Selecting test orbits from the orbit catalog...")
×
236
    for healpixel in healpixel_chunk:
×
237
        healpixel_mask = pc.equal(ephemeris_healpixels, healpixel)
×
238
        ephemeris_healpixel = ephemeris_filtered.apply_mask(healpixel_mask)
×
239

240
        if len(ephemeris_healpixel) == 0:
×
241
            logger.debug(f"No ephemerides in healpixel {healpixel}.")
×
242
            continue
×
243

244
        test_orbits_healpixel = select_test_orbits(ephemeris_healpixel, orbits_filtered)
×
245

246
        if len(test_orbits_healpixel) > 0:
×
247
            test_orbits_list.append(
×
248
                TestOrbits.from_kwargs(
249
                    orbit_id=test_orbits_healpixel.orbit_id,
250
                    object_id=test_orbits_healpixel.object_id,
251
                    coordinates=test_orbits_healpixel.coordinates,
252
                    bundle_id=[healpixel for _ in range(len(test_orbits_healpixel))],
253
                )
254
            )
255
        else:
256
            logger.debug(f"No orbits in healpixel {healpixel}.")
×
257

258
    if len(test_orbits_list) > 0:
×
259
        test_orbits = qv.concatenate(test_orbits_list)
×
260
    else:
261
        test_orbits = TestOrbits.empty()
×
262

263
    return test_orbits
×
264

265

266
generate_test_orbits_worker_remote = ray.remote(generate_test_orbits_worker)
1✔
267
generate_test_orbits_worker_remote.options(num_cpus=1, num_returns=1)
1✔
268

269

270
def generate_test_orbits(
1✔
271
    observations: Union[str, Observations],
272
    catalog: Orbits,
273
    nside: int = 32,
274
    propagator: Propagator = PYOORB(),
275
    max_processes: Optional[int] = None,
276
    chunk_size: int = 100,
277
) -> TestOrbits:
278
    """
279
    Given observations and a catalog of known orbits generate test orbits
280
    from the catalog. The observations are divded into healpixels (with size determined
281
    by the nside parameter). For each healpixel in observations, select up to 9 orbits from
282
    the catalog that are in the same healpixel as the observations. The orbits are selected
283
    in bins of semi-major axis, eccentricity, and inclination.
284

285
    The catalog will be propagated to start time of the observations using the propagator
286
    and ephemerides will be generated for the propagated orbits (assuming a geocentric observer).
287

288
    Parameters
289
    ----------
290
    observations
291
        Observations for which to generate test orbits. These observations can
292
        be an in-memory Observations object or a path to a parquet file containing the
293
        observations.
294
    catalog
295
        Catalog of known orbits.
296
    nside
297
        Healpixel size.
298
    propagator
299
        Propagator to use to propagate the orbits.
300
    max_processes
301
        Maximum number of processes to use while propagating orbits and
302
        generating ephemerides.
303
    chunk_size
304
        The maximum number of unique healpixels for which to generate test orbits per
305
        process. This function will dynamically compute the chunk size based on the
306
        number of unique healpixels and the number of processes. The dynamic chunk
307
        size will never exceed the given value.
308

309
    Returns
310
    -------
311
    test_orbits
312
        Test orbits generated from the catalog.
313
    """
314
    time_start = time.perf_counter()
×
315
    logger.info("Generating test orbits...")
×
316

317
    # If the input file is a string, read in the days column to
318
    # extract the minimum time
319
    if isinstance(observations, str):
×
320
        table = pq.read_table(
×
321
            observations, columns=["coordinates.time.days"], memory_map=True
322
        )
323

324
        min_day = pc.min(table["days"]).as_py()
×
325
        # Set the start time to the midnight of the first night of observations
326
        start_time = Timestamp.from_kwargs(days=[min_day], nanos=[0], scale="utc")
×
327
        del table
×
328
    elif isinstance(observations, Observations):
×
329
        # Extract the minimum time from the observations
330
        earliest_time = observations.coordinates.time.min()
×
331

332
        # Set the start time to the midnight of the first night of observations
333
        start_time = Timestamp.from_kwargs(
×
334
            days=earliest_time.days, nanos=[0], scale="utc"
335
        )
336
    else:
337
        raise ValueError(
×
338
            f"observations must be a path to a parquet file or an Observations object. Got {type(observations)}."
339
        )
340

341
    # Propagate the orbits to the minimum time
342
    logger.info("Propagating orbits to the start time of the observations...")
×
343
    propagation_start_time = time.perf_counter()
×
344
    propagated_orbits = propagator.propagate_orbits(
×
345
        catalog,
346
        start_time,
347
        max_processes=max_processes,
348
        parallel_backend="ray",
349
        chunk_size=500,
350
    )
351
    propagation_end_time = time.perf_counter()
×
352
    logger.info(
×
353
        f"Propagation completed in {propagation_end_time - propagation_start_time:.3f} seconds."
354
    )
355

356
    # Create a geocentric observer for the observations
357
    logger.info("Generating ephemerides for the propagated orbits...")
×
358
    ephemeris_start_time = time.perf_counter()
×
359
    observers = Observers.from_code("500", start_time)
×
360

361
    # Generate ephemerides for the propagated orbits
362
    ephemeris = propagator.generate_ephemeris(
×
363
        propagated_orbits,
364
        observers,
365
        start_time,
366
        max_processes=max_processes,
367
        parallel_backend="ray",
368
        chunk_size=1000,
369
    )
370
    ephemeris_end_time = time.perf_counter()
×
371
    logger.info(
×
372
        f"Ephemeris generation completed in {ephemeris_end_time - ephemeris_start_time:.3f} seconds."
373
    )
374

375
    if isinstance(observations, str):
×
376
        table = pq.read_table(
×
377
            observations,
378
            columns=["coordinates.lon", "coordinates.lat"],
379
            memory_map=True,
380
        )
381
        lon = table["lon"].to_numpy(zero_copy_only=False)
×
382
        lat = table["lat"].to_numpy(zero_copy_only=False)
×
383
        del table
×
384

385
    else:
386
        lon = observations.coordinates.lon.to_numpy(zero_copy_only=False)
×
387
        lat = observations.coordinates.lat.to_numpy(zero_copy_only=False)
×
388

389
    # Calculate the healpixels for observations and ephemerides
390
    # Here we want the unique healpixels so we can cross match against our
391
    # catalog's predicted ephemeris
392
    observations_healpixels = calculate_healpixels(
×
393
        lon,
394
        lat,
395
        nside=nside,
396
    )
397
    observations_healpixels = pc.unique(pa.array(observations_healpixels))
×
398
    logger.info(
×
399
        f"Observations occur in {len(observations_healpixels)} unique healpixels."
400
    )
401

402
    # Calculate the healpixels for each ephemeris
403
    # We do not want unique healpixels here because we want to
404
    # select orbits from the same healpixel as the observations
405
    ephemeris_healpixels = calculate_healpixels(
×
406
        ephemeris.coordinates.lon.to_numpy(zero_copy_only=False),
407
        ephemeris.coordinates.lat.to_numpy(zero_copy_only=False),
408
        nside=nside,
409
    )
410
    ephemeris_healpixels = pa.array(ephemeris_healpixels)
×
411

412
    # Dynamically compute the chunk size based on the number of healpixels
413
    # and the number of processes
414
    if max_processes is None:
×
415
        max_processes = mp.cpu_count()
×
416

417
    chunk_size = np.minimum(
×
418
        np.ceil(len(observations_healpixels) / max_processes).astype(int), chunk_size
419
    )
420
    logger.info(f"Generating test orbits with a chunk size of {chunk_size} healpixels.")
×
421

422
    test_orbits = TestOrbits.empty()
×
423
    use_ray = initialize_use_ray(num_cpus=max_processes)
×
424
    if use_ray:
×
425

426
        ephemeris_ref = ray.put(ephemeris)
×
427
        ephemeris_healpixels_ref = ray.put(ephemeris_healpixels)
×
428
        propagated_orbits_ref = ray.put(propagated_orbits)
×
429

430
        futures = []
×
431
        for healpixel_chunk in _iterate_chunks(observations_healpixels, chunk_size):
×
432
            futures.append(
×
433
                generate_test_orbits_worker_remote.remote(
434
                    healpixel_chunk,
435
                    ephemeris_healpixels_ref,
436
                    propagated_orbits_ref,
437
                    ephemeris_ref,
438
                )
439
            )
440

441
        while futures:
×
442
            finished, futures = ray.wait(futures, num_returns=1)
×
443
            test_orbits = qv.concatenate([test_orbits, ray.get(finished[0])])
×
444
            if test_orbits.fragmented():
×
445
                test_orbits = qv.defragment(test_orbits)
×
446

447
    else:
448

449
        for healpixel_chunk in _iterate_chunks(observations_healpixels, chunk_size):
×
450
            test_orbits_chunk = generate_test_orbits_worker(
×
451
                healpixel_chunk,
452
                ephemeris_healpixels,
453
                propagated_orbits,
454
                ephemeris,
455
            )
456
            test_orbits = qv.concatenate([test_orbits, test_orbits_chunk])
×
457
            if test_orbits.fragmented():
×
458
                test_orbits = qv.defragment(test_orbits)
×
459

460
    time_end = time.perf_counter()
×
461
    logger.info(f"Selected {len(test_orbits)} test orbits.")
×
462
    logger.info(
×
463
        f"Test orbit generation completed in {time_end - time_start:.3f} seconds."
464
    )
465
    return test_orbits
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc