• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

moeyensj / thor / 12280237166

11 Dec 2024 04:04PM UTC coverage: 75.265% (+1.5%) from 73.794%
12280237166

Pull #167

github

web-flow
Merge e0b85b5dd into 597f246f5
Pull Request #167: Use generic propagator, new adam-core, tests with ASSIST

51 of 54 new or added lines in 14 files covered. (94.44%)

10 existing lines in 2 files now uncovered.

2839 of 3772 relevant lines covered (75.27%)

0.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.93
/src/thor/orbit.py
1
import logging
1✔
2
import multiprocessing as mp
1✔
3
import uuid
1✔
4
from typing import Optional, Type, TypeVar, Union
1✔
5

6
import numpy as np
1✔
7
import pyarrow as pa
1✔
8
import pyarrow.compute as pc
1✔
9
import quivr as qv
1✔
10
import ray
1✔
11
from adam_core.coordinates import (
1✔
12
    CartesianCoordinates,
13
    CometaryCoordinates,
14
    KeplerianCoordinates,
15
    OriginCodes,
16
    SphericalCoordinates,
17
    transform_coordinates,
18
)
19
from adam_core.observers import Observers
1✔
20
from adam_core.orbits import Ephemeris, Orbits
1✔
21
from adam_core.propagator import Propagator
1✔
22
from adam_core.ray_cluster import initialize_use_ray
1✔
23
from adam_core.time import Timestamp
1✔
24

25
from .observations import Observations
1✔
26

27
CoordinateType = TypeVar(
1✔
28
    "CoordinateType",
29
    bound=Union[
30
        CartesianCoordinates,
31
        SphericalCoordinates,
32
        KeplerianCoordinates,
33
        CometaryCoordinates,
34
    ],
35
)
36

37

38
logger = logging.getLogger(__name__)
1✔
39

40

41
class RangedPointSourceDetections(qv.Table):
1✔
42
    id = qv.LargeStringColumn()
1✔
43
    exposure_id = qv.LargeStringColumn()
1✔
44
    coordinates = SphericalCoordinates.as_column()
1✔
45
    state_id = qv.LargeStringColumn()
1✔
46

47

48
class TestOrbitEphemeris(qv.Table):
1✔
49
    id = qv.LargeStringColumn()
1✔
50
    ephemeris = Ephemeris.as_column()
1✔
51
    observer = Observers.as_column()
1✔
52

53

54
def range_observations_worker(
1✔
55
    observations: Observations, ephemeris: TestOrbitEphemeris, state_id: str
56
) -> RangedPointSourceDetections:
57
    """
58
    Range observations for a single state given the orbit's ephemeris for that state.
59

60
    Parameters
61
    ----------
62
    observations
63
        Observations to range.
64
    ephemeris
65
        Ephemeris from which to extract the test orbit's aberrated state (we
66
        use this state to get the test orbit's heliocentric distance).
67
    state_id
68
        The ID for this particular state.
69

70
    Returns
71
    -------
72
    ranged_point_source_detections
73
        The detections assuming they are located at the same heliocentric distance
74
        as the test orbit.
75
    """
76
    observations_state = observations.select("state_id", state_id)
1✔
77
    ephemeris_state = ephemeris.select("id", state_id)
1✔
78
    assert len(ephemeris_state) == 1
1✔
79

80
    # Get the heliocentric position vector of the object at the time of the exposure
81
    aberrated_coordinates = ephemeris_state.ephemeris.aberrated_coordinates
1✔
82
    if aberrated_coordinates.origin.code.to_pylist()[0] != "SUN":
1✔
83
        aberrated_coordinates = transform_coordinates(
1✔
84
            aberrated_coordinates,
85
            CartesianCoordinates,
86
            frame_out="ecliptic",
87
            origin_out=OriginCodes.SUN,
88
        )
89

90
    r = aberrated_coordinates.r[0]
1✔
91

92
    # Get the observer's heliocentric coordinates
93
    observer_i = ephemeris_state.observer
1✔
94

95
    return RangedPointSourceDetections.from_kwargs(
1✔
96
        id=observations_state.id,
97
        exposure_id=observations_state.exposure_id,
98
        coordinates=assume_heliocentric_distance(r, observations_state.coordinates, observer_i.coordinates),
99
        state_id=observations_state.state_id,
100
    )
101

102

103
range_observations_remote = ray.remote(range_observations_worker)
1✔
104

105

106
class TestOrbits(qv.Table):
1✔
107
    orbit_id = qv.LargeStringColumn(default=lambda: uuid.uuid4().hex)
1✔
108
    object_id = qv.LargeStringColumn(nullable=True)
1✔
109
    bundle_id = qv.Int64Column(nullable=True)
1✔
110
    coordinates = CartesianCoordinates.as_column()
1✔
111

112
    @classmethod
1✔
113
    def from_orbits(cls, orbits):
1✔
114
        return cls.from_kwargs(
1✔
115
            orbit_id=orbits.orbit_id,
116
            object_id=orbits.object_id,
117
            coordinates=orbits.coordinates,
118
        )
119

120
    def to_orbits(self):
1✔
121
        return Orbits.from_kwargs(
1✔
122
            coordinates=self.coordinates,
123
            orbit_id=self.orbit_id,
124
            object_id=self.object_id,
125
        )
126

127
    def _is_cache_fresh(self, observations: Observations) -> bool:
1✔
128
        """
129
        Check if the cached ephemeris is fresh. If the observation IDs are contained within the
130
        cached observation IDs, then the cache is fresh. Otherwise, it is stale. This permits
131
        observations to be filtered out without having to regenerate the ephemeris.
132

133
        Parameters
134
        ----------
135
        observations : `~thor.observations.observations.Observations`
136
            Observations to check against the cached ephemerides.
137

138
        Returns
139
        -------
140
        is_fresh : bool
141
            True if the cache is fresh, False otherwise.
142
        """
UNCOV
143
        if (
×
144
            getattr(self, "_cached_ephemeris", None) is None
145
            and getattr(self, "_cached_observation_ids", None) is None
146
        ):
UNCOV
147
            self._cached_ephemeris: Optional[TestOrbitEphemeris] = None
×
UNCOV
148
            self._cached_observation_ids: Optional[pa.Array] = None
×
UNCOV
149
            return False
×
UNCOV
150
        elif (
×
151
            getattr(self, "_cached_ephemeris", None) is not None
152
            and getattr(self, "_cached_observation_ids") is not None
153
            and pc.all(
154
                pc.is_in(
155
                    observations.id.sort(),
156
                    self._cached_observation_ids.sort(),  # type: ignore
157
                )
158
            ).as_py()
159
        ):
UNCOV
160
            return True
×
161
        else:
162
            return False
×
163

164
    def _cache_ephemeris(self, ephemeris: TestOrbitEphemeris, observations: Observations):
1✔
165
        """
166
        Cache the ephemeris and observation IDs.
167

168
        Parameters
169
        ----------
170
        ephemeris : `~thor.orbit.TestOrbitEphemeris`
171
            States to cache.
172
        observations : `~thor.observations.observations.Observations`
173
            Observations to cache. Only observation IDs will be cached.
174

175
        Returns
176
        -------
177
        None
178
        """
UNCOV
179
        self._cached_ephemeris = ephemeris
×
UNCOV
180
        self._cached_observation_ids = observations.id
×
181

182
    def propagate(
1✔
183
        self,
184
        times: Timestamp,
185
        propagator_class: Type[Propagator],
186
        max_processes: Optional[int] = 1,
187
    ) -> Orbits:
188
        """
189
        Propagate this test orbit to the given times.
190

191
        Parameters
192
        ----------
193
        times : `~adam_core.time.time.Timestamp`
194
            Times to which to propagate the orbit.
195
        propagator : `~adam_core.propagator.propagator.Propagator`
196
            Propagator to use to propagate the orbit.
197
        num_processes : int, optional
198
            Number of processes to use to propagate the orbit. Defaults to 1.
199

200
        Returns
201
        -------
202
        propagated_orbit : `~adam_core.orbits.orbits.Orbits`
203
            The test orbit propagated to the given times.
204
        """
NEW
205
        propagator = propagator_class()
×
UNCOV
206
        return propagator.propagate_orbits(
×
207
            self.to_orbits(),
208
            times,
209
            max_processes=max_processes,
210
            chunk_size=1,
211
        )
212

213
    def generate_ephemeris(
1✔
214
        self,
215
        observers: Observers,
216
        propagator_class: Type[Propagator],
217
        max_processes: Optional[int] = 1,
218
    ) -> Ephemeris:
219
        """
220
        Generate ephemeris for this test orbit at the given observers.
221

222
        Parameters
223
        ----------
224
        observers : `~adam_core.observers.Observers`
225
            Observers from which to generate ephemeris.
226
        propagator_class : `~adam_core.propagator.propagator.Propagator`
227
            Propagator to use to propagate the orbit.
228
        num_processes : int, optional
229
            Number of processes to use to propagate the orbit. Defaults to 1.
230

231
        Returns
232
        -------
233
        ephemeris : `~adam_core.orbits.ephemeris.Ephemeris`
234
            The ephemeris of the test orbit at the given observers.
235
        """
236
        propagator = propagator_class()
1✔
237
        return propagator.generate_ephemeris(
1✔
238
            self.to_orbits(),
239
            observers,
240
            max_processes=max_processes,
241
            chunk_size=1,
242
        )
243

244
    def generate_ephemeris_from_observations(
1✔
245
        self,
246
        observations: Union[Observations, ray.ObjectRef],
247
        propagator_class: Type[Propagator],
248
        max_processes: Optional[int] = 1,
249
    ):
250
        """
251
        For each unique time and code in the observations (a state), generate an ephemeris for
252
        that state and store them in a TestOrbitStates table. The observer's coordinates will also be
253
        stored in the table and can be referenced through out the THOR pipeline.
254

255
        These ephemerides will be cached. If the cache is fresh, the cached ephemerides will be
256
        returned instead of regenerating them.
257

258
        Parameters
259
        ----------
260
        observations : `~thor.observations.observations.Observations`
261
            Observations to compute test orbit ephemerides for.
262
        propagator_class : `~adam_core.propagator.propagator.Propagator`
263
            Propagator to use to propagate the orbit.
264
        num_processes : int, optional
265
            Number of processes to use to propagate the orbit. Defaults to 1.
266

267

268
        Returns
269
        -------
270
        states : `~thor.orbit.TestOrbitEphemeris`
271
            Table containing the ephemeris of the test orbit, its aberrated state vector, and the
272
            observer coordinates at each unique time of the observations.
273

274
        Raises
275
        ------
276
        ValueError
277
            If the observations are empty.
278
        """
279
        if isinstance(observations, ray.ObjectRef):
1✔
280
            observations = ray.get(observations)
×
281

282
        if len(observations) == 0:
1✔
283
            raise ValueError("Observations must not be empty.")
1✔
284

285
        # if self._is_cache_fresh(observations):
286
        #     logger.debug("Test orbit ephemeris cache is fresh. Returning cached states.")
287
        #     return self._cached_ephemeris
288

289
        logger.debug("Test orbit ephemeris cache is stale. Regenerating.")
1✔
290

291
        observers_with_states = observations.get_observers()
1✔
292

293
        observers_with_states = observers_with_states.sort_by(
1✔
294
            by=[
295
                "observers.coordinates.time.days",
296
                "observers.coordinates.time.nanos",
297
                "observers.code",
298
            ]
299
        )
300
        # Generate ephemerides for each unique state and then sort by time and code
301
        ephemeris = self.generate_ephemeris(
1✔
302
            observers_with_states.observers,
303
            propagator_class=propagator_class,
304
            max_processes=max_processes,
305
        )
306
        ephemeris = ephemeris.sort_by(
1✔
307
            by=[
308
                "coordinates.time.days",
309
                "coordinates.time.nanos",
310
                "coordinates.origin.code",
311
            ]
312
        )
313

314
        test_orbit_ephemeris = TestOrbitEphemeris.from_kwargs(
1✔
315
            id=observers_with_states.state_id,
316
            ephemeris=ephemeris,
317
            observer=observers_with_states.observers,
318
        )
319
        # self._cache_ephemeris(test_orbit_ephemeris, observations)
320
        return test_orbit_ephemeris
1✔
321

322
    def range_observations(
1✔
323
        self,
324
        observations: Union[Observations, ray.ObjectRef],
325
        propagator_class: Type[Propagator],
326
        max_processes: Optional[int] = 1,
327
    ) -> RangedPointSourceDetections:
328
        """
329
        Given a set of observations, propagate this test orbit to the times of the observations and calculate the
330
        topocentric distance (range) assuming they lie at the same heliocentric distance as the test orbit.
331

332
        Parameters
333
        ----------
334
        observations : `~thor.observations.observations.Observations`
335
            Observations to range.
336
        propagator : `~adam_core.propagator.propagator.Propagator`, optional
337
            Propagator to use to propagate the orbit. Defaults to PYOORB.
338
        max_processes : int, optional
339
            Number of processes to use to propagate the orbit. Defaults to 1.
340

341
        Returns
342
        -------
343
        ranged_point_source_detections : `~thor.orbit.RangedPointSourceDetections`
344
            The ranged detections.
345
        """
346
        # Generate an ephemeris for each unique observation time and observatory
347
        # code combination
348
        ephemeris = self.generate_ephemeris_from_observations(
1✔
349
            observations, propagator_class=propagator_class, max_processes=max_processes
350
        )
351

352
        if max_processes is None:
1✔
353
            max_processes = mp.cpu_count()
×
354

355
        ranged_detections = RangedPointSourceDetections.empty()
1✔
356
        use_ray = initialize_use_ray(num_cpus=max_processes)
1✔
357
        if use_ray:
1✔
358
            if isinstance(observations, ray.ObjectRef):
1✔
359
                observations_ref = observations
×
360
                observations = ray.get(observations_ref)
×
361
            else:
362
                observations_ref = ray.put(observations)
1✔
363

364
            if isinstance(ephemeris, ray.ObjectRef):
1✔
365
                ephemeris_ref = ephemeris
×
366
            else:
367
                ephemeris_ref = ray.put(ephemeris)
1✔
368

369
            # Get state IDs
370
            state_ids = observations.state_id.unique()
1✔
371
            futures = []
1✔
372
            for state_id in state_ids:
1✔
373
                futures.append(range_observations_remote.remote(observations_ref, ephemeris_ref, state_id))
1✔
374

375
                if len(futures) >= max_processes * 1.5:
1✔
376
                    finished, futures = ray.wait(futures, num_returns=1)
1✔
377
                    ranged_detections_chunk = ray.get(finished[0])
1✔
378
                    ranged_detections = qv.concatenate([ranged_detections, ranged_detections_chunk])
1✔
379
                    if ranged_detections.fragmented():
1✔
380
                        ranged_detections = qv.defragment(ranged_detections)
×
381

382
            while futures:
1✔
383
                finished, futures = ray.wait(futures, num_returns=1)
1✔
384
                ranged_detections_chunk = ray.get(finished[0])
1✔
385
                ranged_detections = qv.concatenate([ranged_detections, ranged_detections_chunk])
1✔
386
                if ranged_detections.fragmented():
1✔
387
                    ranged_detections = qv.defragment(ranged_detections)
×
388

389
        else:
390
            # Get state IDs
391
            state_ids = observations.state_id.unique()
1✔
392

393
            for state_id in state_ids:
1✔
394
                ranged_detections_chunk = range_observations_worker(
1✔
395
                    observations.select("state_id", state_id),
396
                    ephemeris.select("id", state_id),
397
                    state_id,
398
                )
399

400
                ranged_detections = qv.concatenate([ranged_detections, ranged_detections_chunk])
1✔
401
                if ranged_detections.fragmented():
1✔
402
                    ranged_detections = qv.defragment(ranged_detections)
×
403

404
        return ranged_detections.sort_by(by=["state_id"])
1✔
405

406

407
def assume_heliocentric_distance(
1✔
408
    r: np.ndarray, coords: SphericalCoordinates, origin_coords: CartesianCoordinates
409
) -> SphericalCoordinates:
410
    """
411
    Given a heliocentric distance, for all coordinates that do not have a topocentric distance defined (rho), calculate
412
    the topocentric distance assuming the coordinates are located at the given heliocentric distance.
413

414
    Parameters
415
    ----------
416
    r_mag : `~numpy.ndarray` (3)
417
        Heliocentric position vector from which to assume each coordinate lies at the same heliocentric distance.
418
        In cases where the heliocentric distance is less than the heliocentric distance of the origin, the topocentric
419
        distance will be calculated such that the topocentric position vector is closest to the heliocentric position
420
        vector.
421
    coords : `~adam_core.coordinates.spherical.SphericalCoordinates`
422
        Coordinates to assume the heliocentric distance for.
423
    origin_coords : `~adam_core.coordinates.cartesian.CartesianCoordinates`
424
        Heliocentric coordinates of the origin of the topocentric coordinates.
425

426
    Returns
427
    -------
428
    coords : `~adam_core.coordinates.spherical.SphericalCoordinates`
429
        Coordinates with the missing topocentric distance replaced with the calculated topocentric distance.
430
    """
431
    assert len(origin_coords) == 1
1✔
432
    assert np.all(origin_coords.origin == OriginCodes.SUN)
1✔
433

434
    r_mag = np.linalg.norm(r)
1✔
435

436
    # Extract the topocentric distance and topocentric radial velocity from the coordinates
437
    rho = coords.rho.to_numpy(zero_copy_only=False)
1✔
438
    vrho = coords.vrho.to_numpy(zero_copy_only=False)
1✔
439

440
    # Transform the coordinates to the ecliptic frame by assuming they lie on a unit sphere
441
    # (this assumption will only be applied to coordinates with missing rho values)
442
    coords_eq_unit = coords.to_unit_sphere(only_missing=True)
1✔
443
    coords_ec = transform_coordinates(coords_eq_unit, SphericalCoordinates, frame_out="ecliptic")
1✔
444

445
    # Transform the coordinates to cartesian and calculate the unit vectors pointing
446
    # from the origin to the coordinates
447
    coords_ec_xyz = coords_ec.to_cartesian()
1✔
448
    unit_vectors = coords_ec_xyz.r_hat
1✔
449

450
    # Calculate the topocentric distance such that the heliocentric distance to the coordinate
451
    # is r_mag
452
    dotprod = np.sum(unit_vectors * origin_coords.r, axis=1)
1✔
453
    sqrt = np.sqrt(dotprod**2 + r_mag**2 - origin_coords.r_mag**2)
1✔
454
    delta_p = -dotprod + sqrt
1✔
455
    delta_n = -dotprod - sqrt
1✔
456

457
    # Where rho was not defined, replace it with the calculated topocentric distance
458
    # By default we take the positive solution which applies for all orbits exterior to the
459
    # observer's orbit
460
    coords_ec = coords_ec.set_column("rho", np.where(np.isnan(rho), delta_p, rho))
1✔
461

462
    # For cases where the orbit is interior to the observer's orbit there are two valid solutions
463
    # for the topocentric distance. In this case, we take the dot product of the heliocentric position
464
    # vector with the calculated topocentric position vector. If the dot product is positive, then
465
    # that solution is closest to the heliocentric position vector and we take that solution.
466
    if np.any(r_mag < origin_coords.r_mag):
1✔
467
        coords_ec_xyz_p = coords_ec.to_cartesian()
1✔
468
        dotprod_p = np.sum(coords_ec_xyz_p.r * r, axis=1)
1✔
469
        coords_ec = coords_ec.set_column(
1✔
470
            "rho",
471
            np.where(np.isnan(rho), np.where(dotprod_p < 0, delta_n, delta_p), rho),
472
        )
473

474
    coords_ec = coords_ec.set_column("vrho", vrho)
1✔
475

476
    return coords_ec
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc