• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 16262675714

14 Jul 2025 04:49AM UTC coverage: 46.419% (-48.0%) from 94.444%
16262675714

push

github

web-flow
Merge pull request #3 from sisl/feat/core

Initial implementation of core AST algorithm.

160 of 361 new or added lines in 18 files covered. (44.32%)

175 of 377 relevant lines covered (46.42%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

44.44
/src/astra_rl/core/problem.py
1
"""
2
problem.py
3
Generic class of an AstraProblem
4
"""
5

6
from abc import ABC, abstractmethod
2✔
7
from collections import defaultdict
2✔
8
from typing import Sequence, Dict, Generic, Union, Iterator
2✔
9

10
import torch
2✔
11

12
from astra_rl.logging import logger
2✔
13
from astra_rl.core.moderator import Moderator
2✔
14
from astra_rl.core.common import StateT, ActionT
2✔
15

16

17
class Problem(ABC, Generic[StateT, ActionT]):
2✔
18
    """Defines the core problem interface for Astra RL.
19

20
    This class is responsible for defining how exactly to interact
21
    with the system under test---with generics in terms of how to get
22
    probabilities and rollouts from the attacker and target models.
23

24
    This allows for us to be generic over the types of states, actions
25
    as well as how to measure them. We ask for a moderator as a way to
26
    ensure that subclasses can all be generic over the exact metric, and
27
    instead can only be opinonated about how to achieve the metric.
28

29
    Attributes:
30
        moderator (Moderator[StateT, ActionT]): The moderator used to evaluate sequences.
31

32
    Generics:
33
        StateT (type): The type of the state in the environment.
34
        ActionT (type): The type of the action in the environment.
35
    """
36

37
    def __init__(self, moderator: Moderator[StateT, ActionT]) -> None:
2✔
38
        # we check all asserts once, and then disable them
NEW
39
        self._disable_asserts: Dict[str, bool] = defaultdict(bool)
×
NEW
40
        self.moderator = moderator
×
41

42
    @abstractmethod
2✔
43
    def get_target_logprobs(
2✔
44
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
45
    ) -> torch.Tensor:
46
        """Evaluates P(continuation|context) on *model under test*.
47

48
        Args:
49
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
50
                                 continuation's probability is conditioned.
51
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
52
                                      probability is measured.
53

54
        Note:
55
            This should be batched; i.e., len(context) == len(continuation) and each
56
            represents a batch element.
57

58
        Returns:
59
            torch.Tensor: The log probabilities of the continuations given their contexts.
60
        """
61

NEW
62
        pass
×
63

64
    @abstractmethod
2✔
65
    def get_baseline_logprobs(
2✔
66
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
67
    ) -> torch.Tensor:
68
        """Evaluates P(continuation|context) on *attacker's baseline distribution* for KL
69
           divergence measurements.
70

71
        Args:
72
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
73
                                 continuation's probability is conditioned.
74
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
75
                                      probability is measured.
76

77
        Note:
78
            This should be batched; i.e., len(context) == len(continuation) and each
79
            represents a batch element. Note that this is *not* the defender's model, but
80
            rather the baseline model used for measuring KL divergence to make sure that
81
            the trained attacker stays an LM.
82

83
        Returns:
84
            torch.Tensor: The log probabilities of the continuations given their contexts.
85
        """
86

NEW
87
        pass
×
88

89
    @abstractmethod
2✔
90
    def get_attacker_logprobs(
2✔
91
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
92
    ) -> torch.Tensor:
93
        """Evaluates P(continuation|context) on *attacker*. This must return tensor w/ grads!
94

95
        Args:
96
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
97
                                 continuation's probability is conditioned.
98
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
99
                                      probability is measured.
100

101
        Note:
102
            This should be batched; i.e., len(context) == len(continuation) and each
103
            represents a batch element.
104

105
        Returns:
106
            torch.Tensor: The log probabilities of the continuations given their contexts.
107
        """
108

NEW
109
        pass
×
110

111
    @abstractmethod
2✔
112
    def rollout_prompt_with_attacker(self, x: Sequence[StateT]) -> Sequence[ActionT]:
2✔
113
        """Rolls out the prompt with the attacker model. Do *not* return the prompt.
114

115
        a ~ \\pi(s)
116

117
        Args:
118
            x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.
119

120
        Returns:
121
            Sequence[str]: The rolled out prompt with the adversary model.
122
        """
NEW
123
        pass
×
124

125
    @abstractmethod
2✔
126
    def rollout_prompt_with_target(self, x: Sequence[StateT]) -> Sequence[StateT]:
2✔
127
        """Rolls out the prompt with the model under test. Do *not* return the prompt.
128

129
        s' ~ \\sum_a T(s, a)
130

131
        Args:
132
            x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.
133

134
        Returns:
135
            Sequence[str]: The rolled out prompt with the adversary model.
136
        """
NEW
137
        pass
×
138

139
    @abstractmethod
2✔
140
    def advance(self, context: StateT, attack: ActionT, response: StateT) -> StateT:
2✔
141
        """Given a context and continuation, returns the next state.
142

143
        Args:
144
            context (str): Sequence of strings representing the context.
145
            attack (str): Sequence of strings representing the attack given context.
146
            response (str): Sequence of strings representing the defense against attack.
147

148
        Returns:
149
                str: The next state after applying the continuation to the context.
150
        """
NEW
151
        pass
×
152

153
    @abstractmethod
2✔
154
    def parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
2✔
155
        """Return the trainable parameters in this problem.
156

157
        Returns:
158
            Iterator[torch.nn.parameter.Parameter]: An iterator over the trainable parameters.
159
            usually just by calling model.parameters()
160
        """
NEW
161
        pass
×
162

163
    @abstractmethod
2✔
164
    def reward(
2✔
165
        self,
166
        context: Sequence[StateT],
167
        attack: Sequence[ActionT],
168
        response: Sequence[StateT],
169
    ) -> Sequence[float]:
NEW
170
        pass
×
171

172
    ##### Utility methods for validation and checks #####
173

174
    def _check_continuation(
2✔
175
        self,
176
        check_key: str,
177
        context: Sequence[StateT],
178
        continuation: Sequence[Union[ActionT, StateT]],
179
    ) -> None:
NEW
180
        if self._disable_asserts[check_key]:
×
NEW
181
            return
×
NEW
182
        self._disable_asserts[check_key] = True
×
183

184
    def _check_logprobs(
2✔
185
        self,
186
        check_key: str,
187
        logprobs: torch.Tensor,
188
        ctx_length: int,
189
        requires_grad: bool = False,
190
    ) -> None:
NEW
191
        if self._disable_asserts[check_key]:
×
NEW
192
            return
×
193
        # check that logprobs is a tensor and has gradients
NEW
194
        assert isinstance(logprobs, torch.Tensor), (
×
195
            "Attacker logprobs must be a torch.Tensor."
196
        )
NEW
197
        if requires_grad:
×
NEW
198
            assert logprobs.requires_grad, (
×
199
                "Attacker logprobs must carry gradient information."
200
            )
201
        # check that the size of the tensor is B x 1, where B is the batch size
NEW
202
        if logprobs.dim() == 1:
×
NEW
203
            logprobs = logprobs.unsqueeze(1)
×
NEW
204
        assert logprobs.dim() == 2, (
×
205
            "Attacker logprobs must be a 2D tensor (B, 1) or a 1D list of numbers (B,)."
206
        )
207
        # check that the first dimension is the batch size
NEW
208
        assert logprobs.size(0) == ctx_length, (
×
209
            "Attacker logprobs must have the same batch size as the context."
210
        )
211
        # warn if everything is between 0 and 1
NEW
212
        if ((logprobs >= 0.0) & (logprobs <= 1.0)).all():
×
NEW
213
            logger.warning(
×
214
                "Attacker *log*probs looks suspiciously like probabilities, "
215
                "try taking the .log() of your tensor?"
216
            )
NEW
217
        self._disable_asserts[check_key] = True
×
218

219
    def _get_attacker_logprobs_and_validate(
2✔
220
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
221
    ) -> torch.Tensor:
NEW
222
        logprobs = self.get_attacker_logprobs(context, continuation)
×
NEW
223
        self._check_logprobs("attacker_logprobs", logprobs, len(context), True)
×
NEW
224
        return logprobs
×
225

226
    def _get_target_logprobs_and_validate(
2✔
227
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
228
    ) -> torch.Tensor:
NEW
229
        logprobs = self.get_target_logprobs(context, continuation)
×
NEW
230
        self._check_logprobs("target_logprobs", logprobs, len(context), False)
×
NEW
231
        return logprobs
×
232

233
    def _get_baseline_logprobs_and_validate(
2✔
234
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
235
    ) -> torch.Tensor:
NEW
236
        logprobs = self.get_baseline_logprobs(context, continuation)
×
NEW
237
        self._check_logprobs("baseline_logprobs", logprobs, len(context), False)
×
NEW
238
        return logprobs
×
239

240
    def _rollout_prompt_with_attacker_and_validate(
2✔
241
        self, x: Sequence[StateT]
242
    ) -> Sequence[ActionT]:
NEW
243
        rolled_out = self.rollout_prompt_with_attacker(x)
×
NEW
244
        self._check_continuation("attacker_rollout", x, rolled_out)
×
NEW
245
        return rolled_out
×
246

247
    def _rollout_prompt_with_target_and_validate(
2✔
248
        self, x: Sequence[StateT]
249
    ) -> Sequence[StateT]:
NEW
250
        rolled_out = self.rollout_prompt_with_target(x)
×
NEW
251
        self._check_continuation("target_rollout", x, rolled_out)
×
NEW
252
        return rolled_out
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc