• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 21745123035

02 Dec 2025 06:59PM UTC coverage: 35.456%. Remained the same
21745123035

push

github

web-flow
De-Bug examples, add training logging/wandb, update quick start training (#32)

# Pull Request

## Description

- Small updates and de-bugging in examples. 
- Added logging and wandb to trainer class
- Updates to quick start training to use correct data and
trainer/trainer config

## Issues
- still need to make some custom classes import from astra_rl

2 of 70 new or added lines in 3 files covered. (2.86%)

142 existing lines in 3 files now uncovered.

323 of 911 relevant lines covered (35.46%)

0.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

40.43
/src/astra_rl/training/trainer.py
1
"""
2
trainer.py
3
The trainer is an opinionated system designed for making training new models easy. To gain full customization over the model training pipeline, we recommend using the lower-level `Harness` system in `harness.py`.
4
"""
5

6
import torch
2✔
7
from typing import Generic
2✔
8
from pydantic import BaseModel
2✔
9
from torch.optim import Optimizer
2✔
10

11
from astra_rl.training.harness import Harness
2✔
12
from astra_rl.core.sampler import Sampler
2✔
13
from astra_rl.core.algorithm import Algorithm
2✔
14
from astra_rl.core.common import ActionT, StateT, Batch, Step
2✔
15

16

17
class TrainingConfiguration(BaseModel):
2✔
18
    """A typechecked dataclass which configures the training procedure.
19

20
    Attributes:
21
        lr (float): Learning rate for the optimizer.
22
        batch_size (int): Size of each batch (after flattening from experience) for training.
23
        optimizer (str): Type of optimizer to use [choices: "adam", "adamw", "sgd", "rmsprop", "adagrad"].
24
        gradient_accumulation_steps (int): Number of steps to accumulate gradients before updating the model weights.
25
        training_steps (int): Total number of rollouts to run and train for.
26
        num_episodes_per_experience (int): Number of rollouts to run before making a gradient update.
27
    """
28

29
    # optimization configuration
30
    lr: float = 3e-3
2✔
31
    batch_size: int = 16
2✔
32
    optimizer: str = "adamw"
2✔
33
    gradient_accumulation_steps: int = 1  # how many
2✔
34

35
    # training configuration
36
    training_steps: int = 1024  # how many rollouts to train for
2✔
37

38
    # rollout configuration
39
    num_episodes_per_experience: int = 8  # how many rollouts per gradient update
2✔
40

41

42
class Trainer(Generic[StateT, ActionT, Step, Batch]):
2✔
43
    """A high-level trainer that pushbutton trains your policy
44

45
    Example:
46
        Here is an example of how to use the `Trainer` class with the DPO algorithm
47
        and an AST problem sampler
48

49
        >>> import torch
50
        >>> from astra_rl import (
51
        ...     Trainer,
52
        ...     TrainingConfiguration,
53
        ... )
54
        >>> from astra_rl.algorithms.dpo import (
55
        ...     DPO,
56
        ... )
57
        >>> from astra_rl.methods.ast import (
58
        ...     ASTProblem,
59
        ...     ASTSampler,
60
        ... )
61
        >>>
62
        >>> problem = (
63
        ...     ASTProblem()
64
        ... )
65
        >>> sampler = (
66
        ...     ASTSampler(
67
        ...         problem, ...
68
        ...     )
69
        ... )
70
        >>> algorithm = DPO(...)
71
        >>> config = TrainingConfiguration(
72
        ...     lr=1e-3,
73
        ...     batch_size=16,
74
        ...     optimizer="adamw",
75
        ...     gradient_accumulation_steps=1,
76
        ...     training_steps=1024,
77
        ...     num_episodes_per_experience=8,
78
        ... )
79
        >>> trainer = Trainer(
80
        ...     config,
81
        ...     sampler,
82
        ...     algorithm,
83
        ... )
84
        >>> trainer.train()
85

86
    Attributes:
87
        config (TrainingConfiguration): The configuration for the training process.
88
        harness (Harness): The harness that manages the training loop and interactions with the sampler. See `astra_rl.training.harness` for what it does.
89
        optimizer (Optimizer): The optimizer used for updating the model parameters.
90
        _global_step_counter (int): A counter for global steps, used for gradient accumulation.
91
    """
92

93
    optimizer: Optimizer
2✔
94

95
    def __init__(
2✔
96
        self,
97
        config: TrainingConfiguration,
98
        sampler: Sampler[StateT, ActionT],
99
        algorithm: Algorithm[StateT, ActionT, Step, Batch],
100
        use_wandb: bool = False,
101
    ):
102
        """
103
        Args:
104
            config (TrainingConfiguration): The configuration for the training process.
105
            sampler (Sampler): The sampler to run our algorithm in.
106
            algorithm (Algorithm): The algorithm used for training the tester agent.
107
        """
108

UNCOV
109
        self.config = config
×
NEW
110
        self.harness = Harness(
×
111
            sampler, algorithm, config.num_episodes_per_experience, use_wandb=use_wandb
112
        )
113

114
        # TODO initialize LR scheduler?
115
        # ?????????????????????????????
116

117
        # initialize optimizer
UNCOV
118
        if config.optimizer == "adam":
×
UNCOV
119
            from torch.optim import Adam
×
120

121
            self.optimizer = Adam(sampler.system.parameters(), config.lr)
×
122
        elif config.optimizer == "adamw":
×
UNCOV
123
            from torch.optim import AdamW
×
124

125
            self.optimizer = AdamW(sampler.system.parameters(), config.lr)
×
126
        elif config.optimizer == "sgd":
×
UNCOV
127
            from torch.optim import SGD
×
128

129
            self.optimizer = SGD(sampler.system.parameters(), config.lr)
×
130
        elif config.optimizer == "rmsprop":
×
UNCOV
131
            from torch.optim import RMSprop
×
132

133
            self.optimizer = RMSprop(sampler.system.parameters(), config.lr)
×
134
        elif config.optimizer == "adagrad":
×
UNCOV
135
            from torch.optim import Adagrad
×
136

137
            self.optimizer = Adagrad(sampler.system.parameters(), config.lr)
×
138
        else:
UNCOV
139
            raise ValueError(f"Unknown optimizer configured: {config.optimizer}")
×
140

141
        # step counter, for acccmulutaion, etc.
142
        self._global_step_counter = 0
×
143

144
    def train(self) -> None:
2✔
145
        """Run training by the specified config!
146

147
        Note:
148
            This method takes no arguments and returns nothing, and its
149
            only used for side effects. We don't really need it other than
150
            it's helpful for allowing the user to control when training
151
            actually starts (instead of immediately after Trainer construction).
152
        """
UNCOV
153
        for _ in range(self.config.training_steps):
×
UNCOV
154
            buf = self.harness.experience()
×
UNCOV
155
            for batch in buf:
×
156
                # increment counter first for occumulation
157
                self._global_step_counter += 1
×
158
                loss: torch.Tensor = (
×
159
                    self.harness.step(batch)[0]
160
                    / self.config.gradient_accumulation_steps
161
                )
162
                # typing disabled here b/c mypy can't statically verify
163
                # that the loss has gradients
UNCOV
164
                loss.backward()  # type: ignore[no-untyped-call]
×
165

166
                # if gradient accumulation happens, step!
167
                if (
×
168
                    self._global_step_counter % self.config.gradient_accumulation_steps
169
                    == 0
170
                ):
UNCOV
171
                    self.optimizer.step()
×
UNCOV
172
                    self.optimizer.zero_grad()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc