• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 16201881450

10 Jul 2025 05:29PM UTC coverage: 10.039% (-1.8%) from 11.881%
16201881450

push

github

Jemoka
DPO implementation

2 of 80 new or added lines in 5 files covered. (2.5%)

1 existing line in 1 file now uncovered.

26 of 259 relevant lines covered (10.04%)

0.2 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/astra_rl/methods/ast.py
1
"""
2
ast.py
3
ASTProblem
4
"""
5

6
import random
×
7
from random import Random
×
8
from types import ModuleType
×
9
from typing import Sequence, Optional, Union
×
10

11
import torch
×
12

13
from astra_rl.core.problem import ASTRAProblem
×
14
from astra_rl.core.moderator import Moderator
×
15
from astra_rl.core.rollout import Node, Graph, RolloutGenerator
×
16

17

18
class ASTProblem(ASTRAProblem[str, str]):
×
19
    def __init__(
×
20
        self,
21
        moderator: Moderator[str, str],
22
        ppl_weight: float = 0.025,
23
        combined_weight: float = 0.1,
24
        defense_toxicity_weight: float = 0.1,
25
    ):
26
        super().__init__(moderator)
×
27

28
        self.ppl_weight = ppl_weight
×
29
        self.combined_weight = combined_weight
×
30
        self.defense_toxicity_weight = defense_toxicity_weight
×
31

32
    def advance(self, state: str, action: str, next: str) -> str:
×
33
        return state + action + next
×
34

35
    def reward(
×
36
        self, context: Sequence[str], attack: Sequence[str], response: Sequence[str]
37
    ) -> Sequence[float]:
38
        combined_uts = [j + k for j, k in zip(attack, response)]
×
39
        combined_scores: Sequence[float] = self.moderator.moderate(combined_uts)
×
40
        defender_scores: Sequence[float] = self.moderator.moderate(response)
×
NEW
41
        attack_ppls: torch.Tensor = self._get_target_logprobs_and_validate(
×
42
            context, attack
43
        ).exp()
44

45
        reward: Sequence[float] = (
×
46
            (
47
                self.ppl_weight * attack_ppls
48
                + self.combined_weight
49
                * torch.tensor(combined_scores).to(attack_ppls.device)
50
                + self.defense_toxicity_weight
51
                * torch.tensor(defender_scores).to(attack_ppls.device)
52
            )
53
            .cpu()
54
            .tolist()
55
        )
56

57
        return reward
×
58

59

60
ASTNode = Node[str, str]
×
61

62

63
class ASTTreeRolloutGenerator(RolloutGenerator[str, str]):
×
64
    def __init__(
×
65
        self,
66
        problem: ASTProblem,
67
        prompts: Sequence[str],
68
        tree_width: int = 2,
69
        tree_depth: int = 3,
70
    ):
71
        super().__init__(problem)
×
72

73
        self.prompts = prompts
×
74
        self.tree_width = tree_width
×
75
        self.tree_depth = tree_depth
×
76

77
    def __handle_prompt(self, prompt: str, depth: int = 3) -> Sequence[Node[str, str]]:
×
78
        if depth == 0:
×
79
            return []
×
80

81
        prompts = [prompt for _ in range(self.tree_width)]
×
NEW
82
        attacks = self.problem._rollout_prompt_with_attacker_and_validate(prompts)
×
NEW
83
        defenses = self.problem._rollout_prompt_with_target_and_validate(
×
84
            [prompt + i for i in attacks]
85
        )
86
        rewards = self.problem.reward(prompts, attacks, defenses)
×
87

88
        nodes = [
×
89
            Node(
90
                prompt,
91
                attack,
92
                defense,
93
                reward,
94
                self.__handle_prompt(
95
                    self.problem.advance(prompt, attack, defense), depth - 1
96
                ),
97
            )
98
            for prompt, attack, defense, reward in zip(
99
                prompts, attacks, defenses, rewards
100
            )
101
        ]
102

103
        return nodes
×
104

105
    def rollout(self, seed: Optional[int] = None) -> Graph[str, str]:
×
106
        R: Union[Random, ModuleType] = random
×
107
        if seed:
×
108
            R = Random(seed)
×
109
        prompt = R.choice(self.prompts)
×
110
        nodes = self.__handle_prompt(prompt, self.tree_depth)
×
111

112
        return Graph(prompt, nodes)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc