• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 18474409012

13 Oct 2025 06:07PM UTC coverage: 40.444% (+1.7%) from 38.778%
18474409012

Pull #27

github

web-flow
Merge 8971d248e into fa925eab6
Pull Request #27: WIP: Package Generalization

100 of 189 new or added lines in 10 files covered. (52.91%)

8 existing lines in 5 files now uncovered.

383 of 947 relevant lines covered (40.44%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

36.0
/src/astra_rl/algorithms/dpo.py
1
from dataclasses import dataclass
2✔
2
from typing import Generic, Sequence, List, Any, Dict
2✔
3

4
from astra_rl.core.algorithm import Algorithm
2✔
5
from astra_rl.core.system import TrainableSystem
2✔
6
from astra_rl.core.common import StateT, ActionT
2✔
7
from astra_rl.core.sampler import Graph
2✔
8

9
import torch
2✔
10
import torch.nn.functional as F
2✔
11

12

13
@dataclass
2✔
14
class DPOStep(Generic[StateT, ActionT]):
2✔
15
    prefix: StateT
2✔
16

17
    suffix_pos: ActionT
2✔
18
    suffix_neg: ActionT
2✔
19

20

21
@dataclass
2✔
22
class DPOBatch(Generic[StateT, ActionT]):
2✔
23
    prefixes: Sequence[StateT]
2✔
24

25
    suffix_pos: Sequence[ActionT]
2✔
26
    suffix_neg: Sequence[ActionT]
2✔
27

28

29
class DPO(
2✔
30
    Algorithm[StateT, ActionT, DPOStep[StateT, ActionT], DPOBatch[StateT, ActionT]],
31
    Generic[StateT, ActionT],
32
):
33
    """Direct Preference Optimization (DPO) algorithm."""
34

35
    system: TrainableSystem[StateT, ActionT]
2✔
36

37
    def __init__(self, system: TrainableSystem[StateT, ActionT], beta: float = 0.1):
2✔
UNCOV
38
        super().__init__(system)
×
39

40
        self.beta = beta
×
41

42
    def flatten(
2✔
43
        self, graph: Graph[StateT, ActionT]
44
    ) -> Sequence[DPOStep[StateT, ActionT]]:
45
        # in DPO, we sample from each branch the most rewarded
46
        # and least rewarded actions in order to use them as our contrastive
47
        # pairs.
48

49
        pairs: List[DPOStep[StateT, ActionT]] = []
×
50
        bfs = [graph.children]
×
51
        while len(bfs):
×
52
            front = bfs.pop(0)
×
53
            sorted_list = sorted(list(front), key=lambda x: x.reward, reverse=True)
×
54

55
            if len(sorted_list) < 2:
×
56
                # if there is no pair, we skip this node
57
                continue
×
58

59
            pos_entry = sorted_list[0]
×
60
            neg_entry = sorted_list[-1]
×
61

62
            assert pos_entry.context == neg_entry.context, (
×
63
                "paired rollouts for DPO must share the same prefix!"
64
            )
65

66
            pairs.append(
×
67
                DPOStep(
68
                    prefix=pos_entry.context,
69
                    suffix_pos=pos_entry.utterance,
70
                    suffix_neg=neg_entry.utterance,
71
                )
72
            )
73

74
            for i in sorted_list:
×
75
                bfs.append(i.children)
×
76

77
        return pairs
×
78

79
    @staticmethod
2✔
80
    def collate_fn(x: Sequence[DPOStep[StateT, ActionT]]) -> DPOBatch[StateT, ActionT]:
2✔
81
        prefixes = [i.prefix for i in x]
×
82
        suffix_pos = [i.suffix_pos for i in x]
×
83
        suffix_neg = [i.suffix_neg for i in x]
×
84

85
        return DPOBatch(prefixes=prefixes, suffix_pos=suffix_pos, suffix_neg=suffix_neg)
×
86

87
    def step(
2✔
88
        self, batch: DPOBatch[StateT, ActionT]
89
    ) -> tuple[torch.Tensor, Dict[Any, Any]]:
90
        tester_logprobs_win = self.system._get_tester_logprobs_and_validate(
×
91
            batch.prefixes, batch.suffix_pos
92
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
93
        tester_logprobs_loss = self.system._get_tester_logprobs_and_validate(
×
94
            batch.prefixes, batch.suffix_neg
95
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
96
        baseline_logprobs_win = self.system._get_baseline_logprobs_and_validate(
×
97
            batch.prefixes, batch.suffix_pos
98
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
99
        baseline_logprobs_loss = self.system._get_baseline_logprobs_and_validate(
×
100
            batch.prefixes, batch.suffix_neg
101
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
102

103
        # https://github.com/eric-mitchell/direct-preference-optimization/blob/ \
104
        # f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L70-L87
105
        pi_logratios = tester_logprobs_win - tester_logprobs_loss
×
106
        ref_logratios = baseline_logprobs_win - baseline_logprobs_loss
×
107
        logits = pi_logratios - ref_logratios
×
108

109
        loss = -F.logsigmoid(self.beta * logits)
×
110

111
        # Calculate addition quantities
112
        # TODO: CHECK ME for correctness and completion!
113
        chosen_rewards = self.beta * (tester_logprobs_win - baseline_logprobs_win)
×
114
        rejected_rewards = self.beta * (tester_logprobs_loss - baseline_logprobs_loss)
×
115
        reward_accuracies = (chosen_rewards > rejected_rewards).float()
×
116
        reward_margin = chosen_rewards - rejected_rewards
×
117

118
        logging_dict: Dict[Any, Any] = {
×
119
            "training/loss": loss.mean().cpu().item(),
120
            "reward/chosen_rewards": chosen_rewards.mean().cpu().item(),
121
            "reward/rejected_rewards": rejected_rewards.mean().cpu().item(),
122
            "reward/reward_accuracies": reward_accuracies.mean().cpu().item(),
123
            "reward/reward_margin": reward_margin.mean().cpu().item(),
124
            "policy/logprobs_chosen": tester_logprobs_win.mean().detach().cpu().item(),
125
            "policy/logprobs_rejected": tester_logprobs_loss.mean()
126
            .detach()
127
            .cpu()
128
            .item(),
129
            "ref/logprobs_chosen": baseline_logprobs_win.mean().detach().cpu().item(),
130
            "ref/logprobs_rejected": baseline_logprobs_loss.mean()
131
            .detach()
132
            .cpu()
133
            .item(),
134
        }
135
        # TODO: Add this from old code?
136
        # "policy/rollout": wandb.Html(str(r"<span>"+batch["prompt_win"][0][0]+"</span><span style='color:Tomato;'>"+batch["prompt_win"][0][1]+r"</span><span style='color:DodgerBlue'>"+batch["prompt_win"][0][2]+r"</span>")),
137

138
        return loss.mean(), logging_dict
×
139

140

141
class IPO(DPO[StateT, ActionT]):
2✔
142
    def step(
2✔
143
        self, batch: DPOBatch[StateT, ActionT]
144
    ) -> tuple[torch.Tensor, Dict[Any, Any]]:
145
        tester_logprobs_win = self.system._get_tester_logprobs_and_validate(
×
146
            batch.prefixes, batch.suffix_pos
147
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
148
        tester_logprobs_loss = self.system._get_tester_logprobs_and_validate(
×
149
            batch.prefixes, batch.suffix_neg
150
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
151
        baseline_logprobs_win = self.system._get_baseline_logprobs_and_validate(
×
152
            batch.prefixes, batch.suffix_pos
153
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
154
        baseline_logprobs_loss = self.system._get_baseline_logprobs_and_validate(
×
155
            batch.prefixes, batch.suffix_neg
156
        ).sum(dim=-1)  # Sum per-token logprobs to get sequence logprobs
157

158
        # https://github.com/eric-mitchell/direct-preference-optimization/blob/ \
159
        # f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L70-L87
160
        pi_logratios = tester_logprobs_win - tester_logprobs_loss
×
161
        ref_logratios = baseline_logprobs_win - baseline_logprobs_loss
×
162
        logits = pi_logratios - ref_logratios
×
163

164
        loss = (logits - 1 / (2 * self.beta)) ** 2
×
165

166
        # Calculate addition quantities
167
        # TODO: CHECK ME for correctness and completion!
168
        chosen_rewards = self.beta * (tester_logprobs_win - baseline_logprobs_win)
×
169
        rejected_rewards = self.beta * (tester_logprobs_loss - baseline_logprobs_loss)
×
170
        reward_accuracies = (chosen_rewards > rejected_rewards).float()
×
171
        reward_margin = chosen_rewards - rejected_rewards
×
172

173
        logging_dict: Dict[Any, Any] = {
×
174
            "training/loss": loss.mean().cpu().item(),
175
            "reward/chosen_rewards": chosen_rewards.mean().cpu().item(),
176
            "reward/rejected_rewards": rejected_rewards.mean().cpu().item(),
177
            "reward/reward_accuracies": reward_accuracies.mean().cpu().item(),
178
            "reward/reward_margin": reward_margin.mean().cpu().item(),
179
            "policy/logprobs_chosen": tester_logprobs_win.mean().detach().cpu().item(),
180
            "policy/logprobs_rejected": tester_logprobs_loss.mean()
181
            .detach()
182
            .cpu()
183
            .item(),
184
            "ref/logprobs_chosen": baseline_logprobs_win.mean().detach().cpu().item(),
185
            "ref/logprobs_rejected": baseline_logprobs_loss.mean()
186
            .detach()
187
            .cpu()
188
            .item(),
189
        }
190
        # TODO: Add this from old code?
191
        # "policy/rollout": wandb.Html(str(r"<span>"+batch["prompt_win"][0][0]+"</span><span style='color:Tomato;'>"+batch["prompt_win"][0][1]+r"</span><span style='color:DodgerBlue'>"+batch["prompt_win"][0][2]+r"</span>")),
192

193
        return loss.mean(), logging_dict
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc