• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 16859565540

05 Aug 2025 09:12PM UTC coverage: 44.276%. Remained the same
16859565540

push

github

web-flow
Merge pull request #8 from sisl/de/docstrings

Add API Reference Docstring Implementation

205 of 463 relevant lines covered (44.28%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

44.44
/src/astra_rl/core/problem.py
1
"""
2
A "Problem" is one of the core abstractions in Astra RL, defining how to interact
3
with the system under test. The interface is defined by the `Problem` class, which
4
defines a set of abstract methods that users must implement to create a custom problem.
5
This provides flexibility in terms of how users can define their own applications
6
while still adhering to a common interface that enables the Astra RL framework
7
to function correctly.
8
"""
9

10
from abc import ABC, abstractmethod
2✔
11
from collections import defaultdict
2✔
12
from typing import Sequence, Dict, Generic, Union, Iterator
2✔
13

14
import torch
2✔
15

16
from astra_rl.logging import logger
2✔
17
from astra_rl.core.moderator import Moderator
2✔
18
from astra_rl.core.common import StateT, ActionT
2✔
19

20

21
class Problem(ABC, Generic[StateT, ActionT]):
2✔
22
    """Defines the core problem interface for Astra RL.
23

24
    This class is responsible for defining how exactly to interact
25
    with the system under test---with generics in terms of how to get
26
    probabilities and rollouts from the attacker and target models.
27

28
    This allows for us to be generic over the types of states, actions
29
    as well as how to measure them. We ask for a moderator as a way to
30
    ensure that subclasses can all be generic over the exact metric, and
31
    instead can only be opinonated about how to achieve the metric.
32

33
    Attributes:
34
        moderator (Moderator[StateT, ActionT]): The moderator used to evaluate sequences.
35

36
    Generics:
37
        StateT (type): The type of the state in the environment.
38
        ActionT (type): The type of the action in the environment.
39
    """
40

41
    def __init__(self, moderator: Moderator[StateT, ActionT]) -> None:
2✔
42
        # we check all asserts once, and then disable them
43
        self._disable_asserts: Dict[str, bool] = defaultdict(bool)
×
44
        self.moderator = moderator
×
45

46
    @abstractmethod
2✔
47
    def get_target_logprobs(
2✔
48
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
49
    ) -> torch.Tensor:
50
        """Evaluates P(continuation|context) on *model under test*.
51

52
        Args:
53
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
54
                                 continuation's probability is conditioned.
55
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
56
                                      probability is measured.
57

58
        Note:
59
            This should be batched; i.e., len(context) == len(continuation) and each
60
            represents a batch element.
61

62
        Returns:
63
            torch.Tensor: The log probabilities of the continuations given their contexts.
64
        """
65

66
        pass
×
67

68
    @abstractmethod
2✔
69
    def get_baseline_logprobs(
2✔
70
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
71
    ) -> torch.Tensor:
72
        """Evaluates P(continuation|context) on *attacker's baseline distribution* for KL
73
           divergence measurements.
74

75
        Args:
76
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
77
                                 continuation's probability is conditioned.
78
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
79
                                      probability is measured.
80

81
        Note:
82
            This should be batched; i.e., len(context) == len(continuation) and each
83
            represents a batch element. Note that this is *not* the defender's model, but
84
            rather the baseline model used for measuring KL divergence to make sure that
85
            the trained attacker stays an LM.
86

87
        Returns:
88
            torch.Tensor: The log probabilities of the continuations given their contexts.
89
        """
90

91
        pass
×
92

93
    @abstractmethod
2✔
94
    def get_attacker_logprobs(
2✔
95
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
96
    ) -> torch.Tensor:
97
        """Evaluates P(continuation|context) on *attacker*. This must return tensor w/ grads!
98

99
        Args:
100
            context (Sequence[str]): Sequence of strings, where each string is a context on which the
101
                                 continuation's probability is conditioned.
102
            continuation (Sequence[str]): Sequence of strings, where each string is a continuation whose
103
                                      probability is measured.
104

105
        Note:
106
            This should be batched; i.e., len(context) == len(continuation) and each
107
            represents a batch element.
108

109
        Returns:
110
            torch.Tensor: The log probabilities of the continuations given their contexts.
111
        """
112

113
        pass
×
114

115
    @abstractmethod
2✔
116
    def rollout_prompt_with_attacker(self, x: Sequence[StateT]) -> Sequence[ActionT]:
2✔
117
        """Rolls out the prompt with the attacker model. Do *not* return the prompt.
118

119
        a ~ \\pi(s)
120

121
        Args:
122
            x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.
123

124
        Returns:
125
            Sequence[str]: The rolled out prompt with the adversary model.
126
        """
127
        pass
×
128

129
    @abstractmethod
2✔
130
    def rollout_prompt_with_target(self, x: Sequence[StateT]) -> Sequence[StateT]:
2✔
131
        """Rolls out the prompt with the model under test. Do *not* return the prompt.
132

133
        s' ~ \\sum_a T(s, a)
134

135
        Args:
136
            x (Sequence[str]): Sequence of strings representing the prompt to be rolled out.
137

138
        Returns:
139
            Sequence[str]: The rolled out prompt with the adversary model.
140
        """
141
        pass
×
142

143
    @abstractmethod
2✔
144
    def advance(self, context: StateT, attack: ActionT, response: StateT) -> StateT:
2✔
145
        """Given a context and continuation, returns the next state.
146

147
        Args:
148
            context (str): Sequence of strings representing the context.
149
            attack (str): Sequence of strings representing the attack given context.
150
            response (str): Sequence of strings representing the defense against attack.
151

152
        Returns:
153
                str: The next state after applying the continuation to the context.
154
        """
155
        pass
×
156

157
    @abstractmethod
2✔
158
    def parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
2✔
159
        """Return the trainable parameters in this problem.
160

161
        Returns:
162
            Iterator[torch.nn.parameter.Parameter]: An iterator over the trainable parameters.
163
            usually just by calling model.parameters()
164
        """
165
        pass
×
166

167
    @abstractmethod
2✔
168
    def reward(
2✔
169
        self,
170
        context: Sequence[StateT],
171
        attack: Sequence[ActionT],
172
        response: Sequence[StateT],
173
    ) -> Sequence[float]:
174
        pass
×
175

176
    ##### Utility methods for validation and checks #####
177

178
    def _check_continuation(
2✔
179
        self,
180
        check_key: str,
181
        context: Sequence[StateT],
182
        continuation: Sequence[Union[ActionT, StateT]],
183
    ) -> None:
184
        if self._disable_asserts[check_key]:
×
185
            return
×
186
        self._disable_asserts[check_key] = True
×
187

188
    def _check_logprobs(
2✔
189
        self,
190
        check_key: str,
191
        logprobs: torch.Tensor,
192
        ctx_length: int,
193
        requires_grad: bool = False,
194
    ) -> None:
195
        if self._disable_asserts[check_key]:
×
196
            return
×
197
        # check that logprobs is a tensor and has gradients
198
        assert isinstance(logprobs, torch.Tensor), (
×
199
            "Attacker logprobs must be a torch.Tensor."
200
        )
201
        if requires_grad:
×
202
            assert logprobs.requires_grad, (
×
203
                "Attacker logprobs must carry gradient information."
204
            )
205
        # check that the size of the tensor is B x 1, where B is the batch size
206
        if logprobs.dim() == 1:
×
207
            logprobs = logprobs.unsqueeze(1)
×
208
        assert logprobs.dim() == 2, (
×
209
            "Attacker logprobs must be a 2D tensor (B, 1) or a 1D list of numbers (B,)."
210
        )
211
        # check that the first dimension is the batch size
212
        assert logprobs.size(0) == ctx_length, (
×
213
            "Attacker logprobs must have the same batch size as the context."
214
        )
215
        # warn if everything is between 0 and 1
216
        if ((logprobs >= 0.0) & (logprobs <= 1.0)).all():
×
217
            logger.warning(
×
218
                "Attacker *log*probs looks suspiciously like probabilities, "
219
                "try taking the .log() of your tensor?"
220
            )
221
        self._disable_asserts[check_key] = True
×
222

223
    def _get_attacker_logprobs_and_validate(
2✔
224
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
225
    ) -> torch.Tensor:
226
        logprobs = self.get_attacker_logprobs(context, continuation)
×
227
        self._check_logprobs("attacker_logprobs", logprobs, len(context), True)
×
228
        return logprobs
×
229

230
    def _get_target_logprobs_and_validate(
2✔
231
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
232
    ) -> torch.Tensor:
233
        logprobs = self.get_target_logprobs(context, continuation)
×
234
        self._check_logprobs("target_logprobs", logprobs, len(context), False)
×
235
        return logprobs
×
236

237
    def _get_baseline_logprobs_and_validate(
2✔
238
        self, context: Sequence[StateT], continuation: Sequence[ActionT]
239
    ) -> torch.Tensor:
240
        logprobs = self.get_baseline_logprobs(context, continuation)
×
241
        self._check_logprobs("baseline_logprobs", logprobs, len(context), False)
×
242
        return logprobs
×
243

244
    def _rollout_prompt_with_attacker_and_validate(
2✔
245
        self, x: Sequence[StateT]
246
    ) -> Sequence[ActionT]:
247
        rolled_out = self.rollout_prompt_with_attacker(x)
×
248
        self._check_continuation("attacker_rollout", x, rolled_out)
×
249
        return rolled_out
×
250

251
    def _rollout_prompt_with_target_and_validate(
2✔
252
        self, x: Sequence[StateT]
253
    ) -> Sequence[StateT]:
254
        rolled_out = self.rollout_prompt_with_target(x)
×
255
        self._check_continuation("target_rollout", x, rolled_out)
×
256
        return rolled_out
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc