• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

duncaneddy / astrojax / 21806998846

08 Feb 2026 11:01PM UTC coverage: 97.432%. First build
21806998846

push

github

duncaneddy
Add precommit hooks

137 of 140 new or added lines in 24 files covered. (97.86%)

2201 of 2259 relevant lines covered (97.43%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.47
/src/astrojax/orbit_dynamics/gravity.py
1
"""Gravity force models: point-mass and spherical harmonics.
2

3
Provides gravitational acceleration due to point-mass central bodies and
4
spherical harmonic gravity field models (e.g. EGM2008, GGM05S, JGM3).
5
All inputs and outputs use SI base units (metres, metres/second squared).
6

7
References:
8
    1. O. Montenbruck and E. Gill, *Satellite Orbits: Models, Methods
9
       and Applications*, 2012, p. 56-68.
10
"""
11

12
from __future__ import annotations
1✔
13

14
import importlib.resources
1✔
15
import math
1✔
16
from pathlib import Path
1✔
17

18
import jax.numpy as jnp
1✔
19
import numpy as np
1✔
20
from jax import Array
1✔
21
from jax.typing import ArrayLike
1✔
22

23
from astrojax.config import get_dtype
1✔
24
from astrojax.constants import GM_EARTH
1✔
25

26
# ---------------------------------------------------------------------------
27
# Point-mass gravity (existing)
28
# ---------------------------------------------------------------------------
29

30

31
def accel_point_mass(
1✔
32
    r_object: ArrayLike,
33
    r_body: ArrayLike,
34
    gm: float,
35
) -> Array:
36
    """Acceleration due to point-mass gravity.
37

38
    Computes the gravitational acceleration on *r_object* due to a body
39
    at *r_body* with gravitational parameter *gm*.  When the central body
40
    is at the origin (``r_body = [0, 0, 0]``), the standard two-body
41
    expression ``-gm * r / |r|^3`` is used.  Otherwise the indirect
42
    (third-body) form is applied.
43

44
    Args:
45
        r_object: Position of the object [m].  Shape ``(3,)`` or ``(6,)``
46
            (only first 3 elements used).
47
        r_body: Position of the attracting body [m].  Shape ``(3,)``.
48
        gm: Gravitational parameter of the attracting body [m^3/s^2].
49

50
    Returns:
51
        Acceleration vector [m/s^2], shape ``(3,)``.
52

53
    Examples:
54
        ```python
55
        import jax.numpy as jnp
56
        from astrojax.constants import R_EARTH, GM_EARTH
57
        from astrojax.orbit_dynamics import accel_point_mass
58
        r = jnp.array([R_EARTH, 0.0, 0.0])
59
        a = accel_point_mass(r, jnp.zeros(3), GM_EARTH)
60
        ```
61
    """
62
    _float = get_dtype()
1✔
63
    r_obj = jnp.asarray(r_object, dtype=_float)[:3]
1✔
64
    r_cb = jnp.asarray(r_body, dtype=_float)
1✔
65

66
    d = r_obj - r_cb
1✔
67
    d_norm = jnp.linalg.norm(d)
1✔
68
    r_cb_norm = jnp.linalg.norm(r_cb)
1✔
69

70
    # Third-body form (r_body != 0): -gm * (d/|d|^3 + r_body/|r_body|^3)
71
    # Central-body form (r_body = 0): -gm * d/|d|^3
72
    a_third = -gm * (d / d_norm**3 + r_cb / r_cb_norm**3)
1✔
73
    a_central = -gm * d / d_norm**3
1✔
74

75
    return jnp.where(r_cb_norm > _float(0.0), a_third, a_central)
1✔
76

77

78
def accel_gravity(r_object: ArrayLike) -> Array:
1✔
79
    """Acceleration due to Earth's point-mass gravity.
80

81
    Convenience wrapper for :func:`accel_point_mass` with Earth's
82
    gravitational parameter and the central body at the origin.
83

84
    Args:
85
        r_object: Position of the object in ECI [m].  Shape ``(3,)`` or
86
            ``(6,)`` (only first 3 elements used).
87

88
    Returns:
89
        Acceleration vector [m/s^2], shape ``(3,)``.
90

91
    Examples:
92
        ```python
93
        import jax.numpy as jnp
94
        from astrojax.constants import R_EARTH
95
        from astrojax.orbit_dynamics import accel_gravity
96
        r = jnp.array([R_EARTH, 0.0, 0.0])
97
        a = accel_gravity(r)
98
        ```
99
    """
100
    return accel_point_mass(r_object, jnp.zeros(3, dtype=get_dtype()), GM_EARTH)
1✔
101

102

103
# ---------------------------------------------------------------------------
104
# Gravity model data types
105
# ---------------------------------------------------------------------------
106

107
_PACKAGED_MODELS = {
1✔
108
    "EGM2008_360": "EGM2008_360.gfc",
109
    "GGM05S": "GGM05S.gfc",
110
    "JGM3": "JGM3.gfc",
111
}
112

113

114
class GravityModel:
1✔
115
    """Spherical harmonic gravity field model.
116

117
    Stores Stokes coefficients (C_nm, S_nm) parsed from ICGEM GFC format
118
    files.  The coefficient matrix layout follows the Montenbruck & Gill
119
    convention:
120

121
    - ``data[n, m]`` stores the C coefficient for degree *n*, order *m*
122
    - ``data[m-1, n]`` stores the S coefficient for *m* > 0
123

124
    This is a plain Python class (not a JAX pytree) since it holds static
125
    configuration data that does not participate in differentiation.
126

127
    Args:
128
        model_name: Human-readable name of the gravity model.
129
        gm: Gravitational parameter [m^3/s^2].
130
        radius: Reference radius [m].
131
        n_max: Maximum degree of the model.
132
        m_max: Maximum order of the model.
133
        data: Coefficient matrix, shape ``(n_max+1, m_max+1)``.
134
        tide_system: Tide system convention (e.g. ``"tide_free"``).
135
        normalization: Normalization convention (e.g. ``"fully_normalized"``).
136

137
    Examples:
138
        ```python
139
        from astrojax.orbit_dynamics.gravity import GravityModel
140
        model = GravityModel.from_type("JGM3")
141
        c20, s20 = model.get(2, 0)
142
        ```
143
    """
144

145
    def __init__(
1✔
146
        self,
147
        model_name: str,
148
        gm: float,
149
        radius: float,
150
        n_max: int,
151
        m_max: int,
152
        data: np.ndarray,
153
        tide_system: str = "unknown",
154
        normalization: str = "fully_normalized",
155
    ):
156
        self.model_name = model_name
1✔
157
        self.gm = gm
1✔
158
        self.radius = radius
1✔
159
        self.n_max = n_max
1✔
160
        self.m_max = m_max
1✔
161
        self.data = data
1✔
162
        self.tide_system = tide_system
1✔
163
        self.normalization = normalization
1✔
164

165
    @property
1✔
166
    def is_normalized(self) -> bool:
1✔
167
        """Whether the coefficients are fully normalized."""
168
        return self.normalization == "fully_normalized"
1✔
169

170
    # ------------------------------------------------------------------
171
    # Factory methods
172
    # ------------------------------------------------------------------
173

174
    @classmethod
1✔
175
    def from_file(cls, filepath: str | Path) -> GravityModel:
1✔
176
        """Load a gravity model from a GFC format file.
177

178
        Args:
179
            filepath: Path to the ``.gfc`` file.
180

181
        Returns:
182
            GravityModel: Loaded gravity model.
183

184
        Raises:
185
            FileNotFoundError: If the file does not exist.
186
            ValueError: If required header fields are missing.
187
        """
188
        filepath = Path(filepath)
1✔
189
        if not filepath.exists():
1✔
190
            raise FileNotFoundError(f"Gravity model file not found: {filepath}")
1✔
191
        with open(filepath) as f:
1✔
192
            return cls._parse_gfc(f)
1✔
193

194
    @classmethod
1✔
195
    def from_type(cls, model_type: str) -> GravityModel:
1✔
196
        """Load a packaged gravity model by name.
197

198
        Available models:
199

200
        - ``"EGM2008_360"`` — truncated 360x360 EGM2008
201
        - ``"GGM05S"`` — full 180x180 GGM05S
202
        - ``"JGM3"`` — full 70x70 JGM3
203

204
        Args:
205
            model_type: One of the packaged model names.
206

207
        Returns:
208
            GravityModel: Loaded gravity model.
209

210
        Raises:
211
            ValueError: If the model type is not recognized.
212
        """
213
        if model_type not in _PACKAGED_MODELS:
1✔
214
            raise ValueError(
1✔
215
                f"Unknown gravity model type: {model_type!r}. "
216
                f"Available: {list(_PACKAGED_MODELS.keys())}"
217
            )
218
        filename = _PACKAGED_MODELS[model_type]
1✔
219
        data_pkg = importlib.resources.files("astrojax.data.gravity_models")
1✔
220
        resource = data_pkg.joinpath(filename)
1✔
221
        with importlib.resources.as_file(resource) as path:
1✔
222
            return cls.from_file(path)
1✔
223

224
    # ------------------------------------------------------------------
225
    # Coefficient access
226
    # ------------------------------------------------------------------
227

228
    def get(self, n: int, m: int) -> tuple[float, float]:
1✔
229
        """Retrieve the (C_nm, S_nm) coefficients for degree *n*, order *m*.
230

231
        Args:
232
            n: Degree of the harmonic.
233
            m: Order of the harmonic.
234

235
        Returns:
236
            tuple[float, float]: (C_nm, S_nm) coefficient pair.
237

238
        Raises:
239
            ValueError: If (n, m) exceeds the model bounds.
240
        """
241
        if n > self.n_max or m > self.m_max:
1✔
242
            raise ValueError(
1✔
243
                f"Requested (n={n}, m={m}) exceeds model bounds "
244
                f"(n_max={self.n_max}, m_max={self.m_max})."
245
            )
246
        if m == 0:
1✔
247
            return float(self.data[n, m]), 0.0
1✔
248
        return float(self.data[n, m]), float(self.data[m - 1, n])
1✔
249

250
    # ------------------------------------------------------------------
251
    # Model truncation
252
    # ------------------------------------------------------------------
253

254
    def set_max_degree_order(self, n: int, m: int) -> None:
1✔
255
        """Truncate the model to a smaller degree and order.
256

257
        Coefficients beyond the new limits are discarded.  This is
258
        irreversible.
259

260
        Args:
261
            n: New maximum degree (must be <= current ``n_max``).
262
            m: New maximum order (must be <= *n* and <= current ``m_max``).
263

264
        Raises:
265
            ValueError: If validation fails.
266
        """
267
        if m > n:
1✔
NEW
268
            raise ValueError(f"Maximum order (m={m}) cannot exceed maximum degree (n={n}).")
×
269
        if n > self.n_max:
1✔
NEW
270
            raise ValueError(f"Requested degree (n={n}) exceeds model's n_max={self.n_max}.")
×
271
        if m > self.m_max:
1✔
NEW
272
            raise ValueError(f"Requested order (m={m}) exceeds model's m_max={self.m_max}.")
×
273
        if n == self.n_max and m == self.m_max:
1✔
274
            return
×
275

276
        new_size = n + 1
1✔
277
        self.data = self.data[:new_size, :new_size].copy()
1✔
278
        self.n_max = n
1✔
279
        self.m_max = m
1✔
280

281
    # ------------------------------------------------------------------
282
    # GFC parser
283
    # ------------------------------------------------------------------
284

285
    @classmethod
1✔
286
    def _parse_gfc(cls, fileobj) -> GravityModel:
1✔
287
        """Parse an ICGEM GFC format file.
288

289
        Args:
290
            fileobj: File-like object with GFC content.
291

292
        Returns:
293
            GravityModel: Parsed gravity model.
294
        """
295
        model_name = "Unknown"
1✔
296
        gm = 0.0
1✔
297
        radius = 0.0
1✔
298
        n_max = 0
1✔
299
        m_max = 0
1✔
300
        tide_system = "unknown"
1✔
301
        normalization = "fully_normalized"
1✔
302

303
        # Read header
304
        in_header = True
1✔
305
        lines = iter(fileobj)
1✔
306
        for line in lines:
1✔
307
            line = line.strip()
1✔
308
            if line.startswith("end_of_head"):
1✔
309
                in_header = False
1✔
310
                break
1✔
311

312
            parts = line.split()
1✔
313
            if len(parts) < 2:
1✔
314
                continue
1✔
315

316
            key = parts[0].lower()
1✔
317
            value = parts[-1]
1✔
318

319
            if key == "modelname":
1✔
320
                model_name = value
1✔
321
            elif key == "earth_gravity_constant":
1✔
322
                gm = float(value.replace("D", "e").replace("d", "e"))
1✔
323
            elif key == "radius":
1✔
324
                radius = float(value.replace("D", "e").replace("d", "e"))
1✔
325
            elif key == "max_degree":
1✔
326
                n_max = int(value)
1✔
327
                m_max = n_max
1✔
328
            elif key == "tide_system":
1✔
329
                tide_system = value
1✔
330
            elif key in ("errors",):
1✔
331
                pass  # Stored but not used
1✔
332
            elif key in ("norm", "normalization"):
1✔
333
                normalization = value
1✔
334

335
        if in_header:
1✔
336
            raise ValueError("GFC file missing 'end_of_head' marker.")
×
337
        if gm == 0.0:
1✔
338
            raise ValueError("GFC header missing 'earth_gravity_constant'.")
×
339
        if radius == 0.0:
1✔
340
            raise ValueError("GFC header missing 'radius'.")
×
341
        if n_max == 0:
1✔
342
            raise ValueError("GFC header missing 'max_degree'.")
×
343

344
        # Read coefficient data
345
        data = np.zeros((n_max + 1, m_max + 1), dtype=np.float64)
1✔
346

347
        for line in lines:
1✔
348
            line = line.strip()
1✔
349
            if not line or not line.startswith("gfc"):
1✔
350
                continue
×
351

352
            # Replace Fortran-style D/d exponent notation
353
            line = line.replace("D", "e").replace("d", "e")
1✔
354
            parts = line.split()
1✔
355

356
            # gfc  n  m  C  S  [sig_C  sig_S]
357
            n = int(parts[1])
1✔
358
            m = int(parts[2])
1✔
359
            c = float(parts[3])
1✔
360
            s = float(parts[4])
1✔
361

362
            if n <= n_max and m <= m_max:
1✔
363
                data[n, m] = c
1✔
364
                if m > 0:
1✔
365
                    data[m - 1, n] = s
1✔
366

367
        return cls(
1✔
368
            model_name=model_name,
369
            gm=gm,
370
            radius=radius,
371
            n_max=n_max,
372
            m_max=m_max,
373
            data=data,
374
            tide_system=tide_system,
375
            normalization=normalization,
376
        )
377

378
    def __repr__(self) -> str:
1✔
379
        return (
×
380
            f"GravityModel(name={self.model_name!r}, "
381
            f"n_max={self.n_max}, m_max={self.m_max}, "
382
            f"gm={self.gm:.6e}, radius={self.radius:.1f})"
383
        )
384

385

386
# ---------------------------------------------------------------------------
387
# Spherical harmonic gravity acceleration
388
# ---------------------------------------------------------------------------
389

390

391
def _factorial_product(n: int, m: int) -> float:
1✔
392
    """Compute (n-m)!/(n+m)! efficiently without full factorials.
393

394
    Args:
395
        n: Degree.
396
        m: Order.
397

398
    Returns:
399
        float: The factorial ratio.
400
    """
401
    p = 1.0
1✔
402
    for i in range(n - m + 1, n + m + 1):
1✔
403
        p /= i
1✔
404
    return p
1✔
405

406

407
def accel_gravity_spherical_harmonics(
1✔
408
    r_eci: ArrayLike,
409
    R_eci_to_ecef: ArrayLike,
410
    gravity_model: GravityModel,
411
    n_max: int,
412
    m_max: int,
413
) -> Array:
414
    """Acceleration from spherical harmonic gravity field expansion.
415

416
    Computes the gravitational acceleration using recursively-computed
417
    associated Legendre functions (V/W matrix method).  The position is
418
    transformed to the body-fixed frame, the acceleration is computed
419
    there, and transformed back to ECI.
420

421
    Args:
422
        r_eci: Position in ECI frame [m].  Shape ``(3,)`` or ``(6,)``
423
            (only first 3 elements used).
424
        R_eci_to_ecef: Rotation matrix from ECI to ECEF, shape ``(3, 3)``.
425
        gravity_model: Loaded gravity model with Stokes coefficients.
426
        n_max: Maximum degree for evaluation (must be <= model's n_max).
427
        m_max: Maximum order for evaluation (must be <= n_max and
428
            <= model's m_max).
429

430
    Returns:
431
        Acceleration in ECI frame [m/s^2], shape ``(3,)``.
432

433
    Examples:
434
        ```python
435
        import jax.numpy as jnp
436
        from astrojax.orbit_dynamics.gravity import (
437
            GravityModel, accel_gravity_spherical_harmonics,
438
        )
439
        model = GravityModel.from_type("JGM3")
440
        r = jnp.array([6878e3, 0.0, 0.0])
441
        R = jnp.eye(3)
442
        a = accel_gravity_spherical_harmonics(r, R, model, 20, 20)
443
        ```
444
    """
445
    _float = get_dtype()
1✔
446
    r = jnp.asarray(r_eci, dtype=_float)[:3]
1✔
447
    R = jnp.asarray(R_eci_to_ecef, dtype=_float)
1✔
448

449
    # Transform to body-fixed frame
450
    r_bf = R @ r
1✔
451

452
    # Convert model data to JAX array
453
    CS = jnp.asarray(gravity_model.data, dtype=_float)
1✔
454

455
    # Compute acceleration in body-fixed frame
456
    a_bf = _compute_spherical_harmonics(
1✔
457
        r_bf,
458
        CS,
459
        n_max,
460
        m_max,
461
        gravity_model.radius,
462
        gravity_model.gm,
463
        gravity_model.is_normalized,
464
    )
465

466
    # Transform back to ECI
467
    return R.T @ a_bf
1✔
468

469

470
def _compute_spherical_harmonics(
1✔
471
    r_bf: Array,
472
    CS: Array,
473
    n_max: int,
474
    m_max: int,
475
    r_ref: float,
476
    gm: float,
477
    is_normalized: bool,
478
) -> Array:
479
    """Core V/W recursion for spherical harmonic gravity.
480

481
    Implements the algorithm from Montenbruck & Gill (2012), p. 56-68.
482
    Uses Python loops (traced by JAX) rather than ``jax.lax.fori_loop``
483
    for clarity.  The ``n_max`` and ``m_max`` parameters are static
484
    Python ints, so changing them triggers recompilation.
485

486
    Args:
487
        r_bf: Position in body-fixed frame [m], shape ``(3,)``.
488
        CS: Coefficient matrix from GravityModel, shape ``(N+1, M+1)``.
489
        n_max: Maximum degree for evaluation.
490
        m_max: Maximum order for evaluation.
491
        r_ref: Reference radius [m].
492
        gm: Gravitational parameter [m^3/s^2].
493
        is_normalized: Whether coefficients are fully normalized.
494

495
    Returns:
496
        Acceleration in body-fixed frame [m/s^2], shape ``(3,)``.
497
    """
498
    # Auxiliary quantities
499
    r_sqr = jnp.dot(r_bf, r_bf)
1✔
500
    rho = r_ref * r_ref / r_sqr
1✔
501

502
    # Normalized coordinates
503
    x0 = r_ref * r_bf[0] / r_sqr
1✔
504
    y0 = r_ref * r_bf[1] / r_sqr
1✔
505
    z0 = r_ref * r_bf[2] / r_sqr
1✔
506

507
    # V and W intermediary matrices
508
    size = n_max + 2
1✔
509
    V = jnp.zeros((size, size), dtype=r_bf.dtype)
1✔
510
    W = jnp.zeros((size, size), dtype=r_bf.dtype)
1✔
511

512
    # Zonal terms V(n,0); W(n,0) = 0
513
    V = V.at[0, 0].set(r_ref / jnp.sqrt(r_sqr))
1✔
514
    V = V.at[1, 0].set(z0 * V[0, 0])
1✔
515

516
    for n in range(2, n_max + 2):
1✔
517
        nf = float(n)
1✔
518
        V = V.at[n, 0].set(
1✔
519
            ((2.0 * nf - 1.0) * z0 * V[n - 1, 0] - (nf - 1.0) * rho * V[n - 2, 0]) / nf
520
        )
521

522
    # Tesseral and sectorial terms
523
    for m in range(1, m_max + 2):
1✔
524
        mf = float(m)
1✔
525
        V = V.at[m, m].set((2.0 * mf - 1.0) * (x0 * V[m - 1, m - 1] - y0 * W[m - 1, m - 1]))
1✔
526
        W = W.at[m, m].set((2.0 * mf - 1.0) * (x0 * W[m - 1, m - 1] + y0 * V[m - 1, m - 1]))
1✔
527

528
        if m <= n_max:
1✔
529
            V = V.at[m + 1, m].set((2.0 * mf + 1.0) * z0 * V[m, m])
1✔
530
            W = W.at[m + 1, m].set((2.0 * mf + 1.0) * z0 * W[m, m])
1✔
531

532
        for n in range(m + 2, n_max + 2):
1✔
533
            nf = float(n)
1✔
534
            V = V.at[n, m].set(
1✔
535
                ((2.0 * nf - 1.0) * z0 * V[n - 1, m] - (nf + mf - 1.0) * rho * V[n - 2, m])
536
                / (nf - mf)
537
            )
538
            W = W.at[n, m].set(
1✔
539
                ((2.0 * nf - 1.0) * z0 * W[n - 1, m] - (nf + mf - 1.0) * rho * W[n - 2, m])
540
                / (nf - mf)
541
            )
542

543
    # Accumulate accelerations
544
    ax = jnp.float64(0.0) if r_bf.dtype == jnp.float64 else jnp.float32(0.0)
1✔
545
    ay = ax
1✔
546
    az = ax
1✔
547

548
    for m in range(m_max + 1):
1✔
549
        mf = float(m)
1✔
550
        for n in range(m, n_max + 1):
1✔
551
            nf = float(n)
1✔
552
            if m == 0:
1✔
553
                # Denormalize if needed
554
                if is_normalized:
1✔
555
                    N = math.sqrt(2.0 * nf + 1.0)
1✔
556
                    C = N * CS[n, 0]
1✔
557
                else:
558
                    C = CS[n, 0]
×
559

560
                ax = ax - C * V[n + 1, 1]
1✔
561
                ay = ay - C * W[n + 1, 1]
1✔
562
                az = az - (nf + 1.0) * C * V[n + 1, 0]
1✔
563
            else:
564
                # Denormalize if needed
565
                if is_normalized:
1✔
566
                    kron = 0.0 if m != 0 else 1.0
1✔
567
                    N = math.sqrt((2.0 - kron) * (2.0 * nf + 1.0) * _factorial_product(n, m))
1✔
568
                    C = N * CS[n, m]
1✔
569
                    S = N * CS[m - 1, n]
1✔
570
                else:
571
                    C = CS[n, m]
×
572
                    S = CS[m - 1, n]
×
573

574
                Fac = 0.5 * (nf - mf + 1.0) * (nf - mf + 2.0)
1✔
575
                ax = ax + (
1✔
576
                    0.5 * (-C * V[n + 1, m + 1] - S * W[n + 1, m + 1])
577
                    + Fac * (C * V[n + 1, m - 1] + S * W[n + 1, m - 1])
578
                )
579
                ay = ay + (
1✔
580
                    0.5 * (-C * W[n + 1, m + 1] + S * V[n + 1, m + 1])
581
                    + Fac * (-C * W[n + 1, m - 1] + S * V[n + 1, m - 1])
582
                )
583
                az = az + (nf - mf + 1.0) * (-C * V[n + 1, m] - S * W[n + 1, m])
1✔
584

585
    # Scale by GM/R_ref^2
586
    scale = gm / (r_ref * r_ref)
1✔
587
    return scale * jnp.array([ax, ay, az])
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc