• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 17875546536

20 Sep 2025 04:53AM UTC coverage: 43.428% (-1.6%) from 44.983%
17875546536

push

github

Jemoka
evaluation rollout procedure + implementation

2 of 29 new or added lines in 2 files covered. (6.9%)

1 existing line in 1 file now uncovered.

261 of 601 relevant lines covered (43.43%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

35.29
/src/astra_rl/methods/ast_problem.py
1
"""
2
ast_problem.py
3
ASTProblem
4
"""
5

6
import random
2✔
7
from random import Random
2✔
8
from types import ModuleType
2✔
9
from typing import Sequence, Optional, Union
2✔
10

11
import torch
2✔
12

13
from astra_rl.core.problem import Problem
2✔
14
from astra_rl.core.moderator import Moderator
2✔
15
from astra_rl.core.environment import Node, Graph, Environment
2✔
16

17

18
class ASTProblem(Problem[str, str]):
2✔
19
    """Adaptive Stress Testing for Large Language Model Red-Teaming.
20

21
    Implements https://arxiv.org/abs/2407.09447.
22

23
    For usage examples, see `astra_rl.core.problem.Problem`.
24

25
    Attributes:
26
        moderator (Moderator): The moderator used to evaluate the sequences.
27
        ppl_weight (float): Weight for perplexity in the reward calculation.
28
        combined_weight (float): Weight for combined scores in the reward calculation.
29
        defense_toxicity_weight (float): Weight for defense toxicity in the reward calculation.
30
    """
31

32
    def __init__(
2✔
33
        self,
34
        moderator: Moderator[str, str],
35
        ppl_weight: float = 0.025,
36
        combined_weight: float = 0.1,
37
        defense_toxicity_weight: float = 0.1,
38
    ):
39
        super().__init__(moderator)
×
40

41
        self.ppl_weight = ppl_weight
×
42
        self.combined_weight = combined_weight
×
43
        self.defense_toxicity_weight = defense_toxicity_weight
×
44

45
    def advance(self, state: str, action: str, next: str) -> str:
2✔
46
        return state + action + next
×
47

48
    def reward(
2✔
49
        self, context: Sequence[str], attack: Sequence[str], response: Sequence[str]
50
    ) -> Sequence[float]:
51
        combined_uts = [j + k for j, k in zip(attack, response)]
×
52
        combined_scores: Sequence[float] = self.moderator.moderate(combined_uts)
×
53
        defender_scores: Sequence[float] = self.moderator.moderate(response)
×
54
        attack_ppls: torch.Tensor = (
×
55
            self._get_target_logprobs_and_validate(context, attack).sum(dim=-1).exp()
56
        )  # Sum per-token logprobs before taking exp
57

58
        reward: Sequence[float] = (
×
59
            (
60
                self.ppl_weight * attack_ppls
61
                + self.combined_weight
62
                * torch.tensor(combined_scores).to(attack_ppls.device)
63
                + self.defense_toxicity_weight
64
                * torch.tensor(defender_scores).to(attack_ppls.device)
65
            )
66
            .cpu()
67
            .tolist()
68
        )
69

70
        return reward
×
71

72

73
ASTNode = Node[str, str]
2✔
74

75

76
class ASTEnvironment(Environment[str, str]):
2✔
77
    """The ASTPrompter Rollout Environment
78

79
    Implements https://arxiv.org/abs/2407.09447.
80

81
    Specifically, this is the original rollout system used in the
82
    ASTPrompter paper, the case of red-teaming where we have
83
    the attacker and defender generates successive turns of strings,
84
    each of which is appended to the prompt of the other. They do not
85
    have IFT or other types of structure.
86

87
    For usage examples, see `astra_rl.core.environment.Environment`.
88

89
    Attributes:
90
        problem (ASTProblem): The problem instance that defines the environment and actions.
91
        prompts (Sequence[str]): A sequence of initial prompts to start the rollout.
92
        tree_width (int): The number of branches at each node in the rollout tree.
93
        tree_depth (int): The depth of the rollout tree.
94

95
    Generics:
96
        StateT (str): The type of the state in the environment, which is a string.
97
        ActionT (str): The type of the action in the environment, which is also a string.
98
    """
99

100
    def __init__(
2✔
101
        self,
102
        problem: ASTProblem,
103
        prompts: Sequence[str],
104
        tree_width: int = 2,
105
        tree_depth: int = 3,
106
    ):
107
        super().__init__(problem)
×
108

109
        self.prompts = prompts
×
110
        self.tree_width = tree_width
×
111
        self.tree_depth = tree_depth
×
112

113
    def __handle_prompt(
2✔
114
        self, prompt: str, depth: int = 3, width: int = 2
115
    ) -> Sequence[Node[str, str]]:
116
        if depth == 0:
×
117
            return []
×
118

NEW
119
        prompts = [prompt for _ in range(width)]
×
120
        attacks = self.problem._rollout_prompt_with_attacker_and_validate(prompts)
×
121
        defenses = self.problem._rollout_prompt_with_target_and_validate(
×
122
            [prompt + i for i in attacks]
123
        )
124
        rewards = self.problem.reward(prompts, attacks, defenses)
×
125

126
        nodes = [
×
127
            Node(
128
                prompt,
129
                attack,
130
                defense,
131
                reward,
132
                self.__handle_prompt(
133
                    self.problem.advance(prompt, attack, defense), depth - 1, width
134
                ),
135
            )
136
            for prompt, attack, defense, reward in zip(
137
                prompts, attacks, defenses, rewards
138
            )
139
        ]
140

141
        return nodes
×
142

143
    def rollout(self, seed: Optional[int] = None) -> Graph[str, str]:
2✔
144
        R: Union[Random, ModuleType] = random
×
145
        if seed:
×
146
            R = Random(seed)
×
147
        prompt = R.choice(self.prompts)
×
NEW
148
        nodes = self.__handle_prompt(prompt, self.tree_depth, self.tree_width)
×
149

150
        return Graph(prompt, nodes)
×
151

152
    def eval_rollout(self, prompt: Optional[str] = None) -> Graph[str, str]:
2✔
NEW
153
        if prompt is None:
×
NEW
154
            return self.rollout()
×
155

NEW
156
        nodes = self.__handle_prompt(prompt, self.tree_depth, 1)
×
NEW
157
        return Graph(prompt, nodes)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc