• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 17082416515

19 Aug 2025 09:36PM UTC coverage: 44.677% (+0.4%) from 44.276%
17082416515

Pull #12

github

Jemoka
Merge remote-tracking branch 'origin/main' into feat/ppo
Pull Request #12: PPO

32 of 71 new or added lines in 6 files covered. (45.07%)

1 existing line in 1 file now uncovered.

235 of 526 relevant lines covered (44.68%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

41.67
/src/astra_rl/algorithms/ppo.py
1
from dataclasses import dataclass
2✔
2
from abc import ABC
2✔
3
from typing import Generic, Sequence, List, Any, Dict
2✔
4

5
from astra_rl.core.algorithm import Algorithm
2✔
6
from astra_rl.core.problem import ValueFunctionProblem
2✔
7
from astra_rl.core.common import StateT, ActionT
2✔
8
from astra_rl.core.environment import Graph
2✔
9

10
import torch
2✔
11
import torch.nn.functional as F
2✔
12

13

14
@dataclass
2✔
15
class PPOStep(Generic[StateT, ActionT]):
2✔
16
    prefix: StateT
2✔
17
    suffix: ActionT
2✔
18
    reward: float
2✔
19

20

21
@dataclass
2✔
22
class PPOBatch(Generic[StateT, ActionT]):
2✔
23
    prefix: Sequence[StateT]
2✔
24
    suffix: Sequence[ActionT]
2✔
25
    reward: Sequence[float]
2✔
26

27

28
class PPO(
2✔
29
    Algorithm[StateT, ActionT, PPOStep[StateT, ActionT], PPOBatch[StateT, ActionT]],
30
    ABC,
31
):
32
    """Proximal Policy Optimization (PPO) algorithm with value function."""
33

34
    def __init__(
2✔
35
        self,
36
        problem: ValueFunctionProblem[StateT, ActionT],
37
        clip_range: float = 0.1,
38
        vf_loss_coef: float = 1.0,
39
    ):
NEW
40
        super().__init__(problem)
×
41

NEW
42
        self.problem: ValueFunctionProblem[StateT, ActionT] = problem
×
NEW
43
        self.clip_range = clip_range
×
NEW
44
        self.vf_loss_coef = vf_loss_coef
×
45

46
    def flatten(
2✔
47
        self, graph: Graph[StateT, ActionT]
48
    ) -> Sequence[PPOStep[StateT, ActionT]]:
49
        # in DPO, we sample from each branch the most rewarded
50
        # and least rewarded actions in order to use them as our contrastive
51
        # pairs.
52

NEW
53
        res: List[PPOStep[StateT, ActionT]] = []
×
NEW
54
        bfs = [graph.children]
×
NEW
55
        while len(bfs):
×
NEW
56
            front = bfs.pop(0)
×
NEW
57
            if len(list(front)) < 2:
×
58
                # if there is no pair, we skip this node
NEW
59
                continue
×
60

NEW
61
            for i in front:
×
NEW
62
                res.append(PPOStep(prefix=i.context, suffix=i.attack, reward=i.reward))
×
NEW
63
                bfs.append(i.children)
×
64

NEW
65
        return res
×
66

67
    @staticmethod
2✔
68
    def collate_fn(x: Sequence[PPOStep[StateT, ActionT]]) -> PPOBatch[StateT, ActionT]:
2✔
NEW
69
        prefixes = [i.prefix for i in x]
×
NEW
70
        suffix = [i.suffix for i in x]
×
NEW
71
        rewards = [i.reward for i in x]
×
72

NEW
73
        return PPOBatch(prefix=prefixes, suffix=suffix, reward=rewards)
×
74

75
    def step(
2✔
76
        self, batch: PPOBatch[StateT, ActionT]
77
    ) -> tuple[torch.Tensor, Dict[Any, Any]]:
NEW
78
        logprobs_attacker = self.problem._get_attacker_logprobs_and_validate(
×
79
            batch.prefix, batch.suffix
80
        )
NEW
81
        logprobs_baseline = self.problem._get_baseline_logprobs_and_validate(
×
82
            batch.prefix, batch.suffix
83
        )
NEW
84
        values = self.problem.value(batch.prefix, batch.suffix)
×
85

86
        # Q(s,a) = R(s,a), which is jank but seems to be the standard
87
        # also its bootstrapped without discount throughout the stream
NEW
88
        Q = (
×
89
            torch.tensor(batch.reward)
90
            .to(logprobs_attacker.device)
91
            .unsqueeze(-1)
92
            .unsqueeze(-1)
93
            .repeat(1, *values.shape[1:])
94
        )
NEW
95
        A = Q - values
×
96

97
        # normalize advantages
NEW
98
        if A.size(-1) == 1:
×
NEW
99
            A = ((A - A.mean()) / (A.std() + 1e-8)).squeeze(-1)
×
100
        else:
NEW
101
            A = (A - A.mean()) / (A.std() + 1e-8)
×
102
        # compute ratio, should be 1 at the first iteration
NEW
103
        ratio = torch.exp((logprobs_attacker - logprobs_baseline.detach()))
×
104

105
        # compute clipped surrogate lolss
NEW
106
        policy_loss_1 = A * ratio
×
NEW
107
        policy_loss_2 = A * torch.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range)
×
NEW
108
        policy_loss_2 = A * torch.clamp(ratio, 1 - 0.1, 1 + 0.1)
×
NEW
109
        policy_loss = -(torch.min(policy_loss_1, policy_loss_2)).mean()
×
110

111
        # compute value loss
NEW
112
        value_loss = F.mse_loss(Q, values)
×
113

114
        # compute final lossvalue_loss
NEW
115
        loss = policy_loss + self.vf_loss_coef * value_loss
×
116

117
        # create logging dict
NEW
118
        logging_dict: Dict[Any, Any] = {
×
119
            "training/loss": loss.mean().cpu().item(),
120
            "training/policy_loss": policy_loss.mean().cpu().item(),
121
            "training/value_loss": value_loss.mean().cpu().item(),
122
            "reward/mean_reward": torch.tensor(batch.reward).mean().cpu().item(),
123
            "reward/std_reward": torch.tensor(batch.reward).std().cpu().item(),
124
            "policy/logprobs": logprobs_attacker.mean().detach().cpu().item(),
125
            "ref/logprobs": logprobs_baseline.mean().detach().cpu().item(),
126
        }
127

NEW
128
        return loss, logging_dict
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc