• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

duncaneddy / astrojax / 22021608966

14 Feb 2026 05:43PM UTC coverage: 95.439% (-2.3%) from 97.722%
22021608966

push

github

duncaneddy
Update logo, test additional platforms

5085 of 5328 relevant lines covered (95.44%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.5
/src/astrojax/orbit_dynamics/gravity.py
1
"""Gravity force models: point-mass and spherical harmonics.
2

3
Provides gravitational acceleration due to point-mass central bodies and
4
spherical harmonic gravity field models (e.g. EGM2008, GGM05S, JGM3).
5
All inputs and outputs use SI base units (metres, metres/second squared).
6

7
References:
8
    1. O. Montenbruck and E. Gill, *Satellite Orbits: Models, Methods
9
       and Applications*, 2012, p. 56-68.
10
"""
11

12
from __future__ import annotations
1✔
13

14
import importlib.resources
1✔
15
import math
1✔
16
from pathlib import Path
1✔
17

18
import jax
1✔
19
import jax.numpy as jnp
1✔
20
import numpy as np
1✔
21
from jax import Array
1✔
22
from jax.typing import ArrayLike
1✔
23

24
from astrojax.config import get_dtype
1✔
25
from astrojax.constants import GM_EARTH
1✔
26

27
# ---------------------------------------------------------------------------
28
# Point-mass gravity (existing)
29
# ---------------------------------------------------------------------------
30

31

32
def accel_point_mass(
1✔
33
    r_object: ArrayLike,
34
    r_body: ArrayLike,
35
    gm: float,
36
) -> Array:
37
    """Acceleration due to point-mass gravity.
38

39
    Computes the gravitational acceleration on *r_object* due to a body
40
    at *r_body* with gravitational parameter *gm*.  When the central body
41
    is at the origin (``r_body = [0, 0, 0]``), the standard two-body
42
    expression ``-gm * r / |r|^3`` is used.  Otherwise the indirect
43
    (third-body) form is applied.
44

45
    Args:
46
        r_object: Position of the object [m].  Shape ``(3,)`` or ``(6,)``
47
            (only first 3 elements used).
48
        r_body: Position of the attracting body [m].  Shape ``(3,)``.
49
        gm: Gravitational parameter of the attracting body [m^3/s^2].
50

51
    Returns:
52
        Acceleration vector [m/s^2], shape ``(3,)``.
53

54
    Examples:
55
        ```python
56
        import jax.numpy as jnp
57
        from astrojax.constants import R_EARTH, GM_EARTH
58
        from astrojax.orbit_dynamics import accel_point_mass
59
        r = jnp.array([R_EARTH, 0.0, 0.0])
60
        a = accel_point_mass(r, jnp.zeros(3), GM_EARTH)
61
        ```
62
    """
63
    _float = get_dtype()
1✔
64
    r_obj = jnp.asarray(r_object, dtype=_float)[:3]
1✔
65
    r_cb = jnp.asarray(r_body, dtype=_float)
1✔
66

67
    d = r_obj - r_cb
1✔
68
    d_norm = jnp.linalg.norm(d)
1✔
69
    r_cb_norm = jnp.linalg.norm(r_cb)
1✔
70

71
    # Third-body form (r_body != 0): -gm * (d/|d|^3 + r_body/|r_body|^3)
72
    # Central-body form (r_body = 0): -gm * d/|d|^3
73
    a_third = -gm * (d / d_norm**3 + r_cb / r_cb_norm**3)
1✔
74
    a_central = -gm * d / d_norm**3
1✔
75

76
    return jnp.where(r_cb_norm > _float(0.0), a_third, a_central)
1✔
77

78

79
def accel_gravity(r_object: ArrayLike) -> Array:
1✔
80
    """Acceleration due to Earth's point-mass gravity.
81

82
    Convenience wrapper for :func:`accel_point_mass` with Earth's
83
    gravitational parameter and the central body at the origin.
84

85
    Args:
86
        r_object: Position of the object in ECI [m].  Shape ``(3,)`` or
87
            ``(6,)`` (only first 3 elements used).
88

89
    Returns:
90
        Acceleration vector [m/s^2], shape ``(3,)``.
91

92
    Examples:
93
        ```python
94
        import jax.numpy as jnp
95
        from astrojax.constants import R_EARTH
96
        from astrojax.orbit_dynamics import accel_gravity
97
        r = jnp.array([R_EARTH, 0.0, 0.0])
98
        a = accel_gravity(r)
99
        ```
100
    """
101
    return accel_point_mass(r_object, jnp.zeros(3, dtype=get_dtype()), GM_EARTH)
1✔
102

103

104
# ---------------------------------------------------------------------------
105
# Gravity model data types
106
# ---------------------------------------------------------------------------
107

108
_PACKAGED_MODELS = {
1✔
109
    "EGM2008_360": "EGM2008_360.gfc",
110
    "GGM05S": "GGM05S.gfc",
111
    "JGM3": "JGM3.gfc",
112
}
113

114

115
class GravityModel:
1✔
116
    """Spherical harmonic gravity field model.
117

118
    Stores Stokes coefficients (C_nm, S_nm) parsed from ICGEM GFC format
119
    files.  The coefficient matrix layout follows the Montenbruck & Gill
120
    convention:
121

122
    - ``data[n, m]`` stores the C coefficient for degree *n*, order *m*
123
    - ``data[m-1, n]`` stores the S coefficient for *m* > 0
124

125
    This is a plain Python class (not a JAX pytree) since it holds static
126
    configuration data that does not participate in differentiation.
127

128
    Args:
129
        model_name: Human-readable name of the gravity model.
130
        gm: Gravitational parameter [m^3/s^2].
131
        radius: Reference radius [m].
132
        n_max: Maximum degree of the model.
133
        m_max: Maximum order of the model.
134
        data: Coefficient matrix, shape ``(n_max+1, m_max+1)``.
135
        tide_system: Tide system convention (e.g. ``"tide_free"``).
136
        normalization: Normalization convention (e.g. ``"fully_normalized"``).
137

138
    Examples:
139
        ```python
140
        from astrojax.orbit_dynamics.gravity import GravityModel
141
        model = GravityModel.from_type("JGM3")
142
        c20, s20 = model.get(2, 0)
143
        ```
144
    """
145

146
    def __init__(
1✔
147
        self,
148
        model_name: str,
149
        gm: float,
150
        radius: float,
151
        n_max: int,
152
        m_max: int,
153
        data: np.ndarray,
154
        tide_system: str = "unknown",
155
        normalization: str = "fully_normalized",
156
    ):
157
        self.model_name = model_name
1✔
158
        self.gm = gm
1✔
159
        self.radius = radius
1✔
160
        self.n_max = n_max
1✔
161
        self.m_max = m_max
1✔
162
        self.data = data
1✔
163
        self.tide_system = tide_system
1✔
164
        self.normalization = normalization
1✔
165

166
    @property
1✔
167
    def is_normalized(self) -> bool:
1✔
168
        """Whether the coefficients are fully normalized."""
169
        return self.normalization == "fully_normalized"
1✔
170

171
    # ------------------------------------------------------------------
172
    # Factory methods
173
    # ------------------------------------------------------------------
174

175
    @classmethod
1✔
176
    def from_file(cls, filepath: str | Path) -> GravityModel:
1✔
177
        """Load a gravity model from a GFC format file.
178

179
        Args:
180
            filepath: Path to the ``.gfc`` file.
181

182
        Returns:
183
            GravityModel: Loaded gravity model.
184

185
        Raises:
186
            FileNotFoundError: If the file does not exist.
187
            ValueError: If required header fields are missing.
188
        """
189
        filepath = Path(filepath)
1✔
190
        if not filepath.exists():
1✔
191
            raise FileNotFoundError(f"Gravity model file not found: {filepath}")
1✔
192
        with open(filepath) as f:
1✔
193
            return cls._parse_gfc(f)
1✔
194

195
    @classmethod
1✔
196
    def from_type(cls, model_type: str) -> GravityModel:
1✔
197
        """Load a packaged gravity model by name.
198

199
        Available models:
200

201
        - ``"EGM2008_360"`` — truncated 360x360 EGM2008
202
        - ``"GGM05S"`` — full 180x180 GGM05S
203
        - ``"JGM3"`` — full 70x70 JGM3
204

205
        Args:
206
            model_type: One of the packaged model names.
207

208
        Returns:
209
            GravityModel: Loaded gravity model.
210

211
        Raises:
212
            ValueError: If the model type is not recognized.
213
        """
214
        if model_type not in _PACKAGED_MODELS:
1✔
215
            raise ValueError(
1✔
216
                f"Unknown gravity model type: {model_type!r}. "
217
                f"Available: {list(_PACKAGED_MODELS.keys())}"
218
            )
219
        filename = _PACKAGED_MODELS[model_type]
1✔
220
        data_pkg = importlib.resources.files("astrojax.data.gravity_models")
1✔
221
        resource = data_pkg.joinpath(filename)
1✔
222
        with importlib.resources.as_file(resource) as path:
1✔
223
            return cls.from_file(path)
1✔
224

225
    # ------------------------------------------------------------------
226
    # Coefficient access
227
    # ------------------------------------------------------------------
228

229
    def get(self, n: int, m: int) -> tuple[float, float]:
1✔
230
        """Retrieve the (C_nm, S_nm) coefficients for degree *n*, order *m*.
231

232
        Args:
233
            n: Degree of the harmonic.
234
            m: Order of the harmonic.
235

236
        Returns:
237
            tuple[float, float]: (C_nm, S_nm) coefficient pair.
238

239
        Raises:
240
            ValueError: If (n, m) exceeds the model bounds.
241
        """
242
        if n > self.n_max or m > self.m_max:
1✔
243
            raise ValueError(
1✔
244
                f"Requested (n={n}, m={m}) exceeds model bounds "
245
                f"(n_max={self.n_max}, m_max={self.m_max})."
246
            )
247
        if m == 0:
1✔
248
            return float(self.data[n, m]), 0.0
1✔
249
        return float(self.data[n, m]), float(self.data[m - 1, n])
1✔
250

251
    # ------------------------------------------------------------------
252
    # Model truncation
253
    # ------------------------------------------------------------------
254

255
    def set_max_degree_order(self, n: int, m: int) -> None:
1✔
256
        """Truncate the model to a smaller degree and order.
257

258
        Coefficients beyond the new limits are discarded.  This is
259
        irreversible.
260

261
        Args:
262
            n: New maximum degree (must be <= current ``n_max``).
263
            m: New maximum order (must be <= *n* and <= current ``m_max``).
264

265
        Raises:
266
            ValueError: If validation fails.
267
        """
268
        if m > n:
1✔
269
            raise ValueError(f"Maximum order (m={m}) cannot exceed maximum degree (n={n}).")
×
270
        if n > self.n_max:
1✔
271
            raise ValueError(f"Requested degree (n={n}) exceeds model's n_max={self.n_max}.")
×
272
        if m > self.m_max:
1✔
273
            raise ValueError(f"Requested order (m={m}) exceeds model's m_max={self.m_max}.")
×
274
        if n == self.n_max and m == self.m_max:
1✔
275
            return
×
276

277
        new_size = n + 1
1✔
278
        self.data = self.data[:new_size, :new_size].copy()
1✔
279
        self.n_max = n
1✔
280
        self.m_max = m
1✔
281

282
    # ------------------------------------------------------------------
283
    # GFC parser
284
    # ------------------------------------------------------------------
285

286
    @classmethod
1✔
287
    def _parse_gfc(cls, fileobj) -> GravityModel:
1✔
288
        """Parse an ICGEM GFC format file.
289

290
        Args:
291
            fileobj: File-like object with GFC content.
292

293
        Returns:
294
            GravityModel: Parsed gravity model.
295
        """
296
        model_name = "Unknown"
1✔
297
        gm = 0.0
1✔
298
        radius = 0.0
1✔
299
        n_max = 0
1✔
300
        m_max = 0
1✔
301
        tide_system = "unknown"
1✔
302
        normalization = "fully_normalized"
1✔
303

304
        # Read header
305
        in_header = True
1✔
306
        lines = iter(fileobj)
1✔
307
        for line in lines:
1✔
308
            line = line.strip()
1✔
309
            if line.startswith("end_of_head"):
1✔
310
                in_header = False
1✔
311
                break
1✔
312

313
            parts = line.split()
1✔
314
            if len(parts) < 2:
1✔
315
                continue
1✔
316

317
            key = parts[0].lower()
1✔
318
            value = parts[-1]
1✔
319

320
            if key == "modelname":
1✔
321
                model_name = value
1✔
322
            elif key == "earth_gravity_constant":
1✔
323
                gm = float(value.replace("D", "e").replace("d", "e"))
1✔
324
            elif key == "radius":
1✔
325
                radius = float(value.replace("D", "e").replace("d", "e"))
1✔
326
            elif key == "max_degree":
1✔
327
                n_max = int(value)
1✔
328
                m_max = n_max
1✔
329
            elif key == "tide_system":
1✔
330
                tide_system = value
1✔
331
            elif key in ("errors",):
1✔
332
                pass  # Stored but not used
1✔
333
            elif key in ("norm", "normalization"):
1✔
334
                normalization = value
1✔
335

336
        if in_header:
1✔
337
            raise ValueError("GFC file missing 'end_of_head' marker.")
×
338
        if gm == 0.0:
1✔
339
            raise ValueError("GFC header missing 'earth_gravity_constant'.")
×
340
        if radius == 0.0:
1✔
341
            raise ValueError("GFC header missing 'radius'.")
×
342
        if n_max == 0:
1✔
343
            raise ValueError("GFC header missing 'max_degree'.")
×
344

345
        # Read coefficient data
346
        data = np.zeros((n_max + 1, m_max + 1), dtype=np.float64)
1✔
347

348
        for line in lines:
1✔
349
            line = line.strip()
1✔
350
            if not line or not line.startswith("gfc"):
1✔
351
                continue
×
352

353
            # Replace Fortran-style D/d exponent notation
354
            line = line.replace("D", "e").replace("d", "e")
1✔
355
            parts = line.split()
1✔
356

357
            # gfc  n  m  C  S  [sig_C  sig_S]
358
            n = int(parts[1])
1✔
359
            m = int(parts[2])
1✔
360
            c = float(parts[3])
1✔
361
            s = float(parts[4])
1✔
362

363
            if n <= n_max and m <= m_max:
1✔
364
                data[n, m] = c
1✔
365
                if m > 0:
1✔
366
                    data[m - 1, n] = s
1✔
367

368
        return cls(
1✔
369
            model_name=model_name,
370
            gm=gm,
371
            radius=radius,
372
            n_max=n_max,
373
            m_max=m_max,
374
            data=data,
375
            tide_system=tide_system,
376
            normalization=normalization,
377
        )
378

379
    def __repr__(self) -> str:
1✔
380
        return (
×
381
            f"GravityModel(name={self.model_name!r}, "
382
            f"n_max={self.n_max}, m_max={self.m_max}, "
383
            f"gm={self.gm:.6e}, radius={self.radius:.1f})"
384
        )
385

386

387
# ---------------------------------------------------------------------------
388
# Spherical harmonic gravity acceleration
389
# ---------------------------------------------------------------------------
390

391

392
def _factorial_product(n: int, m: int) -> float:
1✔
393
    """Compute (n-m)!/(n+m)! efficiently without full factorials.
394

395
    Args:
396
        n: Degree.
397
        m: Order.
398

399
    Returns:
400
        float: The factorial ratio.
401
    """
402
    p = 1.0
1✔
403
    for i in range(n - m + 1, n + m + 1):
1✔
404
        p /= i
1✔
405
    return p
1✔
406

407

408
def accel_gravity_spherical_harmonics(
1✔
409
    r_eci: ArrayLike,
410
    R_eci_to_ecef: ArrayLike,
411
    gravity_model: GravityModel,
412
    n_max: int,
413
    m_max: int,
414
) -> Array:
415
    """Acceleration from spherical harmonic gravity field expansion.
416

417
    Computes the gravitational acceleration using recursively-computed
418
    associated Legendre functions (V/W matrix method).  The position is
419
    transformed to the body-fixed frame, the acceleration is computed
420
    there, and transformed back to ECI.
421

422
    Args:
423
        r_eci: Position in ECI frame [m].  Shape ``(3,)`` or ``(6,)``
424
            (only first 3 elements used).
425
        R_eci_to_ecef: Rotation matrix from ECI to ECEF, shape ``(3, 3)``.
426
        gravity_model: Loaded gravity model with Stokes coefficients.
427
        n_max: Maximum degree for evaluation (must be <= model's n_max).
428
        m_max: Maximum order for evaluation (must be <= n_max and
429
            <= model's m_max).
430

431
    Returns:
432
        Acceleration in ECI frame [m/s^2], shape ``(3,)``.
433

434
    Examples:
435
        ```python
436
        import jax.numpy as jnp
437
        from astrojax.orbit_dynamics.gravity import (
438
            GravityModel, accel_gravity_spherical_harmonics,
439
        )
440
        model = GravityModel.from_type("JGM3")
441
        r = jnp.array([6878e3, 0.0, 0.0])
442
        R = jnp.eye(3)
443
        a = accel_gravity_spherical_harmonics(r, R, model, 20, 20)
444
        ```
445
    """
446
    _float = get_dtype()
1✔
447
    r = jnp.asarray(r_eci, dtype=_float)[:3]
1✔
448
    R = jnp.asarray(R_eci_to_ecef, dtype=_float)
1✔
449

450
    # Transform to body-fixed frame
451
    r_bf = R @ r
1✔
452

453
    # Convert model data to JAX array
454
    CS = jnp.asarray(gravity_model.data, dtype=_float)
1✔
455

456
    # Compute acceleration in body-fixed frame
457
    a_bf = _compute_spherical_harmonics(
1✔
458
        r_bf,
459
        CS,
460
        n_max,
461
        m_max,
462
        gravity_model.radius,
463
        gravity_model.gm,
464
        gravity_model.is_normalized,
465
    )
466

467
    # Transform back to ECI
468
    return R.T @ a_bf
1✔
469

470

471
def _compute_spherical_harmonics(
1✔
472
    r_bf: Array,
473
    CS: Array,
474
    n_max: int,
475
    m_max: int,
476
    r_ref: float,
477
    gm: float,
478
    is_normalized: bool,
479
) -> Array:
480
    """Core V/W recursion for spherical harmonic gravity using ``jax.lax.fori_loop``.
481

482
    Implements the algorithm from Montenbruck & Gill (2012), p. 56-68.
483
    Uses ``jax.lax.fori_loop`` so the loop body is compiled once,
484
    avoiding the large traced-graph overhead of Python ``for`` loops.
485

486
    Args:
487
        r_bf: Position in body-fixed frame [m], shape ``(3,)``.
488
        CS: Coefficient matrix from GravityModel, shape ``(N+1, M+1)``.
489
        n_max: Maximum degree for evaluation.
490
        m_max: Maximum order for evaluation.
491
        r_ref: Reference radius [m].
492
        gm: Gravitational parameter [m^3/s^2].
493
        is_normalized: Whether coefficients are fully normalized.
494

495
    Returns:
496
        Acceleration in body-fixed frame [m/s^2], shape ``(3,)``.
497
    """
498
    dtype = r_bf.dtype
1✔
499

500
    # Auxiliary quantities
501
    r_sqr = jnp.dot(r_bf, r_bf)
1✔
502
    rho = r_ref * r_ref / r_sqr
1✔
503

504
    # Normalized coordinates
505
    x0 = r_ref * r_bf[0] / r_sqr
1✔
506
    y0 = r_ref * r_bf[1] / r_sqr
1✔
507
    z0 = r_ref * r_bf[2] / r_sqr
1✔
508

509
    # V and W intermediary matrices
510
    size = n_max + 2
1✔
511
    V = jnp.zeros((size, size), dtype=dtype)
1✔
512
    W = jnp.zeros((size, size), dtype=dtype)
1✔
513

514
    # Seed values
515
    V = V.at[0, 0].set(r_ref / jnp.sqrt(r_sqr))
1✔
516
    V = V.at[1, 0].set(z0 * V[0, 0])
1✔
517

518
    # --- Loop 1: Zonal recursion V(n, 0) ---
519
    def zonal_body(n: int, V: Array) -> Array:
1✔
520
        nf = n.astype(dtype)
1✔
521
        V = V.at[n, 0].set(
1✔
522
            ((2.0 * nf - 1.0) * z0 * V[n - 1, 0] - (nf - 1.0) * rho * V[n - 2, 0])
523
            / nf
524
        )
525
        return V
1✔
526

527
    V = jax.lax.fori_loop(2, n_max + 2, zonal_body, V)
1✔
528

529
    # --- Loop 2: Tesseral/sectorial recursion ---
530
    def tesseral_outer(m: int, state: tuple[Array, Array]) -> tuple[Array, Array]:
1✔
531
        V, W = state
1✔
532
        mf = m.astype(dtype)
1✔
533

534
        # Diagonal terms V(m,m) and W(m,m)
535
        V = V.at[m, m].set(
1✔
536
            (2.0 * mf - 1.0) * (x0 * V[m - 1, m - 1] - y0 * W[m - 1, m - 1])
537
        )
538
        W = W.at[m, m].set(
1✔
539
            (2.0 * mf - 1.0) * (x0 * W[m - 1, m - 1] + y0 * V[m - 1, m - 1])
540
        )
541

542
        # Sub-diagonal terms V(m+1, m), W(m+1, m) — guarded for m <= n_max
543
        idx = jnp.minimum(m + 1, size - 1)
1✔
544
        should_write = m <= n_max
1✔
545
        new_V_sub = (2.0 * mf + 1.0) * z0 * V[m, m]
1✔
546
        new_W_sub = (2.0 * mf + 1.0) * z0 * W[m, m]
1✔
547
        V = V.at[idx, m].set(jnp.where(should_write, new_V_sub, V[idx, m]))
1✔
548
        W = W.at[idx, m].set(jnp.where(should_write, new_W_sub, W[idx, m]))
1✔
549

550
        # Inner recursion for n = m+2 .. n_max+1
551
        def tesseral_inner(
1✔
552
            n: int, state: tuple[Array, Array]
553
        ) -> tuple[Array, Array]:
554
            V, W = state
1✔
555
            nf = n.astype(dtype)
1✔
556
            V = V.at[n, m].set(
1✔
557
                (
558
                    (2.0 * nf - 1.0) * z0 * V[n - 1, m]
559
                    - (nf + mf - 1.0) * rho * V[n - 2, m]
560
                )
561
                / (nf - mf)
562
            )
563
            W = W.at[n, m].set(
1✔
564
                (
565
                    (2.0 * nf - 1.0) * z0 * W[n - 1, m]
566
                    - (nf + mf - 1.0) * rho * W[n - 2, m]
567
                )
568
                / (nf - mf)
569
            )
570
            return V, W
1✔
571

572
        V, W = jax.lax.fori_loop(m + 2, n_max + 2, tesseral_inner, (V, W))
1✔
573
        return V, W
1✔
574

575
    V, W = jax.lax.fori_loop(1, m_max + 2, tesseral_outer, (V, W))
1✔
576

577
    # --- Precompute denormalized coefficient arrays ---
578
    # Build normalization matrix using Python loops (pure numpy, no tracing)
579
    norm = np.ones((n_max + 1, m_max + 1), dtype=np.float64)
1✔
580
    if is_normalized:
1✔
581
        for n in range(n_max + 1):
1✔
582
            nf = float(n)
1✔
583
            norm[n, 0] = math.sqrt(2.0 * nf + 1.0)
1✔
584
            for m in range(1, min(n, m_max) + 1):
1✔
585
                norm[n, m] = math.sqrt(
1✔
586
                    2.0 * (2.0 * nf + 1.0) * _factorial_product(n, m)
587
                )
588

589
    norm_jax = jnp.array(norm, dtype=dtype)
1✔
590

591
    # Denormalized C coefficients: C_arr[n, m] = norm[n, m] * CS[n, m]
592
    C_arr = CS[: n_max + 1, : m_max + 1] * norm_jax
1✔
593

594
    # Denormalized S coefficients from lower-triangle storage
595
    # S_arr[n, m] = norm[n, m] * CS[m-1, n]  for m >= 1;  column 0 stays zero
596
    S_arr = jnp.zeros((n_max + 1, m_max + 1), dtype=dtype)
1✔
597
    if m_max > 0:
1✔
598
        # CS[m-1, n] for m=1..m_max, n=0..n_max  =>  CS[0:m_max, 0:n_max+1].T
599
        S_arr = S_arr.at[:, 1 : m_max + 1].set(
1✔
600
            CS[0:m_max, 0 : n_max + 1].T * norm_jax[:, 1 : m_max + 1]
601
        )
602

603
    # --- Loop 3: Acceleration accumulation ---
604
    def accum_outer(
1✔
605
        m: int, state: tuple[Array, Array, Array]
606
    ) -> tuple[Array, Array, Array]:
607
        ax, ay, az = state
1✔
608
        mf = m.astype(dtype)
1✔
609

610
        def accum_inner(
1✔
611
            n: int, inner_state: tuple[Array, Array, Array]
612
        ) -> tuple[Array, Array, Array]:
613
            ax, ay, az = inner_state
1✔
614
            nf = n.astype(dtype)
1✔
615
            C = C_arr[n, m]
1✔
616
            S = S_arr[n, m]
1✔
617

618
            # m == 0 branch (zonal)
619
            ax_z = ax - C * V[n + 1, 1]
1✔
620
            ay_z = ay - C * W[n + 1, 1]
1✔
621
            az_z = az - (nf + 1.0) * C * V[n + 1, 0]
1✔
622

623
            # m != 0 branch (tesseral/sectorial)
624
            Fac = 0.5 * (nf - mf + 1.0) * (nf - mf + 2.0)
1✔
625
            ax_t = ax + (
1✔
626
                0.5 * (-C * V[n + 1, m + 1] - S * W[n + 1, m + 1])
627
                + Fac * (C * V[n + 1, m - 1] + S * W[n + 1, m - 1])
628
            )
629
            ay_t = ay + (
1✔
630
                0.5 * (-C * W[n + 1, m + 1] + S * V[n + 1, m + 1])
631
                + Fac * (-C * W[n + 1, m - 1] + S * V[n + 1, m - 1])
632
            )
633
            az_t = az + (nf - mf + 1.0) * (-C * V[n + 1, m] - S * W[n + 1, m])
1✔
634

635
            is_zonal = m == 0
1✔
636
            ax = jnp.where(is_zonal, ax_z, ax_t)
1✔
637
            ay = jnp.where(is_zonal, ay_z, ay_t)
1✔
638
            az = jnp.where(is_zonal, az_z, az_t)
1✔
639
            return ax, ay, az
1✔
640

641
        ax, ay, az = jax.lax.fori_loop(m, n_max + 1, accum_inner, (ax, ay, az))
1✔
642
        return ax, ay, az
1✔
643

644
    zero = jnp.zeros((), dtype=dtype)
1✔
645
    ax, ay, az = jax.lax.fori_loop(0, m_max + 1, accum_outer, (zero, zero, zero))
1✔
646

647
    # Scale by GM/R_ref^2
648
    scale = gm / (r_ref * r_ref)
1✔
649
    return scale * jnp.array([ax, ay, az])
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc