• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

moeyensj / thor / 6733525095

02 Nov 2023 01:52PM UTC coverage: 40.544% (+0.9%) from 39.595%
6733525095

push

github

web-flow
Merge pull request #123 from moeyensj/v2.0-link-aims-sample

Link AIMS sample

301 of 301 new or added lines in 12 files covered. (100.0%)

1878 of 4632 relevant lines covered (40.54%)

0.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.95
/thor/observations/filters.py
1
import abc
1✔
2
import logging
1✔
3
import time
1✔
4
from typing import TYPE_CHECKING, Optional, Union
1✔
5

6
import numpy as np
1✔
7
import quivr as qv
1✔
8
import ray
1✔
9
from adam_core.observations import PointSourceDetections
1✔
10

11
from ..orbit import TestOrbit, TestOrbitEphemeris
1✔
12

13
if TYPE_CHECKING:
1✔
14
    from .observations import Observations
×
15

16

17
logger = logging.getLogger(__name__)
1✔
18

19

20
def TestOrbitRadiusObservationFilter_worker(
1✔
21
    observations: "Observations",
22
    ephemeris: TestOrbitEphemeris,
23
    state_id: int,
24
    radius: float,
25
) -> "Observations":
26
    """
27
    Apply the filter to a collection of observations for a particular state.
28

29
    Parameters
30
    ----------
31
    observations : `~thor.observations.Observations`
32
        The observations to filter.
33
    ephemeris : `~thor.orbit.TestOrbitEphemeris`
34
        The ephemeris to use for filtering.
35
    state_id : int
36
        The state ID.
37
    radius : float
38
        The radius in degrees.
39

40
    Returns
41
    -------
42
    filtered_observations : `~thor.observations.Observations`
43
        The filtered observations.
44
    """
45
    # Select the ephemeris and observations for this state
46
    ephemeris_state = ephemeris.select("id", state_id)
1✔
47
    observations_state = observations.select("state_id", state_id)
1✔
48
    detections_state = observations_state.detections
1✔
49

50
    assert (
1✔
51
        len(ephemeris_state) == 1
52
    ), "there should be exactly one ephemeris per exposure"
53

54
    ephem_ra = ephemeris_state.ephemeris.coordinates.lon[0].as_py()
1✔
55
    ephem_dec = ephemeris_state.ephemeris.coordinates.lat[0].as_py()
1✔
56

57
    # Return the observations within the radius for this particular state
58
    return observations_state.apply_mask(
1✔
59
        _within_radius(detections_state, ephem_ra, ephem_dec, radius)
60
    )
61

62

63
TestOrbitRadiusObservationFilter_remote = ray.remote(
1✔
64
    TestOrbitRadiusObservationFilter_worker
65
)
66

67

68
class ObservationFilter(abc.ABC):
1✔
69
    """An ObservationFilter is reduces a collection of observations to
70
    a subset of those observations.
71

72
    """
73

74
    @abc.abstractmethod
1✔
75
    def apply(
1✔
76
        self,
77
        observations: "Observations",
78
        test_orbit: TestOrbit,
79
        max_processes: Optional[int] = 1,
80
    ) -> "Observations":
81
        """
82
        Apply the filter to a collection of observations.
83

84
        Parameters
85
        ----------
86
        observations : `~thor.observations.Observations`
87
            The observations to filter.
88
        test_orbit : `~thor.orbit.TestOrbit`
89
            The test orbit to use for filtering.
90
        max_processes : int, optional
91
            Maximum number of processes to use for parallelization. If
92
            an existing ray cluster is already running, this parameter
93
            will be ignored if larger than 1 or not None.
94

95
        Returns
96
        -------
97
        filtered_observations : `~thor.observations.Observations`
98
            The filtered observations.
99
        """
100
        ...
×
101

102

103
class TestOrbitRadiusObservationFilter(ObservationFilter):
1✔
104
    """A TestOrbitRadiusObservationFilter is an ObservationFilter that
105
    gathers observations within a fixed radius of the test orbit's
106
    ephemeris at each exposure time within a collection of exposures.
107

108
    """
109

110
    def __init__(self, radius: float):
1✔
111
        """
112
        Parameters
113
        ----------
114
        radius : float
115
            The radius in degrees.
116
        """
117
        self.radius = radius
1✔
118

119
    def apply(
1✔
120
        self,
121
        observations: Union["Observations", ray.ObjectRef],
122
        test_orbit: TestOrbit,
123
        max_processes: Optional[int] = 1,
124
    ) -> "Observations":
125
        """
126
        Apply the filter to a collection of observations.
127

128
        Parameters
129
        ----------
130
        observations : `~thor.observations.Observations`
131
            The observations to filter.
132
        test_orbit : `~thor.orbit.TestOrbit`
133
            The test orbit to use for filtering.
134
        max_processes : int, optional
135
            Maximum number of processes to use for parallelization. If
136
            an existing ray cluster is already running, this parameter
137
            will be ignored if larger than 1 or not None.
138

139
        Returns
140
        -------
141
        filtered_observations : `~thor.observations.Observations`
142
            The filtered observations. This will return a copy of the original
143
            observations.
144
        """
145
        time_start = time.perf_counter()
1✔
146
        logger.info("Applying TestOrbitRadiusObservationFilter...")
1✔
147
        logger.info(f"Using radius = {self.radius:.5f} deg")
1✔
148

149
        # Generate an ephemeris for every observer time/location in the dataset
150
        ephemeris = test_orbit.generate_ephemeris_from_observations(observations)
1✔
151

152
        filtered_observations_list = []
1✔
153
        if max_processes is None or max_processes > 1:
1✔
154

155
            if not ray.is_initialized():
×
156
                logger.debug(
×
157
                    f"Ray is not initialized. Initializing with {max_processes}..."
158
                )
159
                ray.init(num_cpus=max_processes)
×
160

161
            if isinstance(observations, ray.ObjectRef):
×
162
                observations_ref = observations
×
163
                observations = ray.get(observations_ref)
×
164
            else:
165
                observations_ref = ray.put(observations)
×
166

167
            if isinstance(ephemeris, ray.ObjectRef):
×
168
                ephemeris_ref = ephemeris
×
169
            else:
170
                ephemeris_ref = ray.put(ephemeris)
×
171

172
            state_ids = observations.state_id.unique().sort()
×
173
            futures = []
×
174
            for state_id in state_ids:
×
175
                futures.append(
×
176
                    TestOrbitRadiusObservationFilter_remote.remote(
177
                        observations_ref,
178
                        ephemeris_ref,
179
                        state_id,
180
                        self.radius,
181
                    )
182
                )
183

184
            while futures:
×
185
                finished, futures = ray.wait(futures, num_returns=1)
×
186
                filtered_observations_list.append(ray.get(finished[0]))
×
187

188
        else:
189

190
            state_ids = observations.state_id.unique().sort()
1✔
191
            for state_id in state_ids:
1✔
192
                filtered_observations = TestOrbitRadiusObservationFilter_worker(
1✔
193
                    observations,
194
                    ephemeris,
195
                    state_id,
196
                    self.radius,
197
                )
198
                filtered_observations_list.append(filtered_observations)
1✔
199

200
        observations_filtered = qv.concatenate(filtered_observations_list)
1✔
201
        observations_filtered = observations_filtered.sort_by(
1✔
202
            ["detections.time.days", "detections.time.nanos", "observatory_code"]
203
        )
204

205
        time_end = time.perf_counter()
1✔
206
        logger.info(
1✔
207
            f"Filtered {len(observations)} observations to {len(observations_filtered)} observations."
208
        )
209
        logger.info(
1✔
210
            f"TestOrbitRadiusObservationFilter completed in {time_end - time_start:.3f} seconds."
211
        )
212
        return observations_filtered
1✔
213

214

215
def _within_radius(
1✔
216
    detections: PointSourceDetections,
217
    ra: float,
218
    dec: float,
219
    radius: float,
220
) -> np.array:
221
    """
222
    Return a boolean mask that identifies which of
223
    the detections are within a given radius of a given ra and dec.
224

225
    Parameters
226
    ----------
227
    detections : `~adam_core.observations.detections.PointSourceDetections`
228
        The detections to filter.
229
    ra : float
230
        The right ascension of the center of the radius in degrees.
231
    dec : float
232
        The declination of the center of the radius in degrees.
233
    radius : float
234
        The radius in degrees.
235

236
    Returns
237
    -------
238
    mask : `~numpy.ndarray`
239
        A boolean mask that identifies which of the detections are within
240
        the radius.
241
    """
242
    det_ra = np.deg2rad(detections.ra.to_numpy())
1✔
243
    det_dec = np.deg2rad(detections.dec.to_numpy())
1✔
244

245
    center_ra = np.deg2rad(ra)
1✔
246
    center_dec = np.deg2rad(dec)
1✔
247

248
    dist_lon = det_ra - center_ra
1✔
249
    sin_dist_lon = np.sin(dist_lon)
1✔
250
    cos_dist_lon = np.cos(dist_lon)
1✔
251

252
    sin_center_lat = np.sin(center_dec)
1✔
253
    sin_det_lat = np.sin(det_dec)
1✔
254
    cos_center_lat = np.cos(center_dec)
1✔
255
    cos_det_lat = np.cos(det_dec)
1✔
256

257
    num1 = cos_det_lat * sin_dist_lon
1✔
258
    num2 = cos_center_lat * sin_det_lat - sin_center_lat * cos_det_lat * cos_dist_lon
1✔
259
    denominator = (
1✔
260
        sin_center_lat * sin_det_lat + cos_center_lat * cos_det_lat * cos_dist_lon
261
    )
262

263
    distances = np.arctan2(np.hypot(num1, num2), denominator)
1✔
264
    return distances <= np.deg2rad(radius)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc