• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 18474409012

13 Oct 2025 06:07PM UTC coverage: 40.444% (+1.7%) from 38.778%
18474409012

Pull #27

github

web-flow
Merge 8971d248e into fa925eab6
Pull Request #27: WIP: Package Generalization

100 of 189 new or added lines in 10 files covered. (52.91%)

8 existing lines in 5 files now uncovered.

383 of 947 relevant lines covered (40.44%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.76
/src/astra_rl/core/system.py
1
"""
2
A "System" is one of the core abstractions in Astra RL, defining how to interact
3
with the system under test. The system is defined by the `System` class, which
4
defines a set of abstract methods that users must implement to create a custom system.
5
This provides flexibility in terms of how users can define their own applications
6
while still adhering to a common system interface that enables the Astra RL framework
7
to function correctly.
8
"""
9

10
from abc import ABC, abstractmethod
2✔
11
from collections import defaultdict
2✔
12
from typing import Sequence, Dict, Generic, Union, Iterator, Optional, Tuple, cast
2✔
13

14
import torch
2✔
15

16
from astra_rl.utils import logger
2✔
17
from astra_rl.core.common import StateT, ActionT
2✔
18

19

20
class System(ABC, Generic[StateT, ActionT]):
2✔
21
    """Base system defining a problem setup with minimal interface.
22

23
    This is the minimal interface that all systems must implement. For simple
24
    evaluation-only systems, only these three methods are required:
25
    - rollout: Generate one step of interaction
26
    - advance: Transition to the next state
27
    - reward: Compute rewards from interactions
28

29
    For training systems (e.g., adversarial training), use specialized subclasses
30
    like `TrainableSystem` or `AdversarialSystem` that add methods for computing
31
    logprobs and accessing trainable parameters.
32

33
    Generics:
34
        StateT (type): The type of the state in the system.
35
        ActionT (type): The type of the action in the system.
36
    """
37

38
    def __init__(self) -> None:
2✔
39
        # we check all asserts once, and then disable them
40
        self._disable_asserts: Dict[str, bool] = defaultdict(bool)
2✔
41
        # track the device of the first logprobs tensor to ensure consistency
42
        self._expected_device: Optional[torch.device] = None
2✔
43

44
    @abstractmethod
2✔
45
    def rollout(
2✔
46
        self, states: Sequence[StateT]
47
    ) -> Tuple[Sequence[Optional[ActionT]], Sequence[StateT]]:
48
        """Generate one step: actions and responses (batched).
49

50
        For simple systems: actions are None, responses are the generations
51
        For adversarial: actions are challenges, responses are target's replies
52

53
        Args:
54
            states: Batch of states to roll out from
55

56
        Returns:
57
            Tuple of (challenges, responses), both sequences matching input batch size
58
        """
NEW
59
        pass
×
60

61
    @abstractmethod
2✔
62
    def advance(
2✔
63
        self, state: StateT, action: Optional[ActionT], response: StateT
64
    ) -> StateT:
65
        """Transition to next state.
66

67
        Args:
68
            state: Current state
69
            action: Action taken (challenge for adversarial, None for static eval)
70
            response: Response from target model
71

72
        Returns:
73
            Next state after applying action and response
74
        """
NEW
75
        pass
×
76

77
    @abstractmethod
2✔
78
    def reward(
2✔
79
        self,
80
        context: Sequence[StateT],
81
        challenge: Sequence[Optional[ActionT]],
82
        response: Sequence[StateT],
83
    ) -> Sequence[float]:
84
        """Compute reward from interactions.
85

86
        Args:
87
            context: Batch of starting states
88
            challenge: Batch of actions taken (may be None for static eval)
89
            response: Batch of responses from target
90

91
        Returns:
92
            Sequence of reward values, one per batch element
93
        """
NEW
94
        pass
×
95

96

97
class TrainableSystem(System[StateT, ActionT], ABC):
2✔
98
    """System with trainable components and logprobs computation.
99

100
    This extends the base System with methods needed for gradient-based training:
101
    - get_tester_logprobs: For computing gradients
102
    - get_baseline_logprobs: For KL divergence penalties
103
    - get_target_logprobs: For certain algorithms
104
    - parameters: For optimizer setup
105

106
    This is the base class for adversarial training systems.
107
    """
108

109
    @abstractmethod
2✔
110
    def get_target_logprobs(
2✔
111
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
112
    ) -> torch.Tensor:
113
        """Evaluates P(continuation|context) on *model under test*.
114

115
        Args:
116
            context: Sequence of contexts on which the continuation's probability is conditioned.
117
            continuation: Sequence of continuations whose probability is measured.
118

119
        Note:
120
            This should be batched; i.e., len(context) == len(continuation) and each
121
            represents a batch element.
122

123
        Returns:
124
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
125
                         Shape: (batch_size, max_continuation_length)
126
        """
UNCOV
127
        pass
×
128

129
    @abstractmethod
2✔
130
    def get_baseline_logprobs(
2✔
131
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
132
    ) -> torch.Tensor:
133
        """Evaluates P(continuation|context) on *tester's baseline distribution* for KL
134
           divergence measurements.
135

136
        Args:
137
            context: Sequence of contexts on which the continuation's probability is conditioned.
138
            continuation: Sequence of continuations whose probability is measured.
139

140
        Note:
141
            This should be batched; i.e., len(context) == len(continuation) and each
142
            represents a batch element. Note that this is *not* the defender's model, but
143
            rather the baseline model used for measuring KL divergence to make sure that
144
            the trained tester stays an LM.
145

146
        Returns:
147
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
148
                         Shape: (batch_size, max_continuation_length)
149
        """
UNCOV
150
        pass
×
151

152
    @abstractmethod
2✔
153
    def get_tester_logprobs(
2✔
154
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
155
    ) -> torch.Tensor:
156
        """Evaluates P(continuation|context) on *tester*. This must return tensor w/ grads!
157

158
        Args:
159
            context: Sequence of contexts on which the continuation's probability is conditioned.
160
            continuation: Sequence of continuations whose probability is measured.
161

162
        Note:
163
            This should be batched; i.e., len(context) == len(continuation) and each
164
            represents a batch element.
165

166
        Returns:
167
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
168
                         Shape: (batch_size, max_continuation_length)
169
        """
UNCOV
170
        pass
×
171

172
    @abstractmethod
2✔
173
    def rollout_prompt_with_tester(self, x: Sequence[StateT]) -> Sequence[ActionT]:
2✔
174
        """Rolls out the prompt with the tester model. Do *not* return the prompt.
175

176
        a ~ \\pi(s)
177

178
        Args:
179
            x: Sequence of strings representing the prompt to be rolled out.
180

181
        Returns:
182
            The rolled out prompt with the adversary model.
183
        """
184
        pass
×
185

186
    @abstractmethod
2✔
187
    def rollout_prompt_with_target(self, x: Sequence[StateT]) -> Sequence[StateT]:
2✔
188
        """Rolls out the prompt with the model under test. Do *not* return the prompt.
189

190
        s' ~ \\sum_a T(s, a)
191

192
        Args:
193
            x: Sequence of strings representing the prompt to be rolled out.
194

195
        Returns:
196
            The rolled out prompt with the adversary model.
197
        """
198
        pass
×
199

200
    @abstractmethod
2✔
201
    def parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
2✔
202
        """Return the trainable parameters in this system.
203

204
        Returns:
205
            Iterator over the trainable parameters, usually from calling model.parameters()
206
        """
207
        pass
×
208

209
    def rollout(
2✔
210
        self, states: Sequence[StateT]
211
    ) -> Tuple[Sequence[Optional[ActionT]], Sequence[StateT]]:
212
        """Default implementation for adversarial systems.
213

214
        Generates challenges from tester, then responses from target.
215
        """
NEW
216
        challenges = self.rollout_prompt_with_tester(states)
×
NEW
217
        full_prompts = [
×
218
            self.advance(s, c, cast(StateT, "")) for s, c in zip(states, challenges)
219
        ]
NEW
220
        responses = self.rollout_prompt_with_target(full_prompts)
×
NEW
221
        return challenges, responses
×
222

223
    ##### Utility methods for validation and checks #####
224

225
    def _check_continuation(
2✔
226
        self,
227
        check_key: str,
228
        context: Sequence[StateT],
229
        continuation: Sequence[Union[ActionT, StateT]],
230
    ) -> None:
231
        if self._disable_asserts[check_key]:
×
232
            return
×
233
        self._disable_asserts[check_key] = True
×
234

235
    def _check_logprobs(
2✔
236
        self,
237
        check_key: str,
238
        logprobs: torch.Tensor,
239
        ctx_length: int,
240
        requires_grad: bool = False,
241
    ) -> None:
242
        if self._disable_asserts[check_key]:
2✔
243
            return
2✔
244
        # check that logprobs is a tensor and has gradients
245
        assert isinstance(logprobs, torch.Tensor), "Logprobs must be a torch.Tensor."
2✔
246
        if requires_grad:
2✔
247
            assert logprobs.requires_grad, (
2✔
248
                "Tester logprobs must carry gradient information."
249
            )
250
        # check that the size of the tensor is B x T, where B is the batch size and T is max_continuation_length
251
        assert logprobs.dim() == 2, (
2✔
252
            "Logprobs must be a 2D tensor (batch_size, max_continuation_length)."
253
        )
254
        # check that the first dimension is the batch size
255
        assert logprobs.size(0) == ctx_length, (
2✔
256
            "Logprobs must have the same batch size as the context."
257
        )
258
        # check device consistency across all logprobs
259
        if self._expected_device is None:
2✔
260
            # This is the first logprobs tensor we've seen, set the expected device
261
            self._expected_device = logprobs.device
2✔
262
        else:
263
            # Validate that this tensor is on the same device as previous ones
264
            assert logprobs.device == self._expected_device, (
2✔
265
                f"All logprobs must be on the same device. Expected {self._expected_device}, "
266
                f"but {check_key} logprobs are on {logprobs.device}. "
267
                f"This typically happens when models are on different devices. "
268
                f"Please ensure all models (tester, target, baseline) are on the same device."
269
            )
270
        # warn if everything is between 0 and 1
271
        if ((logprobs >= 0.0) & (logprobs <= 1.0)).all():
2✔
272
            logger.warning(
×
273
                "Logprobs looks suspiciously like probabilities, "
274
                "try taking the .log() of your tensor?"
275
            )
276
        self._disable_asserts[check_key] = True
2✔
277

278
    def _get_tester_logprobs_and_validate(
2✔
279
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
280
    ) -> torch.Tensor:
281
        logprobs = self.get_tester_logprobs(context, continuation)
2✔
282
        self._check_logprobs("tester_logprobs", logprobs, len(context), True)
2✔
283
        return logprobs
2✔
284

285
    def _get_target_logprobs_and_validate(
2✔
286
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
287
    ) -> torch.Tensor:
288
        logprobs = self.get_target_logprobs(context, continuation)
2✔
289
        self._check_logprobs("target_logprobs", logprobs, len(context), False)
2✔
290
        return logprobs
2✔
291

292
    def _get_baseline_logprobs_and_validate(
2✔
293
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
294
    ) -> torch.Tensor:
295
        logprobs = self.get_baseline_logprobs(context, continuation)
2✔
296
        self._check_logprobs("baseline_logprobs", logprobs, len(context), False)
2✔
297
        return logprobs
2✔
298

299
    def _rollout_prompt_with_tester_and_validate(
2✔
300
        self, x: Sequence[StateT]
301
    ) -> Sequence[ActionT]:
302
        rolled_out = self.rollout_prompt_with_tester(x)
×
303
        self._check_continuation("tester_rollout", x, rolled_out)
×
304
        return rolled_out
×
305

306
    def _rollout_prompt_with_target_and_validate(
2✔
307
        self, x: Sequence[StateT]
308
    ) -> Sequence[StateT]:
309
        rolled_out = self.rollout_prompt_with_target(x)
×
310
        self._check_continuation("target_rollout", x, rolled_out)
×
311
        return rolled_out
×
312

313

314
class ValueFunctionSystem(TrainableSystem[StateT, ActionT], ABC):
2✔
315
    """Extends `TrainableSystem` to be able to return sequence values with a value head.
316

317
    Note:
318
        This is useful for value-laden solution methods such as Actor
319
        Critic derivatives (i.e., PPO).
320

321
    Generics:
322
        StateT (type): The type of the state in the sampler.
323
        ActionT (type): The type of the action in the sampler.
324
    """
325

326
    @abstractmethod
2✔
327
    def value(
2✔
328
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
329
    ) -> torch.Tensor:
330
        """Given a sequence, evaluate its token-wise value using a value function.
331

332
        Notes:
333
           This is typically done by the same neural network you use for rollouts
334
           just passing the intermediate activations through another layer.
335

336
        Args:
337
            context: The context sequence.
338
            continuation: The continuation sequence to evaluate.
339

340
        Returns:
341
            torch.Tensor[batch_size, max_continuation_length]: The per-token values of
342
            the given sequence by the sequence predictor. Do not include the value of the input
343
            prefixes. If you are predicting on the whole input, you should be slicing on
344
            `[:, :-1]`, meaning you should *not* return the value of the last token, whose
345
            input is eos/context length limit.
346
        """
UNCOV
347
        pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc