• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 16262675714

14 Jul 2025 04:49AM UTC coverage: 46.419% (-48.0%) from 94.444%
16262675714

push

github

web-flow
Merge pull request #3 from sisl/feat/core

Initial implementation of core AST algorithm.

160 of 361 new or added lines in 18 files covered. (44.32%)

175 of 377 relevant lines covered (46.42%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

40.63
/src/astra_rl/algorithms/dpo.py
1
from dataclasses import dataclass
2✔
2
from typing import Generic, Sequence, List
2✔
3

4
from astra_rl.core.algorithm import Algorithm
2✔
5
from astra_rl.core.problem import Problem
2✔
6
from astra_rl.core.common import StateT, ActionT
2✔
7
from astra_rl.core.environment import Graph
2✔
8

9
import torch
2✔
10
import torch.nn.functional as F
2✔
11

12

13
@dataclass
2✔
14
class DPOStep(Generic[StateT, ActionT]):
2✔
15
    prefix: StateT
2✔
16

17
    suffix_pos: ActionT
2✔
18
    suffix_neg: ActionT
2✔
19

20

21
@dataclass
2✔
22
class DPOBatch(Generic[StateT, ActionT]):
2✔
23
    prefixes: Sequence[StateT]
2✔
24

25
    suffix_pos: Sequence[ActionT]
2✔
26
    suffix_neg: Sequence[ActionT]
2✔
27

28

29
class DPO(
2✔
30
    Algorithm[StateT, ActionT, DPOStep[StateT, ActionT], DPOBatch[StateT, ActionT]],
31
    Generic[StateT, ActionT],
32
):
33
    def __init__(self, problem: Problem[StateT, ActionT], beta: float = 0.1):
2✔
NEW
34
        super().__init__(problem)
×
35

NEW
36
        self.beta = beta
×
37

38
    def flatten(
2✔
39
        self, graph: Graph[StateT, ActionT]
40
    ) -> Sequence[DPOStep[StateT, ActionT]]:
41
        # in DPO, we sample from each branch the most rewarded
42
        # and least rewarded actions in order to use them as our contrastive
43
        # pairs.
44

NEW
45
        pairs: List[DPOStep[StateT, ActionT]] = []
×
NEW
46
        bfs = [graph.children]
×
NEW
47
        while len(bfs):
×
NEW
48
            front = bfs.pop(0)
×
NEW
49
            sorted_list = sorted(list(front), key=lambda x: x.reward, reverse=True)
×
50

NEW
51
            if len(sorted_list) < 2:
×
52
                # if there is no pair, we skip this node
NEW
53
                continue
×
54

NEW
55
            pos_entry = sorted_list[0]
×
NEW
56
            neg_entry = sorted_list[-1]
×
57

NEW
58
            assert pos_entry.context == neg_entry.context, (
×
59
                "paired rollouts for DPO must share the same prefix!"
60
            )
61

NEW
62
            pairs.append(
×
63
                DPOStep(
64
                    prefix=pos_entry.context,
65
                    suffix_pos=pos_entry.attack,
66
                    suffix_neg=neg_entry.attack,
67
                )
68
            )
69

NEW
70
            for i in sorted_list:
×
NEW
71
                bfs.append(i.children)
×
72

NEW
73
        return pairs
×
74

75
    @staticmethod
2✔
76
    def collate_fn(x: Sequence[DPOStep[StateT, ActionT]]) -> DPOBatch[StateT, ActionT]:
2✔
NEW
77
        prefixes = [i.prefix for i in x]
×
NEW
78
        suffix_pos = [i.suffix_pos for i in x]
×
NEW
79
        suffix_neg = [i.suffix_neg for i in x]
×
80

NEW
81
        return DPOBatch(prefixes=prefixes, suffix_pos=suffix_pos, suffix_neg=suffix_neg)
×
82

83
    def step(self, batch: DPOBatch[StateT, ActionT]) -> torch.Tensor:
2✔
NEW
84
        attacker_logprobs_win = self.problem._get_attacker_logprobs_and_validate(
×
85
            batch.prefixes, batch.suffix_pos
86
        )
NEW
87
        attacker_logprobs_loss = self.problem._get_attacker_logprobs_and_validate(
×
88
            batch.prefixes, batch.suffix_pos
89
        )
NEW
90
        baseline_logprobs_win = self.problem._get_baseline_logprobs_and_validate(
×
91
            batch.prefixes, batch.suffix_pos
92
        )
NEW
93
        baseline_logprobs_loss = self.problem._get_baseline_logprobs_and_validate(
×
94
            batch.prefixes, batch.suffix_neg
95
        )
96

97
        # https://github.com/eric-mitchell/direct-preference-optimization/blob/ \
98
        # f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L70-L87
NEW
99
        pi_logratios = attacker_logprobs_win - attacker_logprobs_loss
×
NEW
100
        ref_logratios = baseline_logprobs_win - baseline_logprobs_loss
×
NEW
101
        logits = pi_logratios - ref_logratios
×
102

NEW
103
        loss = -F.logsigmoid(self.beta * logits)
×
104

105
        # TODO! how do we do logging?
106
        # ideally there's a logging package / logger for metrics that I can just log to
107
        # chosen_rewards = self.beta * (attacker_logprob_win - reference_logprobs_win).detach()
108
        # rejected_rewards = self.beta * (attacker_logprob_loose - referenge_logprobs_loose).detach()
109

NEW
110
        return loss.mean()
×
111

112

113
class IPO(DPO[StateT, ActionT]):
2✔
114
    def step(self, batch: DPOBatch[StateT, ActionT]) -> torch.Tensor:
2✔
NEW
115
        attacker_logprobs_win = self.problem._get_attacker_logprobs_and_validate(
×
116
            batch.prefixes, batch.suffix_pos
117
        )
NEW
118
        attacker_logprobs_loss = self.problem._get_attacker_logprobs_and_validate(
×
119
            batch.prefixes, batch.suffix_pos
120
        )
NEW
121
        baseline_logprobs_win = self.problem._get_baseline_logprobs_and_validate(
×
122
            batch.prefixes, batch.suffix_pos
123
        )
NEW
124
        baseline_logprobs_loss = self.problem._get_baseline_logprobs_and_validate(
×
125
            batch.prefixes, batch.suffix_neg
126
        )
127

128
        # https://github.com/eric-mitchell/direct-preference-optimization/blob/ \
129
        # f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L70-L87
NEW
130
        pi_logratios = attacker_logprobs_win - attacker_logprobs_loss
×
NEW
131
        ref_logratios = baseline_logprobs_win - baseline_logprobs_loss
×
NEW
132
        logits = pi_logratios - ref_logratios
×
133

NEW
134
        loss = (logits - 1 / (2 * self.beta)) ** 2
×
135

NEW
136
        return loss.mean()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc