• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 17093729270

19 Aug 2025 11:37PM UTC coverage: 43.358% (-0.9%) from 44.276%
17093729270

push

github

web-flow
Merge pull request #12 from sisl/feat/ppo

PPO

32 of 88 new or added lines in 6 files covered. (36.36%)

1 existing line in 1 file now uncovered.

235 of 542 relevant lines covered (43.36%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

47.3
/src/astra_rl/core/problem.py
1
"""
2
A "Problem" is one of the core abstractions in Astra RL, defining how to interact
3
with the system under test. The interface is defined by the `Problem` class, which
4
defines a set of abstract methods that users must implement to create a custom problem.
5
This provides flexibility in terms of how users can define their own applications
6
while still adhering to a common interface that enables the Astra RL framework
7
to function correctly.
8
"""
9

10
from abc import ABC, abstractmethod
2✔
11
from collections import defaultdict
2✔
12
from typing import Sequence, Dict, Generic, Union, Iterator
2✔
13

14
import torch
2✔
15

16
from astra_rl.logging import logger
2✔
17
from astra_rl.core.moderator import Moderator
2✔
18
from astra_rl.core.common import StateT, ActionT
2✔
19

20

21
class Problem(ABC, Generic[StateT, ActionT]):
2✔
22
    """Defines the core problem interface for Astra RL.
23

24
    This class is responsible for defining how exactly to interact
25
    with the system under test---with generics in terms of how to get
26
    probabilities and rollouts from the attacker and target models.
27

28
    This allows for us to be generic over the types of states, actions
29
    as well as how to measure them. We ask for a moderator as a way to
30
    ensure that subclasses can all be generic over the exact metric, and
31
    instead can only be opinonated about how to achieve the metric.
32

33
    Attributes:
34
        moderator (Moderator[StateT, ActionT]): The moderator used to evaluate sequences.
35

36
    Generics:
37
        StateT (type): The type of the state in the environment.
38
        ActionT (type): The type of the action in the environment.
39
    """
40

41
    def __init__(self, moderator: Moderator[StateT, ActionT]) -> None:
2✔
42
        # we check all asserts once, and then disable them
43
        self._disable_asserts: Dict[str, bool] = defaultdict(bool)
×
44
        self.moderator = moderator
×
45

46
    @abstractmethod
2✔
47
    def get_target_logprobs(
2✔
48
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
49
    ) -> torch.Tensor:
50
        """Evaluates P(continuation|context) on *model under test*.
51

52
        Args:
53
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
54
                                 continuation's probability is conditioned.
55
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
56
                                      probability is measured.
57

58
        Note:
59
            This should be batched; i.e., len(context) == len(continuation) and each
60
            represents a batch element.
61

62
        Returns:
63
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
64
                         Shape: (batch_size, max_continuation_length)
65
        """
66

67
        pass
×
68

69
    @abstractmethod
2✔
70
    def get_baseline_logprobs(
2✔
71
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
72
    ) -> torch.Tensor:
73
        """Evaluates P(continuation|context) on *attacker's baseline distribution* for KL
74
           divergence measurements.
75

76
        Args:
77
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
78
                                 continuation's probability is conditioned.
79
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
80
                                      probability is measured.
81

82
        Note:
83
            This should be batched; i.e., len(context) == len(continuation) and each
84
            represents a batch element. Note that this is *not* the defender's model, but
85
            rather the baseline model used for measuring KL divergence to make sure that
86
            the trained attacker stays an LM.
87

88
        Returns:
89
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
90
                         Shape: (batch_size, max_continuation_length)
91
        """
92

93
        pass
×
94

95
    @abstractmethod
2✔
96
    def get_attacker_logprobs(
2✔
97
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
98
    ) -> torch.Tensor:
99
        """Evaluates P(continuation|context) on *attacker*. This must return tensor w/ grads!
100

101
        Args:
102
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
103
                                 continuation's probability is conditioned.
104
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
105
                                      probability is measured.
106

107
        Note:
108
            This should be batched; i.e., len(context) == len(continuation) and each
109
            represents a batch element.
110

111
        Returns:
112
            torch.Tensor: The per-token log probabilities of the continuations given their contexts.
113
                         Shape: (batch_size, max_continuation_length)
114
        """
115

116
        pass
×
117

118
    @abstractmethod
2✔
119
    def rollout_prompt_with_attacker(self, x: Sequence[StateT]) -> Sequence[ActionT]:
2✔
120
        """Rolls out the prompt with the attacker model. Do *not* return the prompt.
121

122
        a ~ \\pi(s)
123

124
        Args:
125
            x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.
126

127
        Returns:
128
            Sequence[str]: The rolled out prompt with the adversary model.
129
        """
130
        pass
×
131

132
    @abstractmethod
2✔
133
    def rollout_prompt_with_target(self, x: Sequence[StateT]) -> Sequence[StateT]:
2✔
134
        """Rolls out the prompt with the model under test. Do *not* return the prompt.
135

136
        s' ~ \\sum_a T(s, a)
137

138
        Args:
139
            x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.
140

141
        Returns:
142
            Sequence[str]: The rolled out prompt with the adversary model.
143
        """
144
        pass
×
145

146
    @abstractmethod
2✔
147
    def advance(self, context: StateT, attack: ActionT, response: StateT) -> StateT:
2✔
148
        """Given a context and continuation, returns the next state.
149

150
        Args:
151
            context (str): Sequence of strings representing the context.
152
            attack (str): Sequence of strings representing the attack given context.
153
            response (str): Sequence of strings representing the defense against attack.
154

155
        Returns:
156
                str: The next state after applying the continuation to the context.
157
        """
158
        pass
×
159

160
    @abstractmethod
2✔
161
    def parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
2✔
162
        """Return the trainable parameters in this problem.
163

164
        Returns:
165
            Iterator[torch.nn.parameter.Parameter]: An iterator over the trainable parameters.
166
            usually just by calling model.parameters()
167
        """
168
        pass
×
169

170
    @abstractmethod
2✔
171
    def reward(
2✔
172
        self,
173
        context: Sequence[StateT],
174
        attack: Sequence[ActionT],
175
        response: Sequence[StateT],
176
    ) -> Sequence[float]:
177
        pass
×
178

179
    ##### Utility methods for validation and checks #####
180

181
    def _check_continuation(
2✔
182
        self,
183
        check_key: str,
184
        context: Sequence[StateT],
185
        continuation: Sequence[Union[ActionT, StateT]],
186
    ) -> None:
187
        if self._disable_asserts[check_key]:
×
188
            return
×
189
        self._disable_asserts[check_key] = True
×
190

191
    def _check_logprobs(
2✔
192
        self,
193
        check_key: str,
194
        logprobs: torch.Tensor,
195
        ctx_length: int,
196
        requires_grad: bool = False,
197
    ) -> None:
198
        if self._disable_asserts[check_key]:
×
199
            return
×
200
        # check that logprobs is a tensor and has gradients
NEW
201
        assert isinstance(logprobs, torch.Tensor), "Logprobs must be a torch.Tensor."
×
UNCOV
202
        if requires_grad:
×
203
            assert logprobs.requires_grad, (
×
204
                "Attacker logprobs must carry gradient information."
205
            )
206
        # check that the size of the tensor is B x T, where B is the batch size and T is max_continuation_length
207
        assert logprobs.dim() == 2, (
×
208
            "Logprobs must be a 2D tensor (batch_size, max_continuation_length)."
209
        )
210
        # check that the first dimension is the batch size
211
        assert logprobs.size(0) == ctx_length, (
×
212
            "Logprobs must have the same batch size as the context."
213
        )
214
        # warn if everything is between 0 and 1
215
        if ((logprobs >= 0.0) & (logprobs <= 1.0)).all():
×
216
            logger.warning(
×
217
                "Logprobs looks suspiciously like probabilities, "
218
                "try taking the .log() of your tensor?"
219
            )
220
        self._disable_asserts[check_key] = True
×
221

222
    def _get_attacker_logprobs_and_validate(
2✔
223
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
224
    ) -> torch.Tensor:
225
        logprobs = self.get_attacker_logprobs(context, continuation)
×
226
        self._check_logprobs("attacker_logprobs", logprobs, len(context), True)
×
227
        return logprobs
×
228

229
    def _get_target_logprobs_and_validate(
2✔
230
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
231
    ) -> torch.Tensor:
232
        logprobs = self.get_target_logprobs(context, continuation)
×
233
        self._check_logprobs("target_logprobs", logprobs, len(context), False)
×
234
        return logprobs
×
235

236
    def _get_baseline_logprobs_and_validate(
2✔
237
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
238
    ) -> torch.Tensor:
239
        logprobs = self.get_baseline_logprobs(context, continuation)
×
240
        self._check_logprobs("baseline_logprobs", logprobs, len(context), False)
×
241
        return logprobs
×
242

243
    def _rollout_prompt_with_attacker_and_validate(
2✔
244
        self, x: Sequence[StateT]
245
    ) -> Sequence[ActionT]:
246
        rolled_out = self.rollout_prompt_with_attacker(x)
×
247
        self._check_continuation("attacker_rollout", x, rolled_out)
×
248
        return rolled_out
×
249

250
    def _rollout_prompt_with_target_and_validate(
2✔
251
        self, x: Sequence[StateT]
252
    ) -> Sequence[StateT]:
253
        rolled_out = self.rollout_prompt_with_target(x)
×
254
        self._check_continuation("target_rollout", x, rolled_out)
×
255
        return rolled_out
×
256

257

258
class ValueFunctionProblem(Problem[StateT, ActionT], ABC):
2✔
259
    """Extends `Problem` to be able to return sequence values with a value head.
260

261
    Note:
262
        This is useful for value-laiden solution methods such as Actor
263
        Critic derivatives (i.e., PPO).
264

265
    Attributes:
266
        moderator (Moderator[StateT, ActionT]): The moderator used to evaluate sequences.
267

268
    Generics:
269
        StateT (type): The type of the state in the environment.
270
        ActionT (type): The type of the action in the environment.
271
    """
272

273
    @abstractmethod
2✔
274
    def value(
2✔
275
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
276
    ) -> torch.Tensor:
277
        """Given a squence, evaluate its token-wise value using a value function.
278

279
        Notes:
280
           This is typically done by the same neural network you use for rollouts
281
           just passing the intermediate activations through another layer.
282

283
        Args:
284
            elem (Sequence[StateT]): The sequence to evaluate.
285

286
        Returns:
287
            torch.Tensor[batch_size, max_continuation_length]: The per-token values of
288
            the given squence by the sequence predictor. Do not include the value of the input
289
            prefixes. If you are predicting on the whole input, you should be slicing on
290
            `[:, :-1]`, meaning you should *not* return the value of the last token, whose
291
            input is eos/context length limit.
292
        """
293

NEW
294
        pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc