• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 16258219062

14 Jul 2025 04:49AM UTC coverage: 46.419% (-48.0%) from 94.444%
16258219062

push

github

web-flow
Merge pull request #3 from sisl/feat/core

Initial implementation of core AST algorithm.

160 of 361 new or added lines in 18 files covered. (44.32%)

175 of 377 relevant lines covered (46.42%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

36.96
/src/astra_rl/methods/ast.py
1
"""
2
ast.py
3
ASTProblem
4
"""
5

6
import random
2✔
7
from random import Random
2✔
8
from types import ModuleType
2✔
9
from typing import Sequence, Optional, Union
2✔
10

11
import torch
2✔
12

13
from astra_rl.core.problem import Problem
2✔
14
from astra_rl.core.moderator import Moderator
2✔
15
from astra_rl.core.environment import Node, Graph, Environment
2✔
16

17

18
class ASTProblem(Problem[str, str]):
2✔
19
    """Adaptive Stress Testing for Large Language Model Red-Teaming.
20

21
    Implements https://arxiv.org/abs/2407.09447.
22

23
    For usage examples, see `astra_rl.core.problem.Problem`.
24

25
    Attributes:
26
        moderator (Moderator): The moderator used to evaluate the sequences.
27
        ppl_weight (float): Weight for perplexity in the reward calculation.
28
        combined_weight (float): Weight for combined scores in the reward calculation.
29
        defense_toxicity_weight (float): Weight for defense toxicity in the reward calculation.
30
    """
31

32
    def __init__(
2✔
33
        self,
34
        moderator: Moderator[str, str],
35
        ppl_weight: float = 0.025,
36
        combined_weight: float = 0.1,
37
        defense_toxicity_weight: float = 0.1,
38
    ):
NEW
39
        super().__init__(moderator)
×
40

NEW
41
        self.ppl_weight = ppl_weight
×
NEW
42
        self.combined_weight = combined_weight
×
NEW
43
        self.defense_toxicity_weight = defense_toxicity_weight
×
44

45
    def advance(self, state: str, action: str, next: str) -> str:
2✔
NEW
46
        return state + action + next
×
47

48
    def reward(
2✔
49
        self, context: Sequence[str], attack: Sequence[str], response: Sequence[str]
50
    ) -> Sequence[float]:
NEW
51
        combined_uts = [j + k for j, k in zip(attack, response)]
×
NEW
52
        combined_scores: Sequence[float] = self.moderator.moderate(combined_uts)
×
NEW
53
        defender_scores: Sequence[float] = self.moderator.moderate(response)
×
NEW
54
        attack_ppls: torch.Tensor = self._get_target_logprobs_and_validate(
×
55
            context, attack
56
        ).exp()
57

NEW
58
        reward: Sequence[float] = (
×
59
            (
60
                self.ppl_weight * attack_ppls
61
                + self.combined_weight
62
                * torch.tensor(combined_scores).to(attack_ppls.device)
63
                + self.defense_toxicity_weight
64
                * torch.tensor(defender_scores).to(attack_ppls.device)
65
            )
66
            .cpu()
67
            .tolist()
68
        )
69

NEW
70
        return reward
×
71

72

73
ASTNode = Node[str, str]
2✔
74

75

76
class ASTEnvironment(Environment[str, str]):
2✔
77
    """The ASTPrompter Rollout Environment
78

79
    Implements https://arxiv.org/abs/2407.09447.
80

81
    Specifically, this is the original rollout system used in the
82
    ASTPrompter paper, the case of red-teaming where we have
83
    the attacker and defender generates successive turns of strings,
84
    each of which is appended to the prompt of the other. They do not
85
    have IFT or other types of structure.
86

87
    For usage examples, see `astra_rl.core.environment.Environment`.
88

89
    Attributes:
90
        problem (ASTProblem): The problem instance that defines the environment and actions.
91
        prompts (Sequence[str]): A sequence of initial prompts to start the rollout.
92
        tree_width (int): The number of branches at each node in the rollout tree.
93
        tree_depth (int): The depth of the rollout tree.
94

95
    Generics:
96
        StateT (str): The type of the state in the environment, which is a string.
97
        ActionT (str): The type of the action in the environment, which is also a string.
98
    """
99

100
    def __init__(
2✔
101
        self,
102
        problem: ASTProblem,
103
        prompts: Sequence[str],
104
        tree_width: int = 2,
105
        tree_depth: int = 3,
106
    ):
NEW
107
        super().__init__(problem)
×
108

NEW
109
        self.prompts = prompts
×
NEW
110
        self.tree_width = tree_width
×
NEW
111
        self.tree_depth = tree_depth
×
112

113
    def __handle_prompt(self, prompt: str, depth: int = 3) -> Sequence[Node[str, str]]:
2✔
NEW
114
        if depth == 0:
×
NEW
115
            return []
×
116

NEW
117
        prompts = [prompt for _ in range(self.tree_width)]
×
NEW
118
        attacks = self.problem._rollout_prompt_with_attacker_and_validate(prompts)
×
NEW
119
        defenses = self.problem._rollout_prompt_with_target_and_validate(
×
120
            [prompt + i for i in attacks]
121
        )
NEW
122
        rewards = self.problem.reward(prompts, attacks, defenses)
×
123

NEW
124
        nodes = [
×
125
            Node(
126
                prompt,
127
                attack,
128
                defense,
129
                reward,
130
                self.__handle_prompt(
131
                    self.problem.advance(prompt, attack, defense), depth - 1
132
                ),
133
            )
134
            for prompt, attack, defense, reward in zip(
135
                prompts, attacks, defenses, rewards
136
            )
137
        ]
138

NEW
139
        return nodes
×
140

141
    def rollout(self, seed: Optional[int] = None) -> Graph[str, str]:
2✔
NEW
142
        R: Union[Random, ModuleType] = random
×
NEW
143
        if seed:
×
NEW
144
            R = Random(seed)
×
NEW
145
        prompt = R.choice(self.prompts)
×
NEW
146
        nodes = self.__handle_prompt(prompt, self.tree_depth)
×
147

NEW
148
        return Graph(prompt, nodes)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc