• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 18275623010

03 Oct 2025 08:38PM UTC coverage: 38.778%. Remained the same
18275623010

push

github

web-flow
Merge pull request #24 from sisl/de/ux_improvements

User Experience Naming Improvements

61 of 143 new or added lines in 18 files covered. (42.66%)

58 existing lines in 6 files now uncovered.

311 of 802 relevant lines covered (38.78%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.95
/src/astra_rl/core/algorithm.py
1
"""
2
algorithm.py
3
"""
4

5
from abc import abstractmethod, ABC
2✔
6
from typing import Sequence, Generic, Any, Dict
2✔
7

8
import torch
2✔
9

10
from astra_rl.core.system import System
2✔
11
from astra_rl.core.sampler import Graph
2✔
12
from astra_rl.core.common import Step, Batch, StateT, ActionT
2✔
13

14

15
class Algorithm(ABC, Generic[StateT, ActionT, Step, Batch]):
2✔
16
    """An Algorithm used for performing training.
17

18
    Specifically, the Algorithm object is responsible for encoding
19
    how a particular rollout graph becomes processed into a loss
20
    which updates the weights of the model. To implement its children,
21
    you basically call self.system's various methods to push values
22
    through the network.
23

24

25
    Attributes:
26
        system (System): The system instance that defines the sampler and actions.
27

28
    Generics:
29
        StateT (type): The type of the state in the sampler.
30
        ActionT (type): The type of the action in the sampler.
31
        Step (type): The type of a single step in the environment.
32
        Batch (type): The type of a batch of steps, passed to the .step() function for gradient.
33
    """
34

35
    def __init__(self, system: System[StateT, ActionT]):
2✔
NEW
36
        self.system = system
×
37

38
    @abstractmethod
2✔
39
    def flatten(self, graph: Graph[StateT, ActionT]) -> Sequence[Step]:
2✔
40
        """Process a rollout graph into a sequence of steps.
41

42
        Args:
43
            graph (Graph[StateT, ActionT]): The graph to flatten.
44

45
        Returns:
46
            Sequence[Step]: A sequence of steps representing the flattened graph.
47
        """
48
        pass
×
49

50
    @staticmethod
2✔
51
    @abstractmethod
2✔
52
    def collate_fn(batch: Sequence[Step]) -> Batch:
2✔
53
        """The collate_fn for torch dataloaders for batching.
54

55
        We use this as the literal collate_fn to a torch DataLoader, and
56
        it is responsible for emitting well-formed batches of data.
57

58
        Args:
59
            batch (Sequence[Step]): A sequence of steps to collate.
60

61
        Returns:
62
            Batch: A batch of data ready for processing using .step().
63
        """
64
        pass
×
65

66
    @abstractmethod
2✔
67
    def step(self, batch: Batch) -> tuple[torch.Tensor, Dict[Any, Any]]:
2✔
68
        """Take a batch and compute loss of this batch.
69

70
        Args:
71
            batch (Batch): A batch of data to process.
72

73
        Returns:
74
            tuple[torch.Tensor, Dict[Any, Any]]: A tuple containing:
75
                - torch.Tensor: The loss computed by the algorithm (for current batch).
76
                - Dict[Any, Any]: Additional information for logging.
77
        """
78
        pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc