• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 16182589621

09 Jul 2025 11:45PM UTC coverage: 11.881% (+0.6%) from 11.258%
16182589621

push

github

Jemoka
[wip] adds AST implementation of rollout generator

7 of 80 new or added lines in 7 files covered. (8.75%)

24 of 202 relevant lines covered (11.88%)

0.24 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/astra_rl/methods/ast.py
1
"""
2
ast.py
3
ASTProblem
4
"""
5

NEW
6
import random
×
NEW
7
from random import Random
×
NEW
8
from types import ModuleType
×
NEW
9
from typing import Sequence, Optional, Union
×
10

NEW
11
import torch
×
12

NEW
13
from astra_rl.core.problem import ASTRAProblem
×
NEW
14
from astra_rl.core.moderator import Moderator
×
NEW
15
from astra_rl.core.rollout import Node, Graph, RolloutGenerator
×
16

17

NEW
18
class ASTProblem(ASTRAProblem[str, str]):
×
NEW
19
    def __init__(
×
20
        self,
21
        moderator: Moderator[str, str],
22
        ppl_weight: float = 0.025,
23
        combined_weight: float = 0.1,
24
        defense_toxicity_weight: float = 0.1,
25
    ):
NEW
26
        super().__init__(moderator)
×
27

NEW
28
        self.ppl_weight = ppl_weight
×
NEW
29
        self.combined_weight = combined_weight
×
NEW
30
        self.defense_toxicity_weight = defense_toxicity_weight
×
31

NEW
32
    def advance(self, state: str, action: str, next: str) -> str:
×
NEW
33
        return state + action + next
×
34

NEW
35
    def reward(
×
36
        self, context: Sequence[str], attack: Sequence[str], response: Sequence[str]
37
    ) -> Sequence[float]:
NEW
38
        combined_uts = [j + k for j, k in zip(attack, response)]
×
NEW
39
        combined_scores: Sequence[float] = self.moderator.moderate(combined_uts)
×
NEW
40
        defender_scores: Sequence[float] = self.moderator.moderate(response)
×
NEW
41
        attack_ppls: torch.Tensor = self.get_target_logprobs(context, attack).exp()
×
42

NEW
43
        reward: Sequence[float] = (
×
44
            (
45
                self.ppl_weight * attack_ppls
46
                + self.combined_weight
47
                * torch.tensor(combined_scores).to(attack_ppls.device)
48
                + self.defense_toxicity_weight
49
                * torch.tensor(defender_scores).to(attack_ppls.device)
50
            )
51
            .cpu()
52
            .tolist()
53
        )
54

NEW
55
        return reward
×
56

57

NEW
58
ASTNode = Node[str, str]
×
59

60

NEW
61
class ASTTreeRolloutGenerator(RolloutGenerator[str, str]):
×
NEW
62
    def __init__(
×
63
        self,
64
        problem: ASTProblem,
65
        prompts: Sequence[str],
66
        tree_width: int = 2,
67
        tree_depth: int = 3,
68
    ):
NEW
69
        super().__init__(problem)
×
70

NEW
71
        self.prompts = prompts
×
NEW
72
        self.tree_width = tree_width
×
NEW
73
        self.tree_depth = tree_depth
×
74

NEW
75
    def __handle_prompt(self, prompt: str, depth: int = 3) -> Sequence[Node[str, str]]:
×
NEW
76
        if depth == 0:
×
NEW
77
            return []
×
78

NEW
79
        prompts = [prompt for _ in range(self.tree_width)]
×
NEW
80
        attacks = self.problem.rollout_prompt_with_attacker(prompts)
×
NEW
81
        defenses = self.problem.rollout_prompt_with_target(
×
82
            [prompt + i for i in attacks]
83
        )
NEW
84
        rewards = self.problem.reward(prompts, attacks, defenses)
×
85

NEW
86
        nodes = [
×
87
            Node(
88
                prompt,
89
                attack,
90
                defense,
91
                reward,
92
                self.__handle_prompt(
93
                    self.problem.advance(prompt, attack, defense), depth - 1
94
                ),
95
            )
96
            for prompt, attack, defense, reward in zip(
97
                prompts, attacks, defenses, rewards
98
            )
99
        ]
100

NEW
101
        return nodes
×
102

NEW
103
    def rollout(self, seed: Optional[int] = None) -> Graph[str, str]:
×
NEW
104
        R: Union[Random, ModuleType] = random
×
NEW
105
        if seed:
×
NEW
106
            R = Random(seed)
×
NEW
107
        prompt = R.choice(self.prompts)
×
NEW
108
        nodes = self.__handle_prompt(prompt, self.tree_depth)
×
109

NEW
110
        return Graph(prompt, nodes)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc