• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

moeyensj / thor / 12057966667

27 Nov 2024 08:43PM UTC coverage: 73.794% (-0.2%) from 73.99%
12057966667

Pull #166

github

web-flow
Merge 491d82a51 into 3172ac203
Pull Request #166: Swap to pdm

213 of 288 new or added lines in 37 files covered. (73.96%)

2768 of 3751 relevant lines covered (73.79%)

0.74 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

30.0
/src/thor/orbit_selection.py
1
import logging
1✔
2
import multiprocessing as mp
1✔
3
import time
1✔
4
from dataclasses import dataclass
1✔
5
from typing import Optional, Union
1✔
6

7
import numpy as np
1✔
8
import pyarrow as pa
1✔
9
import pyarrow.compute as pc
1✔
10
import pyarrow.parquet as pq
1✔
11
import quivr as qv
1✔
12
import ray
1✔
13
from adam_core.coordinates import KeplerianCoordinates
1✔
14
from adam_core.observers import Observers
1✔
15
from adam_core.orbits import Ephemeris, Orbits
1✔
16
from adam_core.propagator import Propagator
1✔
17
from adam_core.propagator.adam_pyoorb import PYOORBPropagator
1✔
18
from adam_core.propagator.utils import _iterate_chunks
1✔
19
from adam_core.ray_cluster import initialize_use_ray
1✔
20
from adam_core.time import Timestamp
1✔
21

22
from thor.observations import Observations
1✔
23
from thor.orbit import TestOrbits
1✔
24

25
from .observations.utils import calculate_healpixels
1✔
26

27
logger = logging.getLogger(__name__)
1✔
28

29
__all__ = ["generate_test_orbits"]
1✔
30

31

32
@dataclass
1✔
33
class KeplerianPhaseSpace:
1✔
34
    a_min: float = -1_000_000.0
1✔
35
    a_max: float = 1_000_000.0
1✔
36
    e_min: float = 0.0
1✔
37
    e_max: float = 1_000.0
1✔
38
    i_min: float = 0.0
1✔
39
    i_max: float = 180.0
1✔
40

41

42
def select_average_within_region(coordinates: KeplerianCoordinates) -> int:
1✔
43
    """
44
    Select the Keplerian coordinate as close to the median in semi-major axis,
45
    eccentricity, and inclination.
46

47
    Parameters
48
    ----------
49
    coordinates
50
        Keplerian coordinates to select from.
51

52
    Returns
53
    -------
54
    index
55
        Index of the selected coordinates.
56
    """
57
    keplerian = coordinates.values
1✔
58
    aei = keplerian[:, 0:3]
1✔
59

60
    median = np.median(aei, axis=0)
1✔
61
    percent_diff = np.abs((aei - median) / median)
1✔
62

63
    # Sum the percent differences
64
    summed_diff = np.sum(percent_diff, axis=1)
1✔
65

66
    # Find the minimum summed percent difference and call that
67
    # the average object
68
    index = np.where(summed_diff == np.min(summed_diff))[0][0]
1✔
69
    return index
1✔
70

71

72
def select_test_orbits(ephemeris: Ephemeris, orbits: Orbits) -> Orbits:
1✔
73
    """
74
    Select test orbits from orbits using the predicted ephemeris
75
    for different regions of Keplerian phase space.
76

77
    The regions are:
78
    - 3 in the Hungarias
79
    - 5 in the main belt
80
    - 1 in the outer solar system
81

82
    Parameters
83
    ----------
84
    ephemeris
85
        Ephemeris for the orbits.
86
    orbits
87
        Orbits to select from.
88

89
    Returns
90
    -------
91
    test_orbits
92
        Test orbits selected from the orbits.
93
    """
94
    orbits_patch = orbits.apply_mask(pc.is_in(orbits.orbit_id, ephemeris.orbit_id))
×
95

96
    # Convert to keplerian coordinates
97
    keplerian = orbits_patch.coordinates.to_keplerian()
×
98

99
    # Create 3 phase space regions for the Hungarias
100
    hungarias_01 = KeplerianPhaseSpace(
×
101
        a_min=1.7,
102
        a_max=2.06,
103
        e_max=0.1,
104
    )
105
    hungarias_02 = KeplerianPhaseSpace(
×
106
        a_min=hungarias_01.a_min,
107
        a_max=hungarias_01.a_max,
108
        e_min=hungarias_01.e_max,
109
        e_max=0.2,
110
    )
111
    hungarias_03 = KeplerianPhaseSpace(
×
112
        a_min=hungarias_01.a_min,
113
        a_max=hungarias_01.a_max,
114
        e_min=hungarias_02.e_max,
115
        e_max=0.4,
116
    )
117

118
    # Create 5 phase space regions for the rest of the main belt
119
    mainbelt_01 = KeplerianPhaseSpace(
×
120
        a_min=hungarias_03.a_max,
121
        a_max=2.5,
122
        e_max=0.5,
123
    )
124
    mainbelt_02 = KeplerianPhaseSpace(
×
125
        a_min=mainbelt_01.a_max,
126
        a_max=2.82,
127
        e_max=0.5,
128
    )
129
    mainbelt_03 = KeplerianPhaseSpace(
×
130
        a_min=mainbelt_02.a_max,
131
        a_max=2.95,
132
        e_max=0.5,
133
    )
134
    mainbelt_04 = KeplerianPhaseSpace(
×
135
        a_min=mainbelt_03.a_max,
136
        a_max=3.27,
137
        e_max=0.5,
138
    )
139
    mainbelt_05 = KeplerianPhaseSpace(
×
140
        a_min=mainbelt_04.a_max,
141
        a_max=5.0,
142
        e_max=0.5,
143
    )
144

145
    # Create 1 phase space region for trojans, TNOs, etc..
146
    outer = KeplerianPhaseSpace(
×
147
        a_min=mainbelt_05.a_max,
148
        a_max=50.0,
149
        e_max=0.5,
150
    )
151

152
    phase_space_regions = [
×
153
        hungarias_01,
154
        hungarias_02,
155
        hungarias_03,
156
        mainbelt_01,
157
        mainbelt_02,
158
        mainbelt_03,
159
        mainbelt_04,
160
        mainbelt_05,
161
        outer,
162
    ]
163

164
    test_orbits = []
×
165
    for region in phase_space_regions:
×
166
        mask = pc.and_(
×
167
            pc.and_(
168
                pc.and_(
169
                    pc.and_(
170
                        pc.and_(
171
                            pc.greater_equal(keplerian.a, region.a_min),
172
                            pc.less(keplerian.a, region.a_max),
173
                        ),
174
                        pc.greater_equal(keplerian.e, region.e_min),
175
                    ),
176
                    pc.less(keplerian.e, region.e_max),
177
                ),
178
                pc.greater_equal(keplerian.i, region.i_min),
179
            ),
180
            pc.less(keplerian.i, region.i_max),
181
        )
182

183
        keplerian_region = keplerian.apply_mask(mask)
×
184
        orbits_region = orbits_patch.apply_mask(mask)
×
185

186
        if len(keplerian_region) != 0:
×
187
            index = select_average_within_region(keplerian_region)
×
188
            test_orbits.append(orbits_region[int(index)])
×
189

190
    if len(test_orbits) > 0:
×
191
        return qv.concatenate(test_orbits)
×
192
    else:
193
        return Orbits.empty()
×
194

195

196
def generate_test_orbits_worker(
1✔
197
    healpixel_chunk: pa.Array,
198
    ephemeris_healpixels: pa.Array,
199
    propagated_orbits: Union[Orbits, ray.ObjectRef],
200
    ephemeris: Union[Ephemeris, ray.ObjectRef],
201
) -> TestOrbits:
202
    """
203
    Worker function for generating test orbits.
204

205
    Parameters
206
    ----------
207
    healpixel_chunk
208
        Healpixels to generate test orbits for.
209
    ephemeris_healpixels
210
        Healpixels for the ephemeris.
211
    propagated_orbits
212
        Propagated orbits.
213
    ephemeris
214
        Ephemeris for the propagated orbits.
215

216
    Returns
217
    -------
218
    test_orbits
219
        Test orbits generated from the propagated orbits.
220
    """
221
    test_orbits_list = []
×
222

223
    # Filter the ephemerides to only those in the observations
224
    ephemeris_mask = pc.is_in(ephemeris_healpixels, healpixel_chunk)
×
225
    ephemeris_filtered = ephemeris.apply_mask(ephemeris_mask)
×
226
    ephemeris_healpixels = pc.filter(ephemeris_healpixels, ephemeris_mask)
×
NEW
227
    logger.info(f"{len(ephemeris_filtered)} orbit ephemerides overlap with the observations.")
×
228

229
    # Filter the orbits to only those in the ephemeris
230
    orbits_filtered = propagated_orbits.apply_mask(
×
231
        pc.is_in(propagated_orbits.orbit_id, ephemeris_filtered.orbit_id)
232
    )
233

234
    logger.info("Selecting test orbits from the orbit catalog...")
×
235
    for healpixel in healpixel_chunk:
×
236
        healpixel_mask = pc.equal(ephemeris_healpixels, healpixel)
×
237
        ephemeris_healpixel = ephemeris_filtered.apply_mask(healpixel_mask)
×
238

239
        if len(ephemeris_healpixel) == 0:
×
240
            logger.debug(f"No ephemerides in healpixel {healpixel}.")
×
241
            continue
×
242

243
        test_orbits_healpixel = select_test_orbits(ephemeris_healpixel, orbits_filtered)
×
244

245
        if len(test_orbits_healpixel) > 0:
×
246
            test_orbits_list.append(
×
247
                TestOrbits.from_kwargs(
248
                    orbit_id=test_orbits_healpixel.orbit_id,
249
                    object_id=test_orbits_healpixel.object_id,
250
                    coordinates=test_orbits_healpixel.coordinates,
251
                    bundle_id=[healpixel for _ in range(len(test_orbits_healpixel))],
252
                )
253
            )
254
        else:
255
            logger.debug(f"No orbits in healpixel {healpixel}.")
×
256

257
    if len(test_orbits_list) > 0:
×
258
        test_orbits = qv.concatenate(test_orbits_list)
×
259
    else:
260
        test_orbits = TestOrbits.empty()
×
261

262
    return test_orbits
×
263

264

265
generate_test_orbits_worker_remote = ray.remote(generate_test_orbits_worker)
1✔
266
generate_test_orbits_worker_remote.options(num_cpus=1, num_returns=1)
1✔
267

268

269
def generate_test_orbits(
1✔
270
    observations: Union[str, Observations],
271
    catalog: Orbits,
272
    nside: int = 32,
273
    propagator: Propagator = PYOORBPropagator(),
274
    max_processes: Optional[int] = None,
275
    chunk_size: int = 100,
276
) -> TestOrbits:
277
    """
278
    Given observations and a catalog of known orbits generate test orbits
279
    from the catalog. The observations are divded into healpixels (with size determined
280
    by the nside parameter). For each healpixel in observations, select up to 9 orbits from
281
    the catalog that are in the same healpixel as the observations. The orbits are selected
282
    in bins of semi-major axis, eccentricity, and inclination.
283

284
    The catalog will be propagated to start time of the observations using the propagator
285
    and ephemerides will be generated for the propagated orbits (assuming a geocentric observer).
286

287
    Parameters
288
    ----------
289
    observations
290
        Observations for which to generate test orbits. These observations can
291
        be an in-memory Observations object or a path to a parquet file containing the
292
        observations.
293
    catalog
294
        Catalog of known orbits.
295
    nside
296
        Healpixel size.
297
    propagator
298
        Propagator to use to propagate the orbits.
299
    max_processes
300
        Maximum number of processes to use while propagating orbits and
301
        generating ephemerides.
302
    chunk_size
303
        The maximum number of unique healpixels for which to generate test orbits per
304
        process. This function will dynamically compute the chunk size based on the
305
        number of unique healpixels and the number of processes. The dynamic chunk
306
        size will never exceed the given value.
307

308
    Returns
309
    -------
310
    test_orbits
311
        Test orbits generated from the catalog.
312
    """
313
    time_start = time.perf_counter()
×
314
    logger.info("Generating test orbits...")
×
315

316
    # If the input file is a string, read in the days column to
317
    # extract the minimum time
318
    if isinstance(observations, str):
×
NEW
319
        table = pq.read_table(observations, columns=["coordinates.time.days"], memory_map=True)
×
320

321
        min_day = pc.min(table["days"]).as_py()
×
322
        # Set the start time to the midnight of the first night of observations
323
        start_time = Timestamp.from_kwargs(days=[min_day], nanos=[0], scale="utc")
×
324
        del table
×
325
    elif isinstance(observations, Observations):
×
326
        # Extract the minimum time from the observations
327
        earliest_time = observations.coordinates.time.min()
×
328

329
        # Set the start time to the midnight of the first night of observations
NEW
330
        start_time = Timestamp.from_kwargs(days=earliest_time.days, nanos=[0], scale="utc")
×
331
    else:
332
        raise ValueError(
×
333
            f"observations must be a path to a parquet file or an Observations object. Got {type(observations)}."
334
        )
335

336
    # Propagate the orbits to the minimum time
337
    logger.info("Propagating orbits to the start time of the observations...")
×
338
    propagation_start_time = time.perf_counter()
×
339
    propagated_orbits = propagator.propagate_orbits(
×
340
        catalog,
341
        start_time,
342
        max_processes=max_processes,
343
        chunk_size=500,
344
    )
345
    propagation_end_time = time.perf_counter()
×
NEW
346
    logger.info(f"Propagation completed in {propagation_end_time - propagation_start_time:.3f} seconds.")
×
347

348
    # Create a geocentric observer for the observations
349
    logger.info("Generating ephemerides for the propagated orbits...")
×
350
    ephemeris_start_time = time.perf_counter()
×
351
    observers = Observers.from_code("500", start_time)
×
352

353
    # Generate ephemerides for the propagated orbits
354
    ephemeris = propagator.generate_ephemeris(
×
355
        propagated_orbits,
356
        observers,
357
        start_time,
358
        max_processes=max_processes,
359
        chunk_size=1000,
360
    )
361
    ephemeris_end_time = time.perf_counter()
×
NEW
362
    logger.info(f"Ephemeris generation completed in {ephemeris_end_time - ephemeris_start_time:.3f} seconds.")
×
363

364
    if isinstance(observations, str):
×
365
        table = pq.read_table(
×
366
            observations,
367
            columns=["coordinates.lon", "coordinates.lat"],
368
            memory_map=True,
369
        )
370
        lon = table["lon"].to_numpy(zero_copy_only=False)
×
371
        lat = table["lat"].to_numpy(zero_copy_only=False)
×
372
        del table
×
373

374
    else:
375
        lon = observations.coordinates.lon.to_numpy(zero_copy_only=False)
×
376
        lat = observations.coordinates.lat.to_numpy(zero_copy_only=False)
×
377

378
    # Calculate the healpixels for observations and ephemerides
379
    # Here we want the unique healpixels so we can cross match against our
380
    # catalog's predicted ephemeris
381
    observations_healpixels = calculate_healpixels(
×
382
        lon,
383
        lat,
384
        nside=nside,
385
    )
386
    observations_healpixels = pc.unique(pa.array(observations_healpixels))
×
NEW
387
    logger.info(f"Observations occur in {len(observations_healpixels)} unique healpixels.")
×
388

389
    # Calculate the healpixels for each ephemeris
390
    # We do not want unique healpixels here because we want to
391
    # select orbits from the same healpixel as the observations
392
    ephemeris_healpixels = calculate_healpixels(
×
393
        ephemeris.coordinates.lon.to_numpy(zero_copy_only=False),
394
        ephemeris.coordinates.lat.to_numpy(zero_copy_only=False),
395
        nside=nside,
396
    )
397
    ephemeris_healpixels = pa.array(ephemeris_healpixels)
×
398

399
    # Dynamically compute the chunk size based on the number of healpixels
400
    # and the number of processes
401
    if max_processes is None:
×
402
        max_processes = mp.cpu_count()
×
403

NEW
404
    chunk_size = np.minimum(np.ceil(len(observations_healpixels) / max_processes).astype(int), chunk_size)
×
405
    logger.info(f"Generating test orbits with a chunk size of {chunk_size} healpixels.")
×
406

407
    test_orbits = TestOrbits.empty()
×
408
    use_ray = initialize_use_ray(num_cpus=max_processes)
×
409
    if use_ray:
×
410

411
        ephemeris_ref = ray.put(ephemeris)
×
412
        ephemeris_healpixels_ref = ray.put(ephemeris_healpixels)
×
413
        propagated_orbits_ref = ray.put(propagated_orbits)
×
414

415
        futures = []
×
416
        for healpixel_chunk in _iterate_chunks(observations_healpixels, chunk_size):
×
417
            futures.append(
×
418
                generate_test_orbits_worker_remote.remote(
419
                    healpixel_chunk,
420
                    ephemeris_healpixels_ref,
421
                    propagated_orbits_ref,
422
                    ephemeris_ref,
423
                )
424
            )
425

426
        while futures:
×
427
            finished, futures = ray.wait(futures, num_returns=1)
×
428
            test_orbits = qv.concatenate([test_orbits, ray.get(finished[0])])
×
429
            if test_orbits.fragmented():
×
430
                test_orbits = qv.defragment(test_orbits)
×
431

432
    else:
433

434
        for healpixel_chunk in _iterate_chunks(observations_healpixels, chunk_size):
×
435
            test_orbits_chunk = generate_test_orbits_worker(
×
436
                healpixel_chunk,
437
                ephemeris_healpixels,
438
                propagated_orbits,
439
                ephemeris,
440
            )
441
            test_orbits = qv.concatenate([test_orbits, test_orbits_chunk])
×
442
            if test_orbits.fragmented():
×
443
                test_orbits = qv.defragment(test_orbits)
×
444

445
    time_end = time.perf_counter()
×
446
    logger.info(f"Selected {len(test_orbits)} test orbits.")
×
NEW
447
    logger.info(f"Test orbit generation completed in {time_end - time_start:.3f} seconds.")
×
448
    return test_orbits
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc