• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 16262675714

14 Jul 2025 04:49AM UTC coverage: 46.419% (-48.0%) from 94.444%
16262675714

push

github

web-flow
Merge pull request #3 from sisl/feat/core

Initial implementation of core AST algorithm.

160 of 361 new or added lines in 18 files covered. (44.32%)

175 of 377 relevant lines covered (46.42%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.95
/src/astra_rl/core/algorithm.py
1
"""
2
algorithm.py
3
"""
4

5
from abc import abstractmethod, ABC
2✔
6
from typing import Sequence, Generic
2✔
7

8
import torch
2✔
9

10
from astra_rl.core.problem import Problem
2✔
11
from astra_rl.core.environment import Graph
2✔
12
from astra_rl.core.common import Step, Batch, StateT, ActionT
2✔
13

14

15
class Algorithm(ABC, Generic[StateT, ActionT, Step, Batch]):
2✔
16
    """An Algorithm used for performing training.
17

18
    Specifically, the Algorithm object is responsible for encoding
19
    how a particular rollout graph becomes processed into a loss
20
    which updates the weights of the model. To implement its children,
21
    you basically call self.problem's various methods to push values
22
    through the network.
23

24

25
    Attributes:
26
        problem (Problem): The problem instance that defines the environment and actions.
27

28
    Generics:
29
        StateT (type): The type of the state in the environment.
30
        ActionT (type): The type of the action in the environment.
31
        Step (type): The type of a single step in the environment.
32
        Batch (type): The type of a batch of steps, passed to the .step() function for gradient.
33
    """
34

35
    def __init__(self, problem: Problem[StateT, ActionT]):
2✔
NEW
36
        self.problem = problem
×
37

38
    @abstractmethod
2✔
39
    def flatten(self, graph: Graph[StateT, ActionT]) -> Sequence[Step]:
2✔
40
        """Process a rollout graph into a sequence of steps.
41

42
        Args:
43
            graph (Graph[StateT, ActionT]): The graph to flatten.
44

45
        Returns:
46
            Sequence[Step]: A sequence of steps representing the flattened graph.
47
        """
NEW
48
        pass
×
49

50
    @staticmethod
2✔
51
    @abstractmethod
2✔
52
    def collate_fn(batch: Sequence[Step]) -> Batch:
2✔
53
        """The collate_fn for torch dataloaders for batching.
54

55
        We use this as the literal collate_fn to a torch DataLoader, and
56
        it is responsible for emitting well-formed batches of data.
57

58
        Args:
59
            batch (Sequence[Step]): A sequence of steps to collate.
60

61
        Returns:
62
            Batch: A batch of data ready for processing using .step().
63
        """
NEW
64
        pass
×
65

66
    @abstractmethod
2✔
67
    def step(self, batch: Batch) -> torch.Tensor:
2✔
68
        """Take a batch and compute loss of this batch.
69

70
        Args:
71
            batch (Batch): A batch of data to process.
72

73
        Returns:
74
            torch.Tensor: The computed loss for the batch.
75
        """
NEW
76
        pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc