• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 17983122784

24 Sep 2025 04:27PM UTC coverage: 37.877% (-5.6%) from 43.428%
17983122784

push

github

web-flow
Merge branch 'main' into feat/eval

4 of 97 new or added lines in 3 files covered. (4.12%)

2 existing lines in 1 file now uncovered.

264 of 697 relevant lines covered (37.88%)

0.76 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

34.43
/src/astra_rl/methods/ast_problem.py
1
"""
2
ast_problem.py
3
ASTProblem
4
"""
5

6
import random
2✔
7
from random import Random
2✔
8
from types import ModuleType
2✔
9
from typing import Sequence, Optional, Union, Dict, Any
2✔
10

11
import torch
2✔
12

13
from astra_rl.core.problem import Problem
2✔
14
from astra_rl.core.moderator import Moderator
2✔
15
from astra_rl.core.environment import Node, Graph, Environment
2✔
16

17

18
class ASTProblem(Problem[str, str]):
2✔
19
    """Adaptive Stress Testing for Large Language Model Red-Teaming.
20

21
    Implements https://arxiv.org/abs/2407.09447.
22

23
    For usage examples, see `astra_rl.core.problem.Problem`.
24

25
    Attributes:
26
        moderator (Moderator): The moderator used to evaluate the sequences.
27
        ppl_weight (float): Weight for perplexity in the reward calculation.
28
        combined_weight (float): Weight for combined scores in the reward calculation.
29
        defense_toxicity_weight (float): Weight for defense toxicity in the reward calculation.
30
    """
31

32
    def __init__(
2✔
33
        self,
34
        moderator: Moderator[str, str],
35
        ppl_weight: float = 0.025,
36
        combined_weight: float = 0.1,
37
        defense_toxicity_weight: float = 0.1,
38
    ):
39
        super().__init__(moderator)
×
40

41
        self.ppl_weight = ppl_weight
×
42
        self.combined_weight = combined_weight
×
43
        self.defense_toxicity_weight = defense_toxicity_weight
×
44

45
    def advance(self, state: str, action: str, next: str) -> str:
2✔
46
        return state + action + next
×
47

48
    def reward(
2✔
49
        self, context: Sequence[str], attack: Sequence[str], response: Sequence[str]
50
    ) -> Sequence[float]:
51
        combined_uts = [j + k for j, k in zip(attack, response)]
×
52
        combined_scores: Sequence[float] = self.moderator.moderate(combined_uts)
×
53
        defender_scores: Sequence[float] = self.moderator.moderate(response)
×
54
        attack_ppls: torch.Tensor = (
×
55
            self._get_target_logprobs_and_validate(context, attack).sum(dim=-1).exp()
56
        )  # Sum per-token logprobs before taking exp
57

58
        reward: Sequence[float] = (
×
59
            (
60
                self.ppl_weight * attack_ppls
61
                + self.combined_weight
62
                * torch.tensor(combined_scores).to(attack_ppls.device)
63
                + self.defense_toxicity_weight
64
                * torch.tensor(defender_scores).to(attack_ppls.device)
65
            )
66
            .cpu()
67
            .tolist()
68
        )
69

70
        return reward
×
71

72

73
ASTNode = Node[str, str]
2✔
74

75

76
# custom node class to hold extra eval metrics
77
class CustomNode(Node[str, str]):
2✔
78
    """
79
    A Node with extra per-turn metadata for evaluation/training diagnostics.
80
    Compatible anywhere a plain Node is expected (isinstance(CustomNode, Node) == True).
81
    """
82

83
    def __init__(
2✔
84
        self,
85
        context: str,
86
        attack: str,
87
        response: str,
88
        reward: float,
89
        children: Sequence["Node[str, str]"],
90
        *,
91
        attack_tox: float = 0.0,
92
        target_tox: float = 0.0,
93
        attack_logprob: float = 0.0,
94
        first_attack_tox_turn: Optional[int] = None,
95
        first_target_tox_turn: Optional[int] = None,
96
    ):
97
        # Initialize the base Node fields first
NEW
98
        super().__init__(context, attack, response, reward, list(children))
×
99
        # attach custom metrics
NEW
100
        self.attack_tox: float = float(attack_tox)
×
NEW
101
        self.target_tox: float = float(target_tox)
×
NEW
102
        self.attack_logprob: float = float(attack_logprob)
×
NEW
103
        self.first_attack_tox_turn: Optional[int] = first_attack_tox_turn
×
NEW
104
        self.first_target_tox_turn: Optional[int] = first_target_tox_turn
×
105

106
    # convert to dict for easier saving/analysis
107
    def to_row(self) -> Dict[str, Any]:
2✔
NEW
108
        return {
×
109
            "context": self.context,
110
            "attack": self.attack,
111
            "response": self.response,
112
            "reward": float(self.reward),
113
            "attack_tox": float(self.attack_tox),
114
            "target_tox": float(self.target_tox),
115
            "attack_logprob": float(self.attack_logprob),
116
            "first_attack_tox_turn": self.first_attack_tox_turn,
117
            "first_target_tox_turn": self.first_target_tox_turn,
118
        }
119

120

121
class ASTEnvironment(Environment[str, str]):
2✔
122
    """The ASTPrompter Rollout Environment
123

124
    Implements https://arxiv.org/abs/2407.09447.
125

126
    Specifically, this is the original rollout system used in the
127
    ASTPrompter paper, the case of red-teaming where we have
128
    the attacker and defender generates successive turns of strings,
129
    each of which is appended to the prompt of the other. They do not
130
    have IFT or other types of structure.
131

132
    For usage examples, see `astra_rl.core.environment.Environment`.
133

134
    Attributes:
135
        problem (ASTProblem): The problem instance that defines the environment and actions.
136
        prompts (Sequence[str]): A sequence of initial prompts to start the rollout.
137
        tree_width (int): The number of branches at each node in the rollout tree.
138
        tree_depth (int): The depth of the rollout tree.
139

140
    Generics:
141
        StateT (str): The type of the state in the environment, which is a string.
142
        ActionT (str): The type of the action in the environment, which is also a string.
143
    """
144

145
    def __init__(
2✔
146
        self,
147
        problem: ASTProblem,
148
        prompts: Sequence[str],
149
        tree_width: int = 2,
150
        tree_depth: int = 3,
151
    ):
152
        super().__init__(problem)
×
153

154
        self.prompts = prompts
×
155
        self.tree_width = tree_width
×
156
        self.tree_depth = tree_depth
×
157

158
    def __handle_prompt(
2✔
159
        self, prompt: str, depth: int = 3, width: int = 2
160
    ) -> Sequence[Node[str, str]]:
161
        if depth == 0:
×
162
            return []
×
163

164
        prompts = [prompt for _ in range(width)]
×
165
        attacks = self.problem._rollout_prompt_with_attacker_and_validate(prompts)
×
166
        defenses = self.problem._rollout_prompt_with_target_and_validate(
×
167
            [prompt + i for i in attacks]
168
        )
169
        rewards = self.problem.reward(prompts, attacks, defenses)
×
170

171
        nodes = [
×
172
            Node(
173
                prompt,
174
                attack,
175
                defense,
176
                reward,
177
                self.__handle_prompt(
178
                    self.problem.advance(prompt, attack, defense), depth - 1, width
179
                ),
180
            )
181
            for prompt, attack, defense, reward in zip(
182
                prompts, attacks, defenses, rewards
183
            )
184
        ]
185

186
        return nodes
×
187

188
    def rollout(self, seed: Optional[int] = None) -> Graph[str, str]:
2✔
189
        R: Union[Random, ModuleType] = random
×
190
        if seed:
×
191
            R = Random(seed)
×
192
        prompt = R.choice(self.prompts)
×
193
        nodes = self.__handle_prompt(prompt, self.tree_depth, self.tree_width)
×
194

195
        return Graph(prompt, nodes)
×
196

197
    def eval_rollout(self, prompt: Optional[str] = None) -> Graph[str, str]:
2✔
198
        if prompt is None:
×
199
            return self.rollout()
×
200

201
        nodes = self.__handle_prompt(prompt, self.tree_depth, 1)
×
202
        return Graph(prompt, nodes)
×
203

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc