• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wildjames / predictive_coding_rs / 22864321797

09 Mar 2026 04:45PM UTC coverage: 91.047% (-0.6%) from 91.662%
22864321797

push

github

web-flow
Move eval and infer code out to its own module (#13)

* Move eval code out to library. Also, rename the training handler files.

* Move the data handler access to the mod file, to make import paths a bit shorter

* Restructure model access

* Make training imports cleaner

* better import path for ingerence module

* Go through and check for super imports I missed

* Run rust fmt

1582 of 1742 new or added lines in 20 files covered. (90.82%)

2 existing lines in 2 files now uncovered.

1678 of 1843 relevant lines covered (91.05%)

16.14 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.58
/src/training/train_handler.rs
1
use chrono::{TimeDelta, Utc};
2
use tracing::info;
3

4
use crate::{
5
    data_handling::TrainingDataset,
6
    error::Result,
7
    model::{
8
        PredictiveCodingModel, PredictiveCodingModelConfig, save_model_config, save_model_snapshot,
9
    },
10
};
11

12
use super::{TrainConfig, save_training_config};
13

14
pub trait TrainingHandler {
15
    fn get_config(&self) -> &TrainConfig;
16
    fn get_model(&mut self) -> &mut PredictiveCodingModel;
17
    fn get_data(&self) -> &dyn TrainingDataset;
18
    fn get_file_output_prefix(&self) -> &String;
19

NEW
20
    fn pre_training_hook(&mut self) -> Result<()> {
×
NEW
21
        Ok(())
×
NEW
22
    }
×
23

24
    fn train_step(&mut self, _step: u32) -> Result<()>;
NEW
25
    fn report_hook(&mut self, _step: u32, _mean_step_time: TimeDelta) -> Result<()> {
×
NEW
26
        Ok(())
×
NEW
27
    }
×
28

29
    /// Once the model has completed training, this will be called. By default, it saves the final model to "{file_output_prefix}_final.json"
30
    fn post_training_hook(&mut self) -> Result<()> {
5✔
31
        let final_output_path: String =
5✔
32
            format!("{}_{}", self.get_file_output_prefix(), "final_model.json");
5✔
33

34
        info!(
5✔
35
            "Finished training, saving final model to {}",
36
            final_output_path
37
        );
38
        save_model_snapshot(self.get_model(), &final_output_path)
5✔
39
    }
5✔
40

41
    // Any actions that need to be called with an awareness of the training step can use these hooks. By default, they do nothing.
42
    // e.g. if you want to anneal the learning rate, that would go here
43
    fn pre_step_hook(&mut self, _step: u32) -> Result<()> {
10✔
44
        Ok(())
10✔
45
    }
10✔
46
    fn post_step_hook(&mut self, _step: u32) -> Result<()> {
10✔
47
        Ok(())
10✔
48
    }
10✔
49
}
50

51
pub fn run_supervised_training_loop(handler: &mut dyn TrainingHandler) -> Result<()> {
4✔
52
    handler.pre_training_hook()?;
4✔
53

54
    // Supervised learning
55
    handler.get_model().pin_input();
4✔
56
    handler.get_model().pin_output();
4✔
57

58
    let training_config: &TrainConfig = handler.get_config();
4✔
59
    let training_steps: u32 = training_config.training_steps;
4✔
60
    let report_interval: u32 = training_config.report_interval;
4✔
61
    let snapshot_interval: u32 = training_config.snapshot_interval;
4✔
62

63
    info!(
4✔
64
        "Beginning training loop for {} steps with report interval {} and snapshot interval {}",
65
        training_steps, report_interval, snapshot_interval
66
    );
67

68
    let model_config: &PredictiveCodingModelConfig = &handler.get_model().get_config();
4✔
69
    info!(
4✔
70
        "Model architecture:\n\tlayer sizes: {:?}\n\tgamma: {}\n\talpha: {}\n\tactivation function: {:?}\n\tconvergence steps: {}\n\tconvergence threshold: {}",
71
        model_config.layer_sizes,
72
        model_config.gamma,
73
        model_config.alpha,
74
        model_config.activation_function,
75
        model_config.convergence_steps,
76
        model_config.convergence_threshold
77
    );
78

79
    // Write the config and training params to a file
80
    save_model_config(
4✔
81
        model_config,
4✔
82
        &format!("{}_model_config.json", &handler.get_file_output_prefix()),
4✔
NEW
83
    )?;
×
84
    save_training_config(
4✔
85
        handler.get_config(),
4✔
86
        &format!("{}_training_config.json", &handler.get_file_output_prefix()),
4✔
NEW
87
    )?;
×
88

89
    // Track the time taken for each step in the current reporting epoch
90
    let mut step_times: Vec<TimeDelta> = Vec::new();
4✔
91

92
    // Main loop
93
    for step in 0..training_steps {
12✔
94
        let start_time = Utc::now();
12✔
95
        handler.pre_step_hook(step)?;
12✔
96
        handler.train_step(step)?;
12✔
97
        handler.post_step_hook(step)?;
12✔
98
        let elapsed: TimeDelta = Utc::now() - start_time;
12✔
99
        step_times.push(elapsed);
12✔
100

101
        if (report_interval > 0) && (step % report_interval == 0) {
12✔
102
            let mean_step_time: TimeDelta =
6✔
103
                step_times.iter().sum::<TimeDelta>() / step_times.len() as i32;
6✔
104
            handler.report_hook(step, mean_step_time)?;
6✔
105
            step_times.clear();
6✔
106
        }
6✔
107

108
        if (snapshot_interval > 0) && (step % snapshot_interval == 0) {
12✔
109
            let oname: String = format!(
7✔
110
                "{}_snapshot_step_{}.json",
111
                handler.get_file_output_prefix(),
7✔
112
                step
113
            );
114
            info!("Saving model snapshot {}", oname);
7✔
115

116
            save_model_snapshot(handler.get_model(), &oname)?;
7✔
117
        }
5✔
118
    }
119

120
    handler.post_training_hook()?;
4✔
121

122
    Ok(())
4✔
123
}
4✔
124

125
#[cfg(test)]
126
mod tests {
127
    use super::*;
128
    use crate::test_utils::{RecordingTrainingHandler, TempDir, single_thread_train_config};
129

130
    #[test]
131
    fn training_loop_preserves_hook_order_and_report_interval() {
1✔
132
        let temp_dir: TempDir = TempDir::new("training_loop_hooks");
1✔
133
        let output_prefix: String = temp_dir.path().join("model").display().to_string();
1✔
134
        let mut handler: RecordingTrainingHandler =
1✔
135
            RecordingTrainingHandler::new(single_thread_train_config(3, 2, 0), output_prefix);
1✔
136

137
        run_supervised_training_loop(&mut handler).unwrap();
1✔
138

139
        assert_eq!(
1✔
140
            handler.events,
141
            vec![
1✔
142
                String::from("pre_training"),
1✔
143
                String::from("pre_step:0"),
1✔
144
                String::from("train_step:0"),
1✔
145
                String::from("post_step:0"),
1✔
146
                String::from("report:0"),
1✔
147
                String::from("pre_step:1"),
1✔
148
                String::from("train_step:1"),
1✔
149
                String::from("post_step:1"),
1✔
150
                String::from("pre_step:2"),
1✔
151
                String::from("train_step:2"),
1✔
152
                String::from("post_step:2"),
1✔
153
                String::from("report:2"),
1✔
154
                String::from("post_training"),
1✔
155
            ]
156
        );
157
    }
1✔
158

159
    #[test]
160
    fn training_loop_saves_snapshots_at_configured_steps() {
1✔
161
        let temp_dir = TempDir::new("training_loop_snapshots");
1✔
162
        let output_prefix = temp_dir.path().join("model").display().to_string();
1✔
163
        let mut handler: RecordingTrainingHandler = RecordingTrainingHandler::new(
1✔
164
            single_thread_train_config(5, 0, 2),
1✔
165
            output_prefix.clone(),
1✔
166
        );
167

168
        run_supervised_training_loop(&mut handler).unwrap();
1✔
169

170
        assert!(temp_dir.path().join("model_snapshot_step_0.json").exists());
1✔
171
        assert!(!temp_dir.path().join("model_snapshot_step_1.json").exists());
1✔
172
        assert!(temp_dir.path().join("model_snapshot_step_2.json").exists());
1✔
173
        assert!(temp_dir.path().join("model_snapshot_step_4.json").exists());
1✔
174
        assert!(temp_dir.path().join("model_final_model.json").exists());
1✔
175
    }
1✔
176
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc