• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 18474409012

13 Oct 2025 06:07PM UTC coverage: 40.444% (+1.7%) from 38.778%
18474409012

Pull #27

github

web-flow
Merge 8971d248e into fa925eab6
Pull Request #27: WIP: Package Generalization

100 of 189 new or added lines in 10 files covered. (52.91%)

8 existing lines in 5 files now uncovered.

383 of 947 relevant lines covered (40.44%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

41.67
/src/astra_rl/algorithms/ppo.py
1
from dataclasses import dataclass
2✔
2
from abc import ABC
2✔
3
from typing import Generic, Sequence, List, Any, Dict
2✔
4

5
from astra_rl.core.algorithm import Algorithm
2✔
6
from astra_rl.core.system import ValueFunctionSystem
2✔
7
from astra_rl.core.common import StateT, ActionT
2✔
8
from astra_rl.core.sampler import Graph
2✔
9

10
import torch
2✔
11
import torch.nn.functional as F
2✔
12

13

14
@dataclass
2✔
15
class PPOStep(Generic[StateT, ActionT]):
2✔
16
    prefix: StateT
2✔
17
    suffix: ActionT
2✔
18
    reward: float
2✔
19

20

21
@dataclass
2✔
22
class PPOBatch(Generic[StateT, ActionT]):
2✔
23
    prefix: Sequence[StateT]
2✔
24
    suffix: Sequence[ActionT]
2✔
25
    reward: Sequence[float]
2✔
26

27

28
class PPO(
2✔
29
    Algorithm[StateT, ActionT, PPOStep[StateT, ActionT], PPOBatch[StateT, ActionT]],
30
    ABC,
31
):
32
    """Proximal Policy Optimization (PPO) algorithm with value function."""
33

34
    def __init__(
2✔
35
        self,
36
        system: ValueFunctionSystem[StateT, ActionT],
37
        clip_range: float = 0.1,
38
        vf_loss_coef: float = 1.0,
39
    ):
40
        super().__init__(system)
×
41

42
        self.system: ValueFunctionSystem[StateT, ActionT] = system
×
43
        self.clip_range = clip_range
×
44
        self.vf_loss_coef = vf_loss_coef
×
45

46
    def flatten(
2✔
47
        self, graph: Graph[StateT, ActionT]
48
    ) -> Sequence[PPOStep[StateT, ActionT]]:
49
        # in DPO, we sample from each branch the most rewarded
50
        # and least rewarded actions in order to use them as our contrastive
51
        # pairs.
52

53
        res: List[PPOStep[StateT, ActionT]] = []
×
54
        bfs = [graph.children]
×
55
        while len(bfs):
×
56
            front = bfs.pop(0)
×
57
            if len(list(front)) < 2:
×
58
                # if there is no pair, we skip this node
59
                continue
×
60

61
            for i in front:
×
NEW
62
                res.append(
×
63
                    PPOStep(prefix=i.context, suffix=i.utterance, reward=i.reward)
64
                )
UNCOV
65
                bfs.append(i.children)
×
66

67
        return res
×
68

69
    @staticmethod
2✔
70
    def collate_fn(x: Sequence[PPOStep[StateT, ActionT]]) -> PPOBatch[StateT, ActionT]:
2✔
71
        prefixes = [i.prefix for i in x]
×
72
        suffix = [i.suffix for i in x]
×
73
        rewards = [i.reward for i in x]
×
74

75
        return PPOBatch(prefix=prefixes, suffix=suffix, reward=rewards)
×
76

77
    def step(
2✔
78
        self, batch: PPOBatch[StateT, ActionT]
79
    ) -> tuple[torch.Tensor, Dict[Any, Any]]:
80
        logprobs_tester = self.system._get_tester_logprobs_and_validate(
×
81
            batch.prefix, batch.suffix
82
        )
83
        logprobs_baseline = self.system._get_baseline_logprobs_and_validate(
×
84
            batch.prefix, batch.suffix
85
        )
86
        values = self.system.value(batch.prefix, batch.suffix)
×
87

88
        # Q(s,a) = R(s,a), which is jank but seems to be the standard
89
        # also its bootstrapped without discount throughout the stream
90
        Q = (
×
91
            torch.tensor(batch.reward)
92
            .to(logprobs_tester.device)
93
            .unsqueeze(-1)
94
            .unsqueeze(-1)
95
            .repeat(1, *values.shape[1:])
96
        )
97
        A = Q - values
×
98

99
        # normalize advantages
100
        if A.size(-1) == 1:
×
101
            A = ((A - A.mean()) / (A.std() + 1e-8)).squeeze(-1)
×
102
        else:
103
            A = (A - A.mean()) / (A.std() + 1e-8)
×
104
        # compute ratio, should be 1 at the first iteration
105
        ratio = torch.exp((logprobs_tester - logprobs_baseline.detach()))
×
106

107
        # compute clipped surrogate lolss
108
        policy_loss_1 = A * ratio
×
109
        policy_loss_2 = A * torch.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range)
×
110
        policy_loss_2 = A * torch.clamp(ratio, 1 - 0.1, 1 + 0.1)
×
111
        policy_loss = -(torch.min(policy_loss_1, policy_loss_2)).mean()
×
112

113
        # compute value loss
114
        value_loss = F.mse_loss(Q, values)
×
115

116
        # compute final lossvalue_loss
117
        loss = policy_loss + self.vf_loss_coef * value_loss
×
118

119
        # create logging dict
120
        logging_dict: Dict[Any, Any] = {
×
121
            "training/loss": loss.mean().cpu().item(),
122
            "training/policy_loss": policy_loss.mean().cpu().item(),
123
            "training/value_loss": value_loss.mean().cpu().item(),
124
            "reward/mean_reward": torch.tensor(batch.reward).mean().cpu().item(),
125
            "reward/std_reward": torch.tensor(batch.reward).std().cpu().item(),
126
            "policy/logprobs": logprobs_tester.mean().detach().cpu().item(),
127
            "ref/logprobs": logprobs_baseline.mean().detach().cpu().item(),
128
        }
129

130
        return loss, logging_dict
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc