• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 17778048441

16 Sep 2025 08:22PM UTC coverage: 36.41% (-11.0%) from 47.436%
17778048441

Pull #16

github

web-flow
De/trail1+docs+fixes (#15)

* Address mkdocs warnings. Add installation and dev_setup pages. Update base documentation page

* Add footnote citations. Fix environment note formatting
Pull Request #16: initial experiments

10 of 126 new or added lines in 4 files covered. (7.94%)

103 existing lines in 4 files now uncovered.

213 of 585 relevant lines covered (36.41%)

0.73 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

28.4
/src/astra_rl/methods/ast_problem.py
1
"""
2
ast_problem.py
3
ASTProblem
4
"""
5

6
import random
2✔
7
from random import Random
2✔
8
from types import ModuleType
2✔
9
from typing import Sequence, Optional, Union, Dict, Any
2✔
10

11
import torch
2✔
12

13
from astra_rl.core.problem import Problem
2✔
14
from astra_rl.core.moderator import Moderator
2✔
15
from astra_rl.core.environment import Node, Graph, Environment
2✔
16

17

18
class ASTProblem(Problem[str, str]):
2✔
19
    """Adaptive Stress Testing for Large Language Model Red-Teaming.
20

21
    Implements https://arxiv.org/abs/2407.09447.
22

23
    For usage examples, see `astra_rl.core.problem.Problem`.
24

25
    Attributes:
26
        moderator (Moderator): The moderator used to evaluate the sequences.
27
        ppl_weight (float): Weight for perplexity in the reward calculation.
28
        combined_weight (float): Weight for combined scores in the reward calculation.
29
        defense_toxicity_weight (float): Weight for defense toxicity in the reward calculation.
30
    """
31

32
    def __init__(
2✔
33
        self,
34
        moderator: Moderator[str, str],
35
        ppl_weight: float = 0.025,
36
        combined_weight: float = 0.1,
37
        defense_toxicity_weight: float = 0.1,
38
    ):
39
        super().__init__(moderator)
×
40

41
        self.ppl_weight = ppl_weight
×
42
        self.combined_weight = combined_weight
×
43
        self.defense_toxicity_weight = defense_toxicity_weight
×
44

45
    def advance(self, state: str, action: str, next: str) -> str:
2✔
46
        return state + action + next
×
47

48
    def reward(
2✔
49
        self, context: Sequence[str], attack: Sequence[str], response: Sequence[str]
50
    ) -> Sequence[float]:
51
        combined_uts = [j + k for j, k in zip(attack, response)]
×
52
        combined_scores: Sequence[float] = self.moderator.moderate(combined_uts)
×
53
        defender_scores: Sequence[float] = self.moderator.moderate(response)
×
54
        attack_ppls: torch.Tensor = self._get_target_logprobs_and_validate(
×
55
            context, attack
56
        ).exp()
57

58
        reward: Sequence[float] = (
×
59
            (
60
                self.ppl_weight * attack_ppls
61
                + self.combined_weight
62
                * torch.tensor(combined_scores).to(attack_ppls.device)
63
                + self.defense_toxicity_weight
64
                * torch.tensor(defender_scores).to(attack_ppls.device)
65
            )
66
            .cpu()
67
            .tolist()
68
        )
69

70
        return reward
×
71

72

73
ASTNode = Node[str, str]
2✔
74

75

76
# custom node class to hold extra eval metrics
77
class CustomNode(Node[str, str]):
2✔
78
    """
79
    A Node with extra per-turn metadata for evaluation/training diagnostics.
80
    Compatible anywhere a plain Node is expected (isinstance(CustomNode, Node) == True).
81
    """
82

83
    def __init__(
2✔
84
        self,
85
        context: str,
86
        attack: str,
87
        response: str,
88
        reward: float,
89
        children: Sequence["Node[str, str]"],
90
        *,
91
        attack_tox: float = 0.0,
92
        target_tox: float = 0.0,
93
        attack_logprob: float = 0.0,
94
    ):
95
        # Initialize the base Node fields first
NEW
96
        super().__init__(context, attack, response, reward, list(children))
×
97
        # Then attach your custom metrics
NEW
98
        self.attack_tox: float = float(attack_tox)
×
NEW
99
        self.target_tox: float = float(target_tox)
×
NEW
100
        self.attack_logprob: float = float(attack_logprob)
×
101

102
    # Handy for JSON/CSV export
103
    def to_row(self) -> Dict[str, Any]:
2✔
NEW
104
        return {
×
105
            "context": self.context,
106
            "attack": self.attack,
107
            "response": self.response,
108
            "reward": float(self.reward),
109
            "attack_tox": float(self.attack_tox),
110
            "target_tox": float(self.target_tox),
111
            "attack_logprob": float(self.attack_logprob),
112
        }
113

114

115
class ASTEnvironment(Environment[str, str]):
2✔
116
    """The ASTPrompter Rollout Environment
117

118
    Implements https://arxiv.org/abs/2407.09447.
119

120
    Specifically, this is the original rollout system used in the
121
    ASTPrompter paper, the case of red-teaming where we have
122
    the attacker and defender generates successive turns of strings,
123
    each of which is appended to the prompt of the other. They do not
124
    have IFT or other types of structure.
125

126
    For usage examples, see `astra_rl.core.environment.Environment`.
127

128
    Attributes:
129
        problem (ASTProblem): The problem instance that defines the environment and actions.
130
        prompts (Sequence[str]): A sequence of initial prompts to start the rollout.
131
        tree_width (int): The number of branches at each node in the rollout tree.
132
        tree_depth (int): The depth of the rollout tree.
133

134
    Generics:
135
        StateT (str): The type of the state in the environment, which is a string.
136
        ActionT (str): The type of the action in the environment, which is also a string.
137
    """
138

139
    def __init__(
2✔
140
        self,
141
        problem: ASTProblem,
142
        prompts: Sequence[str],
143
        tree_width: int = 2,
144
        tree_depth: int = 3,
145
    ):
146
        super().__init__(problem)
×
147

148
        self.prompts = prompts
×
149
        self.tree_width = tree_width
×
150
        self.tree_depth = tree_depth
×
151

152
    def __handle_prompt(
2✔
153
        self, prompt: str, depth: int = 3, width: Optional[int] = None
154
    ) -> Sequence[Node[str, str]]:
155
        if depth == 0:
×
156
            return []
×
157

NEW
158
        if width is None:
×
NEW
159
            width = self.tree_width
×
160

NEW
161
        prompts = [prompt for _ in range(width)]
×
162
        attacks = self.problem._rollout_prompt_with_attacker_and_validate(prompts)
×
163
        defenses = self.problem._rollout_prompt_with_target_and_validate(
×
164
            [prompt + i for i in attacks]
165
        )
166
        rewards = self.problem.reward(prompts, attacks, defenses)
×
167

168
        nodes = [
×
169
            Node(
170
                prompt,
171
                attack,
172
                defense,
173
                reward,
174
                self.__handle_prompt(
175
                    self.problem.advance(prompt, attack, defense), depth - 1, width
176
                ),
177
            )
178
            for prompt, attack, defense, reward in zip(
179
                prompts, attacks, defenses, rewards
180
            )
181
        ]
182

183
        return nodes
×
184

185
    def rollout(self, seed: Optional[int] = None) -> Graph[str, str]:
2✔
186
        R: Union[Random, ModuleType] = random
×
187
        if seed:
×
188
            R = Random(seed)
×
189
        prompt = R.choice(self.prompts)
×
NEW
190
        nodes = self.__handle_prompt(prompt, self.tree_depth, self.tree_width)
×
191

NEW
192
        return Graph(prompt, nodes)
×
193

194
    # Allie TODO: test this code
195
    def __handle_eval_prompt(self, prompt: str, depth: int = 3) -> Sequence[CustomNode]:
2✔
NEW
196
        if depth == 0:
×
NEW
197
            return []
×
NEW
198
        width = 1
×
NEW
199
        prompts = [prompt for _ in range(width)]
×
NEW
200
        attacks = self.problem._rollout_prompt_with_attacker_and_validate(prompts)
×
NEW
201
        defenses = self.problem._rollout_prompt_with_target_and_validate(
×
202
            [prompt + i for i in attacks]
203
        )
NEW
204
        rewards = self.problem.reward(prompts, attacks, defenses)
×
NEW
205
        attack_toxs = self.problem.moderator.moderate(attacks)
×
NEW
206
        target_toxs = self.problem.moderator.moderate(defenses)
×
NEW
207
        attack_logprobs = self.problem._get_target_logprobs_and_validate(
×
208
            prompts, attacks
209
        )
210

NEW
211
        nodes = [
×
212
            CustomNode(
213
                prompt,
214
                attack,
215
                defense,
216
                reward,
217
                self.__handle_eval_prompt(
218
                    self.problem.advance(prompt, attack, defense), depth - 1
219
                ),
220
                attack_tox=attack_tox,
221
                target_tox=target_tox,
222
                attack_logprob=attack_logprob.sum().item(),
223
            )
224
            for prompt, attack, defense, reward, attack_tox, target_tox, attack_logprob in zip(
225
                prompts,
226
                attacks,
227
                defenses,
228
                rewards,
229
                attack_toxs,
230
                target_toxs,
231
                attack_logprobs,
232
            )
233
        ]
NEW
234
        return nodes
×
235

236
    # new method generates rollout for evaluation, only a single path
237
    def eval_rollout(self, prompt: str) -> Graph[str, str]:
2✔
NEW
238
        nodes = self.__handle_eval_prompt(prompt, self.tree_depth)
×
UNCOV
239
        return Graph(prompt, nodes)
×
240

241
    # new method to get final reward from eval rollout
242
    def final_reward(self, graph: "Graph[str, str]") -> Optional[float]:
2✔
243
        """Return reward at the leaf along the single path (width==1)."""
NEW
244
        if not graph.children:
×
NEW
245
            return None
×
NEW
246
        n = graph.children[0]
×
NEW
247
        last = n.reward
×
NEW
248
        while n.children:
×
NEW
249
            n = n.children[0]
×
NEW
250
            last = n.reward
×
NEW
251
        return last
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc