• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 17458822168

28 Aug 2025 12:22AM UTC coverage: 47.436%. Remained the same
17458822168

push

github

web-flow
Merge pull request #14 from sisl/copilot/fix-13

Add device consistency validation for logprobs to prevent cryptic runtime errors

5 of 5 new or added lines in 1 file covered. (100.0%)

18 existing lines in 1 file now uncovered.

259 of 546 relevant lines covered (47.44%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.64
/src/astra_rl/core/problem.py
1
"""
2
A "Problem" is one of the core abstractions in Astra RL, defining how to interact
3
with the system under test. The interface is defined by the `Problem` class, which
4
defines a set of abstract methods that users must implement to create a custom problem.
5
This provides flexibility in terms of how users can define their own applications
6
while still adhering to a common interface that enables the Astra RL framework
7
to function correctly.
8
"""
9

10
from abc import ABC, abstractmethod
2✔
11
from collections import defaultdict
2✔
12
from typing import Sequence, Dict, Generic, Union, Iterator, Optional
2✔
13

14
import torch
2✔
15

16
from astra_rl.logging import logger
2✔
17
from astra_rl.core.moderator import Moderator
2✔
18
from astra_rl.core.common import StateT, ActionT
2✔
19

20

21
class Problem(ABC, Generic[StateT, ActionT]):
2✔
22
    """Defines the core problem interface for Astra RL.
23

24
    This class is responsible for defining how exactly to interact
25
    with the system under test---with generics in terms of how to get
26
    probabilities and rollouts from the attacker and target models.
27

28
    This allows for us to be generic over the types of states, actions
29
    as well as how to measure them. We ask for a moderator as a way to
30
    ensure that subclasses can all be generic over the exact metric, and
31
    instead can only be opinonated about how to achieve the metric.
32

33
    Attributes:
34
        moderator (Moderator[StateT, ActionT]): The moderator used to evaluate sequences.
35

36
    Generics:
37
        StateT (type): The type of the state in the environment.
38
        ActionT (type): The type of the action in the environment.
39
    """
40

41
    def __init__(self, moderator: Moderator[StateT, ActionT]) -> None:
2✔
42
        # we check all asserts once, and then disable them
43
        self._disable_asserts: Dict[str, bool] = defaultdict(bool)
2✔
44
        # track the device of the first logprobs tensor to ensure consistency
45
        self._expected_device: Optional[torch.device] = None
2✔
46
        self.moderator = moderator
2✔
47

48
    @abstractmethod
2✔
49
    def get_target_logprobs(
2✔
50
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
51
    ) -> torch.Tensor:
52
        """Evaluates P(continuation|context) on *model under test*.
53

54
        Args:
55
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
56
                                 continuation's probability is conditioned.
57
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
58
                                      probability is measured.
59

60
        Note:
61
            This should be batched; i.e., len(context) == len(continuation) and each
62
            represents a batch element.
63

64
        Returns:
65
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
66
                         Shape: (batch_size, max_continuation_length)
67
        """
68

UNCOV
69
        pass
×
70

71
    @abstractmethod
2✔
72
    def get_baseline_logprobs(
2✔
73
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
74
    ) -> torch.Tensor:
75
        """Evaluates P(continuation|context) on *attacker's baseline distribution* for KL
76
           divergence measurements.
77

78
        Args:
79
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
80
                                 continuation's probability is conditioned.
81
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
82
                                      probability is measured.
83

84
        Note:
85
            This should be batched; i.e., len(context) == len(continuation) and each
86
            represents a batch element. Note that this is *not* the defender's model, but
87
            rather the baseline model used for measuring KL divergence to make sure that
88
            the trained attacker stays an LM.
89

90
        Returns:
91
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
92
                         Shape: (batch_size, max_continuation_length)
93
        """
94

UNCOV
95
        pass
×
96

97
    @abstractmethod
2✔
98
    def get_attacker_logprobs(
2✔
99
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
100
    ) -> torch.Tensor:
101
        """Evaluates P(continuation|context) on *attacker*. This must return tensor w/ grads!
102

103
        Args:
104
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
105
                                 continuation's probability is conditioned.
106
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
107
                                      probability is measured.
108

109
        Note:
110
            This should be batched; i.e., len(context) == len(continuation) and each
111
            represents a batch element.
112

113
        Returns:
114
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
115
                         Shape: (batch_size, max_continuation_length)
116
        """
117

UNCOV
118
        pass
×
119

120
    @abstractmethod
2✔
121
    def rollout_prompt_with_attacker(self, x: Sequence[StateT]) -> Sequence[ActionT]:
2✔
122
        """Rolls out the prompt with the attacker model. Do *not* return the prompt.
123

124
        a ~ \\pi(s)
125

126
        Args:
127
            x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.
128

129
        Returns:
130
            Sequence[str]: The rolled out prompt with the adversary model.
131
        """
UNCOV
132
        pass
×
133

134
    @abstractmethod
2✔
135
    def rollout_prompt_with_target(self, x: Sequence[StateT]) -> Sequence[StateT]:
2✔
136
        """Rolls out the prompt with the model under test. Do *not* return the prompt.
137

138
        s' ~ \\sum_a T(s, a)
139

140
        Args:
141
            x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.
142

143
        Returns:
144
            Sequence[str]: The rolled out prompt with the adversary model.
145
        """
UNCOV
146
        pass
×
147

148
    @abstractmethod
2✔
149
    def advance(self, context: StateT, attack: ActionT, response: StateT) -> StateT:
2✔
150
        """Given a context and continuation, returns the next state.
151

152
        Args:
153
            context (str): Sequence of strings representing the context.
154
            attack (str): Sequence of strings representing the attack given context.
155
            response (str): Sequence of strings representing the defense against attack.
156

157
        Returns:
158
                str: The next state after applying the continuation to the context.
159
        """
UNCOV
160
        pass
×
161

162
    @abstractmethod
2✔
163
    def parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
2✔
164
        """Return the trainable parameters in this problem.
165

166
        Returns:
167
            Iterator[torch.nn.parameter.Parameter]: An iterator over the trainable parameters.
168
            usually just by calling model.parameters()
169
        """
UNCOV
170
        pass
×
171

172
    @abstractmethod
2✔
173
    def reward(
2✔
174
        self,
175
        context: Sequence[StateT],
176
        attack: Sequence[ActionT],
177
        response: Sequence[StateT],
178
    ) -> Sequence[float]:
UNCOV
179
        pass
×
180

181
    ##### Utility methods for validation and checks #####
182

183
    def _check_continuation(
2✔
184
        self,
185
        check_key: str,
186
        context: Sequence[StateT],
187
        continuation: Sequence[Union[ActionT, StateT]],
188
    ) -> None:
UNCOV
189
        if self._disable_asserts[check_key]:
×
UNCOV
190
            return
×
191
        self._disable_asserts[check_key] = True
×
192

193
    def _check_logprobs(
2✔
194
        self,
195
        check_key: str,
196
        logprobs: torch.Tensor,
197
        ctx_length: int,
198
        requires_grad: bool = False,
199
    ) -> None:
200
        if self._disable_asserts[check_key]:
2✔
201
            return
2✔
202
        # check that logprobs is a tensor and has gradients
203
        assert isinstance(logprobs, torch.Tensor), "Logprobs must be a torch.Tensor."
2✔
204
        if requires_grad:
2✔
205
            assert logprobs.requires_grad, (
2✔
206
                "Attacker logprobs must carry gradient information."
207
            )
208
        # check that the size of the tensor is B x T, where B is the batch size and T is max_continuation_length
209
        assert logprobs.dim() == 2, (
2✔
210
            "Logprobs must be a 2D tensor (batch_size, max_continuation_length)."
211
        )
212
        # check that the first dimension is the batch size
213
        assert logprobs.size(0) == ctx_length, (
2✔
214
            "Logprobs must have the same batch size as the context."
215
        )
216
        # check device consistency across all logprobs
217
        if self._expected_device is None:
2✔
218
            # This is the first logprobs tensor we've seen, set the expected device
219
            self._expected_device = logprobs.device
2✔
220
        else:
221
            # Validate that this tensor is on the same device as previous ones
222
            assert logprobs.device == self._expected_device, (
2✔
223
                f"All logprobs must be on the same device. Expected {self._expected_device}, "
224
                f"but {check_key} logprobs are on {logprobs.device}. "
225
                f"This typically happens when models are on different devices. "
226
                f"Please ensure all models (attacker, target, baseline) are on the same device."
227
            )
228
        # warn if everything is between 0 and 1
229
        if ((logprobs >= 0.0) & (logprobs <= 1.0)).all():
2✔
UNCOV
230
            logger.warning(
×
231
                "Logprobs looks suspiciously like probabilities, "
232
                "try taking the .log() of your tensor?"
233
            )
234
        self._disable_asserts[check_key] = True
2✔
235

236
    def _get_attacker_logprobs_and_validate(
2✔
237
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
238
    ) -> torch.Tensor:
239
        logprobs = self.get_attacker_logprobs(context, continuation)
2✔
240
        self._check_logprobs("attacker_logprobs", logprobs, len(context), True)
2✔
241
        return logprobs
2✔
242

243
    def _get_target_logprobs_and_validate(
2✔
244
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
245
    ) -> torch.Tensor:
246
        logprobs = self.get_target_logprobs(context, continuation)
2✔
247
        self._check_logprobs("target_logprobs", logprobs, len(context), False)
2✔
248
        return logprobs
2✔
249

250
    def _get_baseline_logprobs_and_validate(
2✔
251
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
252
    ) -> torch.Tensor:
253
        logprobs = self.get_baseline_logprobs(context, continuation)
2✔
254
        self._check_logprobs("baseline_logprobs", logprobs, len(context), False)
2✔
255
        return logprobs
2✔
256

257
    def _rollout_prompt_with_attacker_and_validate(
2✔
258
        self, x: Sequence[StateT]
259
    ) -> Sequence[ActionT]:
UNCOV
260
        rolled_out = self.rollout_prompt_with_attacker(x)
×
UNCOV
261
        self._check_continuation("attacker_rollout", x, rolled_out)
×
UNCOV
262
        return rolled_out
×
263

264
    def _rollout_prompt_with_target_and_validate(
2✔
265
        self, x: Sequence[StateT]
266
    ) -> Sequence[StateT]:
UNCOV
267
        rolled_out = self.rollout_prompt_with_target(x)
×
UNCOV
268
        self._check_continuation("target_rollout", x, rolled_out)
×
UNCOV
269
        return rolled_out
×
270

271

272
class ValueFunctionProblem(Problem[StateT, ActionT], ABC):
2✔
273
    """Extends `Problem` to be able to return sequence values with a value head.
274

275
    Note:
276
        This is useful for value-laiden solution methods such as Actor
277
        Critic derivatives (i.e., PPO).
278

279
    Attributes:
280
        moderator (Moderator[StateT, ActionT]): The moderator used to evaluate sequences.
281

282
    Generics:
283
        StateT (type): The type of the state in the environment.
284
        ActionT (type): The type of the action in the environment.
285
    """
286

287
    @abstractmethod
2✔
288
    def value(
2✔
289
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
290
    ) -> torch.Tensor:
291
        """Given a squence, evaluate its token-wise value using a value function.
292

293
        Notes:
294
           This is typically done by the same neural network you use for rollouts
295
           just passing the intermediate activations through another layer.
296

297
        Args:
298
            elem (Sequence[StateT]): The sequence to evaluate.
299

300
        Returns:
301
            torch.Tensor[batch_size, max_continuation_length]: The per-token values of
302
            the given squence by the sequence predictor. Do not include the value of the input
303
            prefixes. If you are predicting on the whole input, you should be slicing on
304
            `[:, :-1]`, meaning you should *not* return the value of the last token, whose
305
            input is eos/context length limit.
306
        """
307

UNCOV
308
        pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc