• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 16457978964

22 Jul 2025 11:47PM UTC coverage: 44.198% (-2.2%) from 46.419%
16457978964

Pull #4

github

maxlampe
fix: Logprob calculation error in DPO/IPO
Pull Request #4: Wandb logging

10 of 40 new or added lines in 4 files covered. (25.0%)

2 existing lines in 1 file now uncovered.

179 of 405 relevant lines covered (44.2%)

0.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

35.14
/src/astra_rl/algorithms/dpo.py
1
from dataclasses import dataclass
2✔
2
from typing import Generic, Sequence, List, Any
2✔
3

4
from astra_rl.core.algorithm import Algorithm
2✔
5
from astra_rl.core.problem import Problem
2✔
6
from astra_rl.core.common import StateT, ActionT
2✔
7
from astra_rl.core.environment import Graph
2✔
8

9
import torch
2✔
10
import torch.nn.functional as F
2✔
11

12

13
@dataclass
2✔
14
class DPOStep(Generic[StateT, ActionT]):
2✔
15
    prefix: StateT
2✔
16

17
    suffix_pos: ActionT
2✔
18
    suffix_neg: ActionT
2✔
19

20

21
@dataclass
2✔
22
class DPOBatch(Generic[StateT, ActionT]):
2✔
23
    prefixes: Sequence[StateT]
2✔
24

25
    suffix_pos: Sequence[ActionT]
2✔
26
    suffix_neg: Sequence[ActionT]
2✔
27

28

29
class DPO(
2✔
30
    Algorithm[StateT, ActionT, DPOStep[StateT, ActionT], DPOBatch[StateT, ActionT]],
31
    Generic[StateT, ActionT],
32
):
33
    def __init__(self, problem: Problem[StateT, ActionT], beta: float = 0.1):
2✔
34
        super().__init__(problem)
×
35

36
        self.beta = beta
×
37

38
    def flatten(
2✔
39
        self, graph: Graph[StateT, ActionT]
40
    ) -> Sequence[DPOStep[StateT, ActionT]]:
41
        # in DPO, we sample from each branch the most rewarded
42
        # and least rewarded actions in order to use them as our contrastive
43
        # pairs.
44

45
        pairs: List[DPOStep[StateT, ActionT]] = []
×
46
        bfs = [graph.children]
×
47
        while len(bfs):
×
48
            front = bfs.pop(0)
×
49
            sorted_list = sorted(list(front), key=lambda x: x.reward, reverse=True)
×
50

51
            if len(sorted_list) < 2:
×
52
                # if there is no pair, we skip this node
53
                continue
×
54

55
            pos_entry = sorted_list[0]
×
56
            neg_entry = sorted_list[-1]
×
57

58
            assert pos_entry.context == neg_entry.context, (
×
59
                "paired rollouts for DPO must share the same prefix!"
60
            )
61

62
            pairs.append(
×
63
                DPOStep(
64
                    prefix=pos_entry.context,
65
                    suffix_pos=pos_entry.attack,
66
                    suffix_neg=neg_entry.attack,
67
                )
68
            )
69

70
            for i in sorted_list:
×
71
                bfs.append(i.children)
×
72

73
        return pairs
×
74

75
    @staticmethod
2✔
76
    def collate_fn(x: Sequence[DPOStep[StateT, ActionT]]) -> DPOBatch[StateT, ActionT]:
2✔
77
        prefixes = [i.prefix for i in x]
×
78
        suffix_pos = [i.suffix_pos for i in x]
×
79
        suffix_neg = [i.suffix_neg for i in x]
×
80

81
        return DPOBatch(prefixes=prefixes, suffix_pos=suffix_pos, suffix_neg=suffix_neg)
×
82

83
    def step(
2✔
84
        self, batch: DPOBatch[StateT, ActionT]
85
    ) -> tuple[torch.Tensor, dict[Any, Any]]:
UNCOV
86
        attacker_logprobs_win = self.problem._get_attacker_logprobs_and_validate(
×
87
            batch.prefixes, batch.suffix_pos
88
        )
89
        attacker_logprobs_loss = self.problem._get_attacker_logprobs_and_validate(
×
90
            batch.prefixes, batch.suffix_neg
91
        )
92
        baseline_logprobs_win = self.problem._get_baseline_logprobs_and_validate(
×
93
            batch.prefixes, batch.suffix_pos
94
        )
95
        baseline_logprobs_loss = self.problem._get_baseline_logprobs_and_validate(
×
96
            batch.prefixes, batch.suffix_neg
97
        )
98

99
        # https://github.com/eric-mitchell/direct-preference-optimization/blob/ \
100
        # f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L70-L87
101
        pi_logratios = attacker_logprobs_win - attacker_logprobs_loss
×
102
        ref_logratios = baseline_logprobs_win - baseline_logprobs_loss
×
103
        logits = pi_logratios - ref_logratios
×
104

105
        loss = -F.logsigmoid(self.beta * logits)
×
106

107
        # Calculate addition quantities
108
        # TODO: CHECK ME for correctness and completion!
NEW
109
        chosen_rewards = self.beta * (attacker_logprobs_win - baseline_logprobs_win)
×
NEW
110
        rejected_rewards = self.beta * (attacker_logprobs_loss - baseline_logprobs_loss)
×
NEW
111
        reward_accuracies = (chosen_rewards > rejected_rewards).float()
×
NEW
112
        reward_margin = chosen_rewards - rejected_rewards
×
113

NEW
114
        logging_dict: dict[Any, Any] = {
×
115
            "training/loss": loss.mean().cpu().item(),
116
            "reward/chosen_rewards": chosen_rewards.mean().cpu().item(),
117
            "reward/rejected_rewards": rejected_rewards.mean().cpu().item(),
118
            "reward/reward_accuracies": reward_accuracies.mean().cpu().item(),
119
            "reward/reward_margin": reward_margin.mean().cpu().item(),
120
            "policy/logprobs_chosen": attacker_logprobs_win.mean()
121
            .detach()
122
            .cpu()
123
            .item(),
124
            "policy/logprobs_rejected": attacker_logprobs_loss.mean()
125
            .detach()
126
            .cpu()
127
            .item(),
128
            "ref/logprobs_chosen": baseline_logprobs_win.mean().detach().cpu().item(),
129
            "ref/logprobs_rejected": baseline_logprobs_loss.mean()
130
            .detach()
131
            .cpu()
132
            .item(),
133
        }
134
        # TODO: Add this from old code?
135
        # "policy/rollout": wandb.Html(str(r"<span>"+batch["prompt_win"][0][0]+"</span><span style='color:Tomato;'>"+batch["prompt_win"][0][1]+r"</span><span style='color:DodgerBlue'>"+batch["prompt_win"][0][2]+r"</span>")),
136

NEW
137
        return loss.mean(), logging_dict
×
138

139

140
class IPO(DPO[StateT, ActionT]):
2✔
141
    def step(
2✔
142
        self, batch: DPOBatch[StateT, ActionT]
143
    ) -> tuple[torch.Tensor, dict[Any, Any]]:
UNCOV
144
        attacker_logprobs_win = self.problem._get_attacker_logprobs_and_validate(
×
145
            batch.prefixes, batch.suffix_pos
146
        )
147
        attacker_logprobs_loss = self.problem._get_attacker_logprobs_and_validate(
×
148
            batch.prefixes, batch.suffix_neg
149
        )
150
        baseline_logprobs_win = self.problem._get_baseline_logprobs_and_validate(
×
151
            batch.prefixes, batch.suffix_pos
152
        )
153
        baseline_logprobs_loss = self.problem._get_baseline_logprobs_and_validate(
×
154
            batch.prefixes, batch.suffix_neg
155
        )
156

157
        # https://github.com/eric-mitchell/direct-preference-optimization/blob/ \
158
        # f8b8c0f49dc92a430bae41585f9d467d3618fe2f/trainers.py#L70-L87
159
        pi_logratios = attacker_logprobs_win - attacker_logprobs_loss
×
160
        ref_logratios = baseline_logprobs_win - baseline_logprobs_loss
×
161
        logits = pi_logratios - ref_logratios
×
162

163
        loss = (logits - 1 / (2 * self.beta)) ** 2
×
164

165
        # Calculate addition quantities
166
        # TODO: CHECK ME for correctness and completion!
NEW
167
        chosen_rewards = self.beta * (attacker_logprobs_win - baseline_logprobs_win)
×
NEW
168
        rejected_rewards = self.beta * (attacker_logprobs_loss - baseline_logprobs_loss)
×
NEW
169
        reward_accuracies = (chosen_rewards > rejected_rewards).float()
×
NEW
170
        reward_margin = chosen_rewards - rejected_rewards
×
171

NEW
172
        logging_dict: dict[Any, Any] = {
×
173
            "training/loss": loss.mean().cpu().item(),
174
            "reward/chosen_rewards": chosen_rewards.mean().cpu().item(),
175
            "reward/rejected_rewards": rejected_rewards.mean().cpu().item(),
176
            "reward/reward_accuracies": reward_accuracies.mean().cpu().item(),
177
            "reward/reward_margin": reward_margin.mean().cpu().item(),
178
            "policy/logprobs_chosen": attacker_logprobs_win.mean()
179
            .detach()
180
            .cpu()
181
            .item(),
182
            "policy/logprobs_rejected": attacker_logprobs_loss.mean()
183
            .detach()
184
            .cpu()
185
            .item(),
186
            "ref/logprobs_chosen": baseline_logprobs_win.mean().detach().cpu().item(),
187
            "ref/logprobs_rejected": baseline_logprobs_loss.mean()
188
            .detach()
189
            .cpu()
190
            .item(),
191
        }
192
        # TODO: Add this from old code?
193
        # "policy/rollout": wandb.Html(str(r"<span>"+batch["prompt_win"][0][0]+"</span><span style='color:Tomato;'>"+batch["prompt_win"][0][1]+r"</span><span style='color:DodgerBlue'>"+batch["prompt_win"][0][2]+r"</span>")),
194

NEW
195
        return loss.mean(), logging_dict
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc