• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sisl / astra-rl / 16923134818

12 Aug 2025 11:14PM UTC coverage: 42.094% (-2.2%) from 44.276%
16923134818

push

github

Jemoka
[wiP] work on IFT problem formulation

7 of 31 new or added lines in 4 files covered. (22.58%)

205 of 487 relevant lines covered (42.09%)

0.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/astra_rl/methods/ift_ast_problem.py
1
"""
2
ift_ast_problem.py
3
AST Problem for Instruction Fine-Tuned Models
4
"""
5

NEW
6
from typing import Sequence, Dict
×
7

NEW
8
from abc import ABC, abstractmethod
×
NEW
9
import torch
×
10

NEW
11
from astra_rl.core.problem import Problem
×
NEW
12
from astra_rl.core.moderator import Moderator
×
13

14

NEW
15
class IFTASTProblem(Problem[Sequence[Dict[str, str]], str], ABC):
×
16
    """ASTProblem, implemented for instruction fine-tuned models
17

18
    For usage examples, see `astra_rl.core.problem.Problem`.
19

20
    Attributes:
21
        moderator (Moderator): The moderator used to evaluate the sequences.
22
        ppl_weight (float): Weight for perplexity in the reward calculation.
23
        combined_weight (float): Weight for combined scores in the reward calculation.
24
        defense_toxicity_weight (float): Weight for defense toxicity in the reward calculation.
25
    """
26

NEW
27
    def __init__(
×
28
        self,
29
        moderator: Moderator[Sequence[Dict[str, str]]],
30
        ppl_weight: float = 0.025,
31
        combined_weight: float = 0.1,
32
        defense_toxicity_weight: float = 0.1,
33
    ):
NEW
34
        super().__init__(moderator)
×
35

NEW
36
        self.ppl_weight = ppl_weight
×
NEW
37
        self.combined_weight = combined_weight
×
NEW
38
        self.defense_toxicity_weight = defense_toxicity_weight
×
39

NEW
40
    def advance(
×
41
        self, state: Sequence[Dict[str, str]], action: str, next: str
42
    ) -> Sequence[Dict[str, str]]:
NEW
43
        return list(state) + [{"role": "assistant", "content": action + next}]
×
44

NEW
45
    @abstractmethod
×
NEW
46
    def rollout_prompt_with_user_model(
×
47
        self, x: Sequence[Sequence[Dict[str, str]]]
48
    ) -> Sequence[Sequence[Dict[str, str]]]:
49
        """Rolls out a prompt using a model for the user
50

51
        Args:
52
            x (Sequence[StateT]): The input chat messages to roll out from.
53

54
        Returns:
55
            Sequence[StateT]: **REPEAT** the input message + added a user utterance on top.
56
        """
NEW
57
        pass
×
58

NEW
59
    def reward(
×
60
        self,
61
        context: Sequence[Sequence[Dict[str, str]]],
62
        attack: Sequence[str],
63
        response: Sequence[str],
64
    ) -> Sequence[float]:
65
        # IMPORTANT: this is profound---"combined toxicity" in IFT'd models also includes
66
        # the context, and defense toxicity is defense only without any context.
NEW
67
        combined_uts = [
×
68
            self.advance(i, j, k) for i, j, k in zip(context, attack, response)
69
        ]
NEW
70
        defender_uts = [self.advance([], "", k) for k in response]
×
NEW
71
        combined_scores: Sequence[float] = self.moderator.moderate(combined_uts)
×
NEW
72
        defender_scores: Sequence[float] = self.moderator.moderate(defender_uts)
×
73

74
        # calculate logprobs
NEW
75
        attack_ppls: torch.Tensor = self._get_target_logprobs_and_validate(
×
76
            context, attack
77
        ).exp()
78

NEW
79
        reward: Sequence[float] = (
×
80
            (
81
                self.ppl_weight * attack_ppls
82
                + self.combined_weight
83
                * torch.tensor(combined_scores).to(attack_ppls.device)
84
                + self.defense_toxicity_weight
85
                * torch.tensor(defender_scores).to(attack_ppls.device)
86
            )
87
            .cpu()
88
            .tolist()
89
        )
90

NEW
91
        return reward
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc