• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 17204280202

19 Aug 2025 11:37PM UTC coverage: 43.358% (-0.9%) from 44.276%
17204280202

push

github

web-flow
Merge pull request #12 from sisl/feat/ppo

PPO

32 of 88 new or added lines in 6 files covered. (36.36%)

1 existing line in 1 file now uncovered.

235 of 542 relevant lines covered (43.36%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

36.96
/src/astra_rl/methods/ast_problem.py
1
"""
2
ast_problem.py
3
ASTProblem
4
"""
5

6
import random
2✔
7
from random import Random
2✔
8
from types import ModuleType
2✔
9
from typing import Sequence, Optional, Union
2✔
10

11
import torch
2✔
12

13
from astra_rl.core.problem import Problem
2✔
14
from astra_rl.core.moderator import Moderator
2✔
15
from astra_rl.core.environment import Node, Graph, Environment
2✔
16

17

18
class ASTProblem(Problem[str, str]):
2✔
19
    """Adaptive Stress Testing for Large Language Model Red-Teaming.
20

21
    Implements https://arxiv.org/abs/2407.09447.
22

23
    For usage examples, see `astra_rl.core.problem.Problem`.
24

25
    Attributes:
26
        moderator (Moderator): The moderator used to evaluate the sequences.
27
        ppl_weight (float): Weight for perplexity in the reward calculation.
28
        combined_weight (float): Weight for combined scores in the reward calculation.
29
        defense_toxicity_weight (float): Weight for defense toxicity in the reward calculation.
30
    """
31

32
    def __init__(
2✔
33
        self,
34
        moderator: Moderator[str, str],
35
        ppl_weight: float = 0.025,
36
        combined_weight: float = 0.1,
37
        defense_toxicity_weight: float = 0.1,
38
    ):
39
        super().__init__(moderator)
×
40

41
        self.ppl_weight = ppl_weight
×
42
        self.combined_weight = combined_weight
×
43
        self.defense_toxicity_weight = defense_toxicity_weight
×
44

45
    def advance(self, state: str, action: str, next: str) -> str:
2✔
46
        return state + action + next
×
47

48
    def reward(
2✔
49
        self, context: Sequence[str], attack: Sequence[str], response: Sequence[str]
50
    ) -> Sequence[float]:
51
        combined_uts = [j + k for j, k in zip(attack, response)]
×
52
        combined_scores: Sequence[float] = self.moderator.moderate(combined_uts)
×
53
        defender_scores: Sequence[float] = self.moderator.moderate(response)
×
NEW
54
        attack_ppls: torch.Tensor = (
×
55
            self._get_target_logprobs_and_validate(context, attack).sum(dim=-1).exp()
56
        )  # Sum per-token logprobs before taking exp
57

58
        reward: Sequence[float] = (
×
59
            (
60
                self.ppl_weight * attack_ppls
61
                + self.combined_weight
62
                * torch.tensor(combined_scores).to(attack_ppls.device)
63
                + self.defense_toxicity_weight
64
                * torch.tensor(defender_scores).to(attack_ppls.device)
65
            )
66
            .cpu()
67
            .tolist()
68
        )
69

70
        return reward
×
71

72

73
ASTNode = Node[str, str]
2✔
74

75

76
class ASTEnvironment(Environment[str, str]):
2✔
77
    """The ASTPrompter Rollout Environment
78

79
    Implements https://arxiv.org/abs/2407.09447.
80

81
    Specifically, this is the original rollout system used in the
82
    ASTPrompter paper, the case of red-teaming where we have
83
    the attacker and defender generates successive turns of strings,
84
    each of which is appended to the prompt of the other. They do not
85
    have IFT or other types of structure.
86

87
    For usage examples, see `astra_rl.core.environment.Environment`.
88

89
    Attributes:
90
        problem (ASTProblem): The problem instance that defines the environment and actions.
91
        prompts (Sequence[str]): A sequence of initial prompts to start the rollout.
92
        tree_width (int): The number of branches at each node in the rollout tree.
93
        tree_depth (int): The depth of the rollout tree.
94

95
    Generics:
96
        StateT (str): The type of the state in the environment, which is a string.
97
        ActionT (str): The type of the action in the environment, which is also a string.
98
    """
99

100
    def __init__(
2✔
101
        self,
102
        problem: ASTProblem,
103
        prompts: Sequence[str],
104
        tree_width: int = 2,
105
        tree_depth: int = 3,
106
    ):
107
        super().__init__(problem)
×
108

109
        self.prompts = prompts
×
110
        self.tree_width = tree_width
×
111
        self.tree_depth = tree_depth
×
112

113
    def __handle_prompt(self, prompt: str, depth: int = 3) -> Sequence[Node[str, str]]:
2✔
114
        if depth == 0:
×
115
            return []
×
116

117
        prompts = [prompt for _ in range(self.tree_width)]
×
118
        attacks = self.problem._rollout_prompt_with_attacker_and_validate(prompts)
×
119
        defenses = self.problem._rollout_prompt_with_target_and_validate(
×
120
            [prompt + i for i in attacks]
121
        )
122
        rewards = self.problem.reward(prompts, attacks, defenses)
×
123

124
        nodes = [
×
125
            Node(
126
                prompt,
127
                attack,
128
                defense,
129
                reward,
130
                self.__handle_prompt(
131
                    self.problem.advance(prompt, attack, defense), depth - 1
132
                ),
133
            )
134
            for prompt, attack, defense, reward in zip(
135
                prompts, attacks, defenses, rewards
136
            )
137
        ]
138

139
        return nodes
×
140

141
    def rollout(self, seed: Optional[int] = None) -> Graph[str, str]:
2✔
142
        R: Union[Random, ModuleType] = random
×
143
        if seed:
×
144
            R = Random(seed)
×
145
        prompt = R.choice(self.prompts)
×
146
        nodes = self.__handle_prompt(prompt, self.tree_depth)
×
147

148
        return Graph(prompt, nodes)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc