• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wildjames / predictive_coding_rs / 22809140016

07 Mar 2026 11:04PM UTC coverage: 87.11% (-6.0%) from 93.084%
22809140016

push

github

web-flow
Add an error handler, and use Result<>  (#8)

* Add an Error handler

* Wherever an error is thrown, return a Result. When errors are thrown, use appropriate ones from the new classes.

281 of 391 new or added lines in 10 files covered. (71.87%)

1 existing line in 1 file now uncovered.

1257 of 1443 relevant lines covered (87.11%)

9.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.82
/src/training/train_handler.rs
1
use tracing::info;
2

3
use crate::{
4
  data_handling::data_handler,
5
  error::Result,
6
  model_structure::{
7
    model::{PredictiveCodingModel, PredictiveCodingModelConfig},
8
    model_utils::{save_model_config, save_model_snapshot}
9
  },
10
  training::utils::{TrainConfig, save_training_config}
11
};
12

13
pub trait TrainingHandler {
14
  fn get_config(&self) -> &TrainConfig;
15
  fn get_model(&mut self) -> &mut PredictiveCodingModel;
16
  fn get_data(&self) -> &data_handler::TrainingDataset;
17
  fn get_file_output_prefix(&self) -> &String;
18

NEW
19
  fn pre_training_hook(&mut self) -> Result<()> {
×
NEW
20
    Ok(())
×
NEW
21
  }
×
22

23
  fn train_step(&mut self, _step: u32) -> Result<()>;
NEW
24
  fn report_hook(&mut self, _step: u32) -> Result<()> {
×
NEW
25
    Ok(())
×
NEW
26
  }
×
27

28
  /// Once the model has completed training, this will be called. By default, it saves the final model to "{file_output_prefix}_final.json"
29
  fn post_training_hook(&mut self) -> Result<()> {
4✔
30
    let final_output_path: String = format!("{}_{}", self.get_file_output_prefix(), "final_model.json");
4✔
31

32
    info!("Finished training, saving final model to {}", final_output_path);
4✔
33
    save_model_snapshot(
4✔
34
      self.get_model(),
4✔
35
      &final_output_path
4✔
36
    )
37
  }
4✔
38

39
  // Any actions that need to be called with an awareness of the training step can use these hooks. By default, they do nothing.
40
  // e.g. if you want to anneal the learning rate, that would go here
41
  fn pre_step_hook(&mut self, _step: u32) -> Result<()> {
8✔
42
    Ok(())
8✔
43
  }
8✔
44
  fn post_step_hook(&mut self, _step: u32) -> Result<()> {
8✔
45
    Ok(())
8✔
46
  }
8✔
47
}
48

49

50
pub fn run_supervised_training_loop(handler: &mut dyn TrainingHandler) -> Result<()> {
4✔
51
  handler.pre_training_hook()?;
4✔
52

53
  // Supervised learning
54
  handler.get_model().pin_input();
4✔
55
  handler.get_model().pin_output();
4✔
56

57
  let training_config: &TrainConfig = handler.get_config();
4✔
58
  let training_steps: u32 = training_config.training_steps;
4✔
59
  let report_interval: u32 = training_config.report_interval;
4✔
60
  let snapshot_interval: u32 = training_config.snapshot_interval;
4✔
61

62
  info!(
4✔
63
    "Beginning training loop for {} steps with report interval {} and snapshot interval {}",
64
    training_steps, report_interval, snapshot_interval
65
  );
66

67
  let model_config: &PredictiveCodingModelConfig = &handler.get_model().get_config();
4✔
68
  info!(
4✔
69
    "Model architecture:\n\tlayer sizes: {:?}\n\tgamma: {}\n\talpha: {}\n\tactivation function: {:?}\n\tconvergence steps: {}\n\tconvergence threshold: {}",
70
    model_config.layer_sizes,
71
    model_config.gamma,
72
    model_config.alpha,
73
    model_config.activation_function,
74
    model_config.convergence_steps,
75
    model_config.convergence_threshold
76
  );
77

78
  // Write the config and training params to a file
79
  save_model_config(
4✔
80
    model_config,
4✔
81
    &format!("{}_model_config.json", &handler.get_file_output_prefix())
4✔
NEW
82
  )?;
×
83
  save_training_config(
4✔
84
    handler.get_config(),
4✔
85
    &format!("{}_training_config.json", &handler.get_file_output_prefix())
4✔
NEW
86
  )?;
×
87

88
  // Main loop
89
  for step in 0..training_steps {
12✔
90
    handler.pre_step_hook(step)?;
12✔
91
    handler.train_step(step)?;
12✔
92
    handler.post_step_hook(step)?;
12✔
93

94
    if (report_interval > 0) && (step % report_interval == 0) {
12✔
95
      handler.report_hook(step)?;
6✔
96
    }
6✔
97

98
    if (snapshot_interval > 0) && (step % snapshot_interval == 0) {
12✔
99
      let oname: String = format!("{}_snapshot_step_{}.json", handler.get_file_output_prefix(), step);
7✔
100
      info!("Saving model snapshot {}", oname);
7✔
101

102
      save_model_snapshot(handler.get_model(), &oname)?;
7✔
103
    }
5✔
104
  }
105

106
  handler.post_training_hook()?;
4✔
107

108
  Ok(())
4✔
109
}
4✔
110

111
#[cfg(test)]
112
mod tests {
113
  use super::*;
114

115
  use crate::{
116
    model_structure::{
117
      model::PredictiveCodingModelConfig,
118
      model_utils::ActivationFunction
119
    },
120
    training::utils::{
121
      DataSetSource,
122
      ModelSource,
123
      TrainingStrategy
124
    }
125
  };
126
  use ndarray::Array2;
127
  use std::{
128
    fs,
129
    path::PathBuf,
130
    time::{SystemTime, UNIX_EPOCH}
131
  };
132

133
  struct TempDir {
134
    path: PathBuf,
135
  }
136

137
  impl TempDir {
138
    fn new(prefix: &str) -> Self {
2✔
139
      let unique_id = SystemTime::now()
2✔
140
        .duration_since(UNIX_EPOCH)
2✔
141
        .unwrap()
2✔
142
        .as_nanos();
2✔
143
      let path = std::env::temp_dir().join(format!(
2✔
144
        "predictive_coding_{prefix}_{}_{}",
2✔
145
        std::process::id(),
2✔
146
        unique_id
2✔
147
      ));
2✔
148
      fs::create_dir_all(&path).unwrap();
2✔
149
      TempDir { path }
2✔
150
    }
2✔
151

152
    fn path(&self) -> &PathBuf {
7✔
153
      &self.path
7✔
154
    }
7✔
155
  }
156

157
  impl Drop for TempDir {
158
    fn drop(&mut self) {
2✔
159
      let _ = fs::remove_dir_all(&self.path);
2✔
160
    }
2✔
161
  }
162

163
  struct RecordingHandler {
164
    config: TrainConfig,
165
    model: PredictiveCodingModel,
166
    data: data_handler::TrainingDataset,
167
    file_output_prefix: String,
168
    events: Vec<String>,
169
  }
170

171
  impl RecordingHandler {
172
    fn new(config: TrainConfig, output_prefix: String) -> Self {
2✔
173
      let model: PredictiveCodingModel = PredictiveCodingModel::new(&PredictiveCodingModelConfig {
2✔
174
        layer_sizes: vec![4, 10],
2✔
175
        alpha: 0.01,
2✔
176
        gamma: 0.05,
2✔
177
        convergence_threshold: 0.0,
2✔
178
        convergence_steps: 1,
2✔
179
        activation_function: ActivationFunction::Relu,
2✔
180
      });
2✔
181
      let data: data_handler::TrainingDataset = data_handler::TrainingDataset {
2✔
182
        dataset_size: 1,
2✔
183
        input_size: 4,
2✔
184
        output_size: 10,
2✔
185
        inputs: Array2::zeros((1, 4)),
2✔
186
        labels: Array2::zeros((1, 10)),
2✔
187
      };
2✔
188

189
      RecordingHandler {
2✔
190
        config,
2✔
191
        model,
2✔
192
        data,
2✔
193
        file_output_prefix: output_prefix,
2✔
194
        events: Vec::new(),
2✔
195
      }
2✔
196
    }
2✔
197
  }
198

199
  // Construct a dummy handler which just records what it's done, in what order.
200
  impl TrainingHandler for RecordingHandler {
201
    fn get_config(&self) -> &TrainConfig {
4✔
202
      &self.config
4✔
203
    }
4✔
204

205
    fn get_model(&mut self) -> &mut PredictiveCodingModel {
11✔
206
      &mut self.model
11✔
207
    }
11✔
208

209
    fn get_data(&self) -> &data_handler::TrainingDataset {
×
210
      &self.data
×
211
    }
×
212

213
    fn get_file_output_prefix(&self) -> &String {
9✔
214
      &self.file_output_prefix
9✔
215
    }
9✔
216

217
    fn pre_training_hook(&mut self) -> Result<()> {
2✔
218
      self.events.push(String::from("pre_training"));
2✔
219
      Ok(())
2✔
220
    }
2✔
221

222
    fn train_step(&mut self, step: u32) -> Result<()> {
8✔
223
      self.events.push(format!("train_step:{step}"));
8✔
224
      Ok(())
8✔
225
    }
8✔
226

227
    fn report_hook(&mut self, step: u32) -> Result<()> {
2✔
228
      self.events.push(format!("report:{step}"));
2✔
229
      Ok(())
2✔
230
    }
2✔
231

232
    fn pre_step_hook(&mut self, step: u32) -> Result<()> {
8✔
233
      self.events.push(format!("pre_step:{step}"));
8✔
234
      Ok(())
8✔
235
    }
8✔
236

237
    fn post_step_hook(&mut self, step: u32) -> Result<()> {
8✔
238
      self.events.push(format!("post_step:{step}"));
8✔
239
      Ok(())
8✔
240
    }
8✔
241

242
    fn post_training_hook(&mut self) -> Result<()> {
2✔
243
      self.events.push(String::from("post_training"));
2✔
244
      let final_output_path = format!("{}_final_model.json", self.get_file_output_prefix());
2✔
245
      save_model_snapshot(self.get_model(), &final_output_path)
2✔
246
    }
2✔
247
  }
248

249
  fn test_config(training_steps: u32, report_interval: u32, snapshot_interval: u32) -> TrainConfig {
2✔
250
    TrainConfig {
2✔
251
      model_source: ModelSource::Config(String::from("unused.json")),
2✔
252
      dataset: DataSetSource::MNIST {
2✔
253
        input_idx_file: String::from("unused-images.idx"),
2✔
254
        output_idx_file: String::from("unused-labels.idx"),
2✔
255
      },
2✔
256
      training_strategy: TrainingStrategy::SingleThread,
2✔
257
      training_steps,
2✔
258
      report_interval,
2✔
259
      snapshot_interval,
2✔
260
    }
2✔
261
  }
2✔
262

263
  #[test]
264
  fn training_loop_preserves_hook_order_and_report_interval() {
1✔
265
    let temp_dir: TempDir = TempDir::new("training_loop_hooks");
1✔
266
    let output_prefix: String = temp_dir.path().join("model").display().to_string();
1✔
267
    let mut handler: RecordingHandler = RecordingHandler::new(test_config(3, 2, 0), output_prefix);
1✔
268

269
    run_supervised_training_loop(&mut handler).unwrap();
1✔
270

271
    assert_eq!(
1✔
272
      handler.events,
273
      vec![
1✔
274
        String::from("pre_training"),
1✔
275
        String::from("pre_step:0"),
1✔
276
        String::from("train_step:0"),
1✔
277
        String::from("post_step:0"),
1✔
278
        String::from("report:0"),
1✔
279
        String::from("pre_step:1"),
1✔
280
        String::from("train_step:1"),
1✔
281
        String::from("post_step:1"),
1✔
282
        String::from("pre_step:2"),
1✔
283
        String::from("train_step:2"),
1✔
284
        String::from("post_step:2"),
1✔
285
        String::from("report:2"),
1✔
286
        String::from("post_training"),
1✔
287
      ]
288
    );
289
  }
1✔
290

291
  #[test]
292
  fn training_loop_saves_snapshots_at_configured_steps() {
1✔
293
    let temp_dir = TempDir::new("training_loop_snapshots");
1✔
294
    let output_prefix = temp_dir.path().join("model").display().to_string();
1✔
295
    let mut handler = RecordingHandler::new(test_config(5, 0, 2), output_prefix.clone());
1✔
296

297
    run_supervised_training_loop(&mut handler).unwrap();
1✔
298

299
    assert!(temp_dir.path().join("model_snapshot_step_0.json").exists());
1✔
300
    assert!(!temp_dir.path().join("model_snapshot_step_1.json").exists());
1✔
301
    assert!(temp_dir.path().join("model_snapshot_step_2.json").exists());
1✔
302
    assert!(temp_dir.path().join("model_snapshot_step_4.json").exists());
1✔
303
    assert!(temp_dir.path().join("model_final_model.json").exists());
1✔
304
  }
1✔
305
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc