• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wildjames / predictive_coding_rs / 26039594071

18 May 2026 02:23PM UTC coverage: 89.655% (+1.0%) from 88.629%
26039594071

Pull #18

github

web-flow
Merge 3555e47b1 into ab068c7f4
Pull Request #18: Implement the remaining GPU kernels

268 of 288 new or added lines in 12 files covered. (93.06%)

6 existing lines in 2 files now uncovered.

2080 of 2320 relevant lines covered (89.66%)

14.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.71
/src/training/train_handler.rs
1
use chrono::{TimeDelta, Utc};
2
use std::time::Duration;
3
use tracing::info;
4

5
use crate::{
6
    data_handling::TrainingDataset,
7
    error::Result,
8
    model::{ModelSnapshot, PredictiveCodingModelConfig, save_model_config, save_snapshot},
9
};
10

11
use super::{TrainConfig, save_training_config};
12

13
/// Timing breakdown of a single training step, as a sequence of named phases.
14
#[derive(Debug, Clone)]
15
pub struct StepProfile {
16
    pub phases: Vec<(String, Duration)>,
17
}
18

19
impl StepProfile {
20
    pub fn new() -> Self {
22✔
21
        StepProfile { phases: Vec::new() }
22✔
22
    }
22✔
23

24
    pub fn record(&mut self, name: impl Into<String>, duration: Duration) {
48✔
25
        self.phases.push((name.into(), duration));
48✔
26
    }
48✔
27

NEW
28
    pub fn total(&self) -> Duration {
×
NEW
29
        self.phases.iter().map(|(_, d)| *d).sum()
×
NEW
30
    }
×
31
}
32

33
impl Default for StepProfile {
NEW
34
    fn default() -> Self {
×
NEW
35
        Self::new()
×
NEW
36
    }
×
37
}
38

39
pub trait TrainingHandler {
40
    fn get_config(&self) -> &TrainConfig;
41
    fn model_snapshot(&mut self) -> Result<ModelSnapshot>;
42
    fn model_config(&self) -> PredictiveCodingModelConfig;
43
    fn pin_input(&mut self) -> Result<()>;
44
    fn pin_output(&mut self) -> Result<()>;
45
    fn get_data(&self) -> &dyn TrainingDataset;
46
    fn get_file_output_prefix(&self) -> &String;
47

48
    fn pre_training_hook(&mut self) -> Result<()> {
×
49
        Ok(())
×
50
    }
×
51

52
    /// Execute one training step, returning per-phase timing breakdown.
53
    /// Implement this in handlers to define the training logic with profiling.
54
    fn profiled_train_step(&mut self, step: u32) -> Result<StepProfile>;
55

56
    /// Execute one training step, discarding the profile. Delegates to
57
    /// `profiled_train_step` by default.
58
    fn train_step(&mut self, step: u32) -> Result<()> {
14✔
59
        self.profiled_train_step(step)?;
14✔
60
        Ok(())
14✔
61
    }
14✔
62

63
    fn report_hook(&mut self, _step: u32, _mean_step_time: TimeDelta) -> Result<()> {
×
64
        Ok(())
×
65
    }
×
66

67
    /// Once the model has completed training, this will be called. By default, it saves the final model to "{file_output_prefix}_final.json"
68
    fn post_training_hook(&mut self) -> Result<()> {
5✔
69
        let final_output_path: String =
5✔
70
            format!("{}_{}", self.get_file_output_prefix(), "final_model.json");
5✔
71

72
        info!(
5✔
73
            "Finished training, saving final model to {}",
74
            final_output_path
75
        );
76
        let snapshot = self.model_snapshot()?;
5✔
77
        save_snapshot(&snapshot, &final_output_path)
5✔
78
    }
5✔
79

80
    // Any actions that need to be called with an awareness of the training step can use these hooks. By default, they do nothing.
81
    // e.g. if you want to anneal the learning rate, that would go here
82
    fn pre_step_hook(&mut self, _step: u32) -> Result<()> {
10✔
83
        Ok(())
10✔
84
    }
10✔
85
    fn post_step_hook(&mut self, _step: u32) -> Result<()> {
10✔
86
        Ok(())
10✔
87
    }
10✔
88
}
89

90
/// Log training progress: ETA and current energy.
91
pub fn log_training_progress(
5✔
92
    step: u32,
5✔
93
    remaining_steps: u32,
5✔
94
    mean_step_time: TimeDelta,
5✔
95
    energy: f32,
5✔
96
) {
5✔
97
    let est_time_to_finish = mean_step_time * remaining_steps as i32;
5✔
98
    let est_finish_time = Utc::now() + est_time_to_finish;
5✔
99
    info!(
5✔
100
        "Step {}: Current model state: energy = {:.2}\tEstimated finish time: {}",
101
        step,
102
        energy,
103
        est_finish_time.format("%Y-%m-%d %H:%M:%S")
4✔
104
    );
105
}
5✔
106

107
pub fn run_supervised_training_loop(handler: &mut dyn TrainingHandler) -> Result<()> {
4✔
108
    handler.pre_training_hook()?;
4✔
109

110
    // Supervised learning
111
    handler.pin_input()?;
4✔
112
    handler.pin_output()?;
4✔
113

114
    let training_config: &TrainConfig = handler.get_config();
4✔
115
    let training_steps: u32 = training_config.training_steps;
4✔
116
    let report_interval: u32 = training_config.report_interval;
4✔
117
    let snapshot_interval: u32 = training_config.snapshot_interval;
4✔
118

119
    info!(
4✔
120
        "Beginning training loop for {} steps with report interval {} and snapshot interval {}",
121
        training_steps, report_interval, snapshot_interval
122
    );
123

124
    let model_config: PredictiveCodingModelConfig = handler.model_config();
4✔
125
    info!(
4✔
126
        "Model architecture:\n\tlayer sizes: {:?}\n\tgamma: {}\n\talpha: {}\n\tweight_clip: {}\n\tactivation function: {:?}\n\tconvergence steps: {}\n\tconvergence threshold: {}",
127
        model_config.layer_sizes,
128
        model_config.gamma,
129
        model_config.alpha,
130
        model_config.weight_clip,
131
        model_config.activation_function,
132
        model_config.convergence_steps,
133
        model_config.convergence_threshold
134
    );
135

136
    // Write the config and training params to a file
137
    save_model_config(
4✔
138
        &model_config,
4✔
139
        &format!("{}_model_config.json", &handler.get_file_output_prefix()),
4✔
140
    )?;
×
141
    save_training_config(
4✔
142
        handler.get_config(),
4✔
143
        &format!("{}_training_config.json", &handler.get_file_output_prefix()),
4✔
144
    )?;
×
145

146
    // Track the time taken for each step in the current reporting epoch
147
    let mut step_times: Vec<TimeDelta> = Vec::new();
4✔
148

149
    // Main loop
150
    for step in 0..training_steps {
12✔
151
        let start_time = Utc::now();
12✔
152
        handler.pre_step_hook(step)?;
12✔
153
        handler.train_step(step)?;
12✔
154
        handler.post_step_hook(step)?;
12✔
155
        let elapsed: TimeDelta = Utc::now() - start_time;
12✔
156
        step_times.push(elapsed);
12✔
157

158
        if (report_interval > 0) && (step % report_interval == 0) {
12✔
159
            let mean_step_time: TimeDelta =
6✔
160
                step_times.iter().sum::<TimeDelta>() / step_times.len() as i32;
6✔
161
            handler.report_hook(step, mean_step_time)?;
6✔
162
            step_times.clear();
6✔
163
        }
6✔
164

165
        if (snapshot_interval > 0) && (step % snapshot_interval == 0) {
12✔
166
            let oname: String = format!(
7✔
167
                "{}_snapshot_step_{}.json",
168
                handler.get_file_output_prefix(),
7✔
169
                step
170
            );
171
            info!("Saving model snapshot {}", oname);
7✔
172

173
            let snapshot = handler.model_snapshot()?;
7✔
174
            save_snapshot(&snapshot, &oname)?;
7✔
175
        }
5✔
176
    }
177

178
    handler.post_training_hook()?;
4✔
179

180
    Ok(())
4✔
181
}
4✔
182

183
#[cfg(test)]
184
mod tests {
185
    use super::*;
186
    use crate::test_utils::{RecordingTrainingHandler, TempDir, single_thread_train_config};
187

188
    #[test]
189
    fn training_loop_preserves_hook_order_and_report_interval() {
1✔
190
        let temp_dir: TempDir = TempDir::new("training_loop_hooks");
1✔
191
        let output_prefix: String = temp_dir.path().join("model").display().to_string();
1✔
192
        let mut handler: RecordingTrainingHandler =
1✔
193
            RecordingTrainingHandler::new(single_thread_train_config(3, 2, 0), output_prefix);
1✔
194

195
        run_supervised_training_loop(&mut handler).unwrap();
1✔
196

197
        assert_eq!(
1✔
198
            handler.events,
199
            vec![
1✔
200
                String::from("pre_training"),
1✔
201
                String::from("pre_step:0"),
1✔
202
                String::from("train_step:0"),
1✔
203
                String::from("post_step:0"),
1✔
204
                String::from("report:0"),
1✔
205
                String::from("pre_step:1"),
1✔
206
                String::from("train_step:1"),
1✔
207
                String::from("post_step:1"),
1✔
208
                String::from("pre_step:2"),
1✔
209
                String::from("train_step:2"),
1✔
210
                String::from("post_step:2"),
1✔
211
                String::from("report:2"),
1✔
212
                String::from("post_training"),
1✔
213
            ]
214
        );
215
    }
1✔
216

217
    #[test]
218
    fn training_loop_saves_snapshots_at_configured_steps() {
1✔
219
        let temp_dir = TempDir::new("training_loop_snapshots");
1✔
220
        let output_prefix = temp_dir.path().join("model").display().to_string();
1✔
221
        let mut handler: RecordingTrainingHandler = RecordingTrainingHandler::new(
1✔
222
            single_thread_train_config(5, 0, 2),
1✔
223
            output_prefix.clone(),
1✔
224
        );
225

226
        run_supervised_training_loop(&mut handler).unwrap();
1✔
227

228
        assert!(temp_dir.path().join("model_snapshot_step_0.json").exists());
1✔
229
        assert!(!temp_dir.path().join("model_snapshot_step_1.json").exists());
1✔
230
        assert!(temp_dir.path().join("model_snapshot_step_2.json").exists());
1✔
231
        assert!(temp_dir.path().join("model_snapshot_step_4.json").exists());
1✔
232
        assert!(temp_dir.path().join("model_final_model.json").exists());
1✔
233
    }
1✔
234
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc