• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wildjames / predictive_coding_rs / 26003515046

17 May 2026 09:41PM UTC coverage: 89.661% (+1.0%) from 88.629%
26003515046

Pull #18

github

web-flow
Merge cabb4c1ae into ab068c7f4
Pull Request #18: Implement the remaining GPU kernels

225 of 243 new or added lines in 9 files covered. (92.59%)

6 existing lines in 2 files now uncovered.

2064 of 2302 relevant lines covered (89.66%)

14.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.61
/src/training/train_handler.rs
1
use chrono::{TimeDelta, Utc};
2
use std::time::Duration;
3
use tracing::info;
4

5
use crate::{
6
    data_handling::TrainingDataset,
7
    error::Result,
8
    model::{ModelSnapshot, PredictiveCodingModelConfig, save_model_config, save_snapshot},
9
};
10

11
use super::{TrainConfig, save_training_config};
12

13
/// Timing breakdown of a single training step, as a sequence of named phases.
14
#[derive(Debug, Clone)]
15
pub struct StepProfile {
16
    pub phases: Vec<(String, Duration)>,
17
}
18

19
impl StepProfile {
20
    pub fn new() -> Self {
22✔
21
        StepProfile { phases: Vec::new() }
22✔
22
    }
22✔
23

24
    pub fn record(&mut self, name: impl Into<String>, duration: Duration) {
48✔
25
        self.phases.push((name.into(), duration));
48✔
26
    }
48✔
27

NEW
28
    pub fn total(&self) -> Duration {
×
NEW
29
        self.phases.iter().map(|(_, d)| *d).sum()
×
NEW
30
    }
×
31
}
32

33
impl Default for StepProfile {
NEW
34
    fn default() -> Self {
×
NEW
35
        Self::new()
×
NEW
36
    }
×
37
}
38

39
pub trait TrainingHandler {
40
    fn get_config(&self) -> &TrainConfig;
41
    fn model_snapshot(&mut self) -> Result<ModelSnapshot>;
42
    fn model_config(&self) -> PredictiveCodingModelConfig;
43
    fn pin_input(&mut self) -> Result<()>;
44
    fn pin_output(&mut self) -> Result<()>;
45
    fn get_data(&self) -> &dyn TrainingDataset;
46
    fn get_file_output_prefix(&self) -> &String;
47

48
    fn pre_training_hook(&mut self) -> Result<()> {
×
49
        Ok(())
×
50
    }
×
51

52
    /// Execute one training step, returning per-phase timing breakdown.
53
    /// Implement this in handlers to define the training logic with profiling.
54
    fn profiled_train_step(&mut self, step: u32) -> Result<StepProfile>;
55

56
    /// Execute one training step, discarding the profile. Delegates to
57
    /// `profiled_train_step` by default.
58
    fn train_step(&mut self, step: u32) -> Result<()> {
14✔
59
        self.profiled_train_step(step)?;
14✔
60
        Ok(())
14✔
61
    }
14✔
62

63
    fn report_hook(&mut self, _step: u32, _mean_step_time: TimeDelta) -> Result<()> {
×
64
        Ok(())
×
65
    }
×
66

67
    /// Once the model has completed training, this will be called. By default, it saves the final model to "{file_output_prefix}_final.json"
68
    fn post_training_hook(&mut self) -> Result<()> {
5✔
69
        let final_output_path: String =
5✔
70
            format!("{}_{}", self.get_file_output_prefix(), "final_model.json");
5✔
71

72
        info!(
5✔
73
            "Finished training, saving final model to {}",
74
            final_output_path
75
        );
76
        let snapshot = self.model_snapshot()?;
5✔
77
        save_snapshot(&snapshot, &final_output_path)
5✔
78
    }
5✔
79

80
    // Any actions that need to be called with an awareness of the training step can use these hooks. By default, they do nothing.
81
    // e.g. if you want to anneal the learning rate, that would go here
82
    fn pre_step_hook(&mut self, _step: u32) -> Result<()> {
10✔
83
        Ok(())
10✔
84
    }
10✔
85
    fn post_step_hook(&mut self, _step: u32) -> Result<()> {
10✔
86
        Ok(())
10✔
87
    }
10✔
88
}
89

90
pub fn run_supervised_training_loop(handler: &mut dyn TrainingHandler) -> Result<()> {
4✔
91
    handler.pre_training_hook()?;
4✔
92

93
    // Supervised learning
94
    handler.pin_input()?;
4✔
95
    handler.pin_output()?;
4✔
96

97
    let training_config: &TrainConfig = handler.get_config();
4✔
98
    let training_steps: u32 = training_config.training_steps;
4✔
99
    let report_interval: u32 = training_config.report_interval;
4✔
100
    let snapshot_interval: u32 = training_config.snapshot_interval;
4✔
101

102
    info!(
4✔
103
        "Beginning training loop for {} steps with report interval {} and snapshot interval {}",
104
        training_steps, report_interval, snapshot_interval
105
    );
106

107
    let model_config: PredictiveCodingModelConfig = handler.model_config();
4✔
108
    info!(
4✔
109
        "Model architecture:\n\tlayer sizes: {:?}\n\tgamma: {}\n\talpha: {}\n\tactivation function: {:?}\n\tconvergence steps: {}\n\tconvergence threshold: {}",
110
        model_config.layer_sizes,
111
        model_config.gamma,
112
        model_config.alpha,
113
        model_config.activation_function,
114
        model_config.convergence_steps,
115
        model_config.convergence_threshold
116
    );
117

118
    // Write the config and training params to a file
119
    save_model_config(
4✔
120
        &model_config,
4✔
121
        &format!("{}_model_config.json", &handler.get_file_output_prefix()),
4✔
122
    )?;
×
123
    save_training_config(
4✔
124
        handler.get_config(),
4✔
125
        &format!("{}_training_config.json", &handler.get_file_output_prefix()),
4✔
126
    )?;
×
127

128
    // Track the time taken for each step in the current reporting epoch
129
    let mut step_times: Vec<TimeDelta> = Vec::new();
4✔
130

131
    // Main loop
132
    for step in 0..training_steps {
12✔
133
        let start_time = Utc::now();
12✔
134
        handler.pre_step_hook(step)?;
12✔
135
        handler.train_step(step)?;
12✔
136
        handler.post_step_hook(step)?;
12✔
137
        let elapsed: TimeDelta = Utc::now() - start_time;
12✔
138
        step_times.push(elapsed);
12✔
139

140
        if (report_interval > 0) && (step % report_interval == 0) {
12✔
141
            let mean_step_time: TimeDelta =
6✔
142
                step_times.iter().sum::<TimeDelta>() / step_times.len() as i32;
6✔
143
            handler.report_hook(step, mean_step_time)?;
6✔
144
            step_times.clear();
6✔
145
        }
6✔
146

147
        if (snapshot_interval > 0) && (step % snapshot_interval == 0) {
12✔
148
            let oname: String = format!(
7✔
149
                "{}_snapshot_step_{}.json",
150
                handler.get_file_output_prefix(),
7✔
151
                step
152
            );
153
            info!("Saving model snapshot {}", oname);
7✔
154

155
            let snapshot = handler.model_snapshot()?;
7✔
156
            save_snapshot(&snapshot, &oname)?;
7✔
157
        }
5✔
158
    }
159

160
    handler.post_training_hook()?;
4✔
161

162
    Ok(())
4✔
163
}
4✔
164

165
#[cfg(test)]
166
mod tests {
167
    use super::*;
168
    use crate::test_utils::{RecordingTrainingHandler, TempDir, single_thread_train_config};
169

170
    #[test]
171
    fn training_loop_preserves_hook_order_and_report_interval() {
1✔
172
        let temp_dir: TempDir = TempDir::new("training_loop_hooks");
1✔
173
        let output_prefix: String = temp_dir.path().join("model").display().to_string();
1✔
174
        let mut handler: RecordingTrainingHandler =
1✔
175
            RecordingTrainingHandler::new(single_thread_train_config(3, 2, 0), output_prefix);
1✔
176

177
        run_supervised_training_loop(&mut handler).unwrap();
1✔
178

179
        assert_eq!(
1✔
180
            handler.events,
181
            vec![
1✔
182
                String::from("pre_training"),
1✔
183
                String::from("pre_step:0"),
1✔
184
                String::from("train_step:0"),
1✔
185
                String::from("post_step:0"),
1✔
186
                String::from("report:0"),
1✔
187
                String::from("pre_step:1"),
1✔
188
                String::from("train_step:1"),
1✔
189
                String::from("post_step:1"),
1✔
190
                String::from("pre_step:2"),
1✔
191
                String::from("train_step:2"),
1✔
192
                String::from("post_step:2"),
1✔
193
                String::from("report:2"),
1✔
194
                String::from("post_training"),
1✔
195
            ]
196
        );
197
    }
1✔
198

199
    #[test]
200
    fn training_loop_saves_snapshots_at_configured_steps() {
1✔
201
        let temp_dir = TempDir::new("training_loop_snapshots");
1✔
202
        let output_prefix = temp_dir.path().join("model").display().to_string();
1✔
203
        let mut handler: RecordingTrainingHandler = RecordingTrainingHandler::new(
1✔
204
            single_thread_train_config(5, 0, 2),
1✔
205
            output_prefix.clone(),
1✔
206
        );
207

208
        run_supervised_training_loop(&mut handler).unwrap();
1✔
209

210
        assert!(temp_dir.path().join("model_snapshot_step_0.json").exists());
1✔
211
        assert!(!temp_dir.path().join("model_snapshot_step_1.json").exists());
1✔
212
        assert!(temp_dir.path().join("model_snapshot_step_2.json").exists());
1✔
213
        assert!(temp_dir.path().join("model_snapshot_step_4.json").exists());
1✔
214
        assert!(temp_dir.path().join("model_final_model.json").exists());
1✔
215
    }
1✔
216
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc