• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 17387194991

01 Sep 2025 09:06PM UTC coverage: 40.99% (-0.8%) from 41.752%
17387194991

push

github

alliegriffith
adding hf_ast trainer class plus updating ASTEnvironment

3 of 16 new or added lines in 1 file covered. (18.75%)

207 of 505 relevant lines covered (40.99%)

0.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

31.67
/src/astra_rl/methods/ast_problem.py
1
"""
2
ast_problem.py
3
ASTProblem
4
"""
5

6
import random
2✔
7
from random import Random
2✔
8
from types import ModuleType
2✔
9
from typing import Sequence, Optional, Union
2✔
10

11
import torch
2✔
12

13
from astra_rl.core.problem import Problem
2✔
14
from astra_rl.core.moderator import Moderator
2✔
15
from astra_rl.core.environment import Node, Graph, Environment
2✔
16

17

18
class ASTProblem(Problem[str, str]):
2✔
19
    """Adaptive Stress Testing for Large Language Model Red-Teaming.
20

21
    Implements https://arxiv.org/abs/2407.09447.
22

23
    For usage examples, see `astra_rl.core.problem.Problem`.
24

25
    Attributes:
26
        moderator (Moderator): The moderator used to evaluate the sequences.
27
        ppl_weight (float): Weight for perplexity in the reward calculation.
28
        combined_weight (float): Weight for combined scores in the reward calculation.
29
        defense_toxicity_weight (float): Weight for defense toxicity in the reward calculation.
30
    """
31

32
    def __init__(
2✔
33
        self,
34
        moderator: Moderator[str, str],
35
        ppl_weight: float = 0.025,
36
        combined_weight: float = 0.1,
37
        defense_toxicity_weight: float = 0.1,
38
    ):
39
        super().__init__(moderator)
×
40

41
        self.ppl_weight = ppl_weight
×
42
        self.combined_weight = combined_weight
×
43
        self.defense_toxicity_weight = defense_toxicity_weight
×
44

45
    def advance(self, state: str, action: str, next: str) -> str:
2✔
46
        return state + action + next
×
47

48
    def reward(
2✔
49
        self, context: Sequence[str], attack: Sequence[str], response: Sequence[str]
50
    ) -> Sequence[float]:
51
        combined_uts = [j + k for j, k in zip(attack, response)]
×
52
        combined_scores: Sequence[float] = self.moderator.moderate(combined_uts)
×
53
        defender_scores: Sequence[float] = self.moderator.moderate(response)
×
54
        attack_ppls: torch.Tensor = self._get_target_logprobs_and_validate(
×
55
            context, attack
56
        ).exp()
57

58
        reward: Sequence[float] = (
×
59
            (
60
                self.ppl_weight * attack_ppls
61
                + self.combined_weight
62
                * torch.tensor(combined_scores).to(attack_ppls.device)
63
                + self.defense_toxicity_weight
64
                * torch.tensor(defender_scores).to(attack_ppls.device)
65
            )
66
            .cpu()
67
            .tolist()
68
        )
69

70
        return reward
×
71

72

73
ASTNode = Node[str, str]
2✔
74

75

76
class ASTEnvironment(Environment[str, str]):
2✔
77
    """The ASTPrompter Rollout Environment
78

79
    Implements https://arxiv.org/abs/2407.09447.
80

81
    Specifically, this is the original rollout system used in the
82
    ASTPrompter paper, the case of red-teaming where we have
83
    the attacker and defender generates successive turns of strings,
84
    each of which is appended to the prompt of the other. They do not
85
    have IFT or other types of structure.
86

87
    For usage examples, see `astra_rl.core.environment.Environment`.
88

89
    Attributes:
90
        problem (ASTProblem): The problem instance that defines the environment and actions.
91
        prompts (Sequence[str]): A sequence of initial prompts to start the rollout.
92
        tree_width (int): The number of branches at each node in the rollout tree.
93
        tree_depth (int): The depth of the rollout tree.
94

95
    Generics:
96
        StateT (str): The type of the state in the environment, which is a string.
97
        ActionT (str): The type of the action in the environment, which is also a string.
98
    """
99

100
    def __init__(
2✔
101
        self,
102
        problem: ASTProblem,
103
        prompts: Sequence[str],
104
        tree_width: int = 2,
105
        tree_depth: int = 3,
106
    ):
107
        super().__init__(problem)
×
108

109
        self.prompts = prompts
×
110
        self.tree_width = tree_width
×
111
        self.tree_depth = tree_depth
×
112

113
    def __handle_prompt(
2✔
114
        self, prompt: str, depth: int = 3, width: Optional[int] = None
115
    ) -> Sequence[Node[str, str]]:
116
        if depth == 0:
×
117
            return []
×
118

NEW
119
        if width is None:
×
NEW
120
            width = self.tree_width
×
121

NEW
122
        prompts = [prompt for _ in range(width)]
×
123
        attacks = self.problem._rollout_prompt_with_attacker_and_validate(prompts)
×
124
        defenses = self.problem._rollout_prompt_with_target_and_validate(
×
125
            [prompt + i for i in attacks]
126
        )
127
        rewards = self.problem.reward(prompts, attacks, defenses)
×
128

129
        nodes = [
×
130
            Node(
131
                prompt,
132
                attack,
133
                defense,
134
                reward,
135
                self.__handle_prompt(
136
                    self.problem.advance(prompt, attack, defense), depth - 1, width
137
                ),
138
            )
139
            for prompt, attack, defense, reward in zip(
140
                prompts, attacks, defenses, rewards
141
            )
142
        ]
143

144
        return nodes
×
145

146
    def rollout(self, seed: Optional[int] = None) -> Graph[str, str]:
2✔
147
        R: Union[Random, ModuleType] = random
×
148
        if seed:
×
149
            R = Random(seed)
×
150
        prompt = R.choice(self.prompts)
×
151
        nodes = self.__handle_prompt(prompt, self.tree_depth)
×
152

153
        return Graph(prompt, nodes)
×
154

155
    # new method generates rollout for evaluation, only a single path
156
    def eval_rollout(self, prompt: str) -> Graph[str, str]:
2✔
NEW
157
        nodes = self.__handle_prompt(prompt, self.tree_depth, width=1)
×
NEW
158
        return Graph(prompt, nodes)
×
159

160
    # new method to get final reward from eval rollout
161
    def final_reward(self, graph: "Graph[str, str]") -> Optional[float]:
2✔
162
        """Return reward at the leaf along the single path (width==1)."""
NEW
163
        if not graph.children:
×
NEW
164
            return None
×
NEW
165
        n = graph.children[0]
×
NEW
166
        last = n.reward
×
NEW
167
        while n.children:
×
NEW
168
            n = n.children[0]
×
NEW
169
            last = n.reward
×
NEW
170
        return last
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc