• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wildjames / predictive_coding_rs / 22821825018

08 Mar 2026 01:14PM UTC coverage: 84.927% (-2.2%) from 87.11%
22821825018

Pull #10

github

web-flow
Merge 4b9278865 into 5f8ad5902
Pull Request #10: Generalise the data boundary

111 of 164 new or added lines in 9 files covered. (67.68%)

2 existing lines in 2 files now uncovered.

1279 of 1506 relevant lines covered (84.93%)

9.11 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.64
/src/training/train_handler.rs
1
use tracing::info;
2

3
#[cfg(test)]
4
use std::sync::Arc;
5

6
use crate::{
7
  data_handling::data_handler,
8
  error::Result,
9
  model_structure::{
10
    model::{PredictiveCodingModel, PredictiveCodingModelConfig},
11
    configuration::{save_model_config, save_model_snapshot}
12
  },
13
  training::configuration::{
14
    TrainConfig,
15
    save_training_config
16
  }
17
};
18

19
pub trait TrainingHandler {
20
  fn get_config(&self) -> &TrainConfig;
21
  fn get_model(&mut self) -> &mut PredictiveCodingModel;
22
  fn get_data(&self) -> &dyn data_handler::TrainingDataset;
23
  fn get_file_output_prefix(&self) -> &String;
24

25
  fn pre_training_hook(&mut self) -> Result<()> {
×
26
    Ok(())
×
27
  }
×
28

29
  fn train_step(&mut self, _step: u32) -> Result<()>;
30
  fn report_hook(&mut self, _step: u32) -> Result<()> {
×
31
    Ok(())
×
32
  }
×
33

34
  /// Once the model has completed training, this will be called. By default, it saves the final model to "{file_output_prefix}_final.json"
35
  fn post_training_hook(&mut self) -> Result<()> {
4✔
36
    let final_output_path: String = format!("{}_{}", self.get_file_output_prefix(), "final_model.json");
4✔
37

38
    info!("Finished training, saving final model to {}", final_output_path);
4✔
39
    save_model_snapshot(
4✔
40
      self.get_model(),
4✔
41
      &final_output_path
4✔
42
    )
43
  }
4✔
44

45
  // Any actions that need to be called with an awareness of the training step can use these hooks. By default, they do nothing.
46
  // e.g. if you want to anneal the learning rate, that would go here
47
  fn pre_step_hook(&mut self, _step: u32) -> Result<()> {
8✔
48
    Ok(())
8✔
49
  }
8✔
50
  fn post_step_hook(&mut self, _step: u32) -> Result<()> {
8✔
51
    Ok(())
8✔
52
  }
8✔
53
}
54

55

56
pub fn run_supervised_training_loop(handler: &mut dyn TrainingHandler) -> Result<()> {
4✔
57
  handler.pre_training_hook()?;
4✔
58

59
  // Supervised learning
60
  handler.get_model().pin_input();
4✔
61
  handler.get_model().pin_output();
4✔
62

63
  let training_config: &TrainConfig = handler.get_config();
4✔
64
  let training_steps: u32 = training_config.training_steps;
4✔
65
  let report_interval: u32 = training_config.report_interval;
4✔
66
  let snapshot_interval: u32 = training_config.snapshot_interval;
4✔
67

68
  info!(
4✔
69
    "Beginning training loop for {} steps with report interval {} and snapshot interval {}",
70
    training_steps, report_interval, snapshot_interval
71
  );
72

73
  let model_config: &PredictiveCodingModelConfig = &handler.get_model().get_config();
4✔
74
  info!(
4✔
75
    "Model architecture:\n\tlayer sizes: {:?}\n\tgamma: {}\n\talpha: {}\n\tactivation function: {:?}\n\tconvergence steps: {}\n\tconvergence threshold: {}",
76
    model_config.layer_sizes,
77
    model_config.gamma,
78
    model_config.alpha,
79
    model_config.activation_function,
80
    model_config.convergence_steps,
81
    model_config.convergence_threshold
82
  );
83

84
  // Write the config and training params to a file
85
  save_model_config(
4✔
86
    model_config,
4✔
87
    &format!("{}_model_config.json", &handler.get_file_output_prefix())
4✔
88
  )?;
×
89
  save_training_config(
4✔
90
    handler.get_config(),
4✔
91
    &format!("{}_training_config.json", &handler.get_file_output_prefix())
4✔
92
  )?;
×
93

94
  // Main loop
95
  for step in 0..training_steps {
12✔
96
    handler.pre_step_hook(step)?;
12✔
97
    handler.train_step(step)?;
12✔
98
    handler.post_step_hook(step)?;
12✔
99

100
    if (report_interval > 0) && (step % report_interval == 0) {
12✔
101
      handler.report_hook(step)?;
6✔
102
    }
6✔
103

104
    if (snapshot_interval > 0) && (step % snapshot_interval == 0) {
12✔
105
      let oname: String = format!("{}_snapshot_step_{}.json", handler.get_file_output_prefix(), step);
7✔
106
      info!("Saving model snapshot {}", oname);
7✔
107

108
      save_model_snapshot(handler.get_model(), &oname)?;
7✔
109
    }
5✔
110
  }
111

112
  handler.post_training_hook()?;
4✔
113

114
  Ok(())
4✔
115
}
4✔
116

117
#[cfg(test)]
118
mod tests {
119
  use super::*;
120

121
  use crate::{
122
    model_structure::{
123
      model::PredictiveCodingModelConfig,
124
      maths::ActivationFunction
125
    },
126
    training::configuration::{
127
      DataSetSource,
128
      ModelSource,
129
      TrainingStrategy
130
    }
131
  };
132
  use ndarray::{Array1, Array2};
133
  use std::{
134
    fs,
135
    path::PathBuf,
136
    time::{SystemTime, UNIX_EPOCH}
137
  };
138

139
  struct TempDir {
140
    path: PathBuf,
141
  }
142

143
  impl TempDir {
144
    fn new(prefix: &str) -> Self {
2✔
145
      let unique_id = SystemTime::now()
2✔
146
        .duration_since(UNIX_EPOCH)
2✔
147
        .unwrap()
2✔
148
        .as_nanos();
2✔
149
      let path = std::env::temp_dir().join(format!(
2✔
150
        "predictive_coding_{prefix}_{}_{}",
2✔
151
        std::process::id(),
2✔
152
        unique_id
2✔
153
      ));
2✔
154
      fs::create_dir_all(&path).unwrap();
2✔
155
      TempDir { path }
2✔
156
    }
2✔
157

158
    fn path(&self) -> &PathBuf {
7✔
159
      &self.path
7✔
160
    }
7✔
161
  }
162

163
  impl Drop for TempDir {
164
    fn drop(&mut self) {
2✔
165
      let _ = fs::remove_dir_all(&self.path);
2✔
166
    }
2✔
167
  }
168

169
  struct DummyTrainingDataset {
170
    dataset_size: usize,
171
    input_size: usize,
172
    output_size: usize,
173
    inputs: Array2<f32>,
174
    labels: Array2<f32>,
175
  }
176

177
  impl data_handler::TrainingDataset for DummyTrainingDataset {
NEW
178
    fn get_dataset_size(&self) -> usize {self.dataset_size}
×
NEW
179
    fn get_input_size(&self) -> usize {self.input_size}
×
NEW
180
    fn get_output_size(&self) -> usize {self.output_size}
×
NEW
181
    fn get_inputs(&self) -> &Array2<f32> {&self.inputs}
×
NEW
182
    fn get_labels(&self) -> &Array2<f32> {&self.labels}
×
183

NEW
184
    fn get_random_input(&self) -> Array1<f32> {
×
NEW
185
      self.get_input(0)
×
NEW
186
    }
×
187

NEW
188
    fn get_random_input_and_output(&self) -> (Array1<f32>, Array1<f32>) {
×
NEW
189
      (self.get_input(0), self.get_output(0))
×
NEW
190
    }
×
191

NEW
192
    fn get_input(&self, _index: usize) -> Array1<f32> {
×
NEW
193
      self.inputs.row(0).to_owned()
×
NEW
194
    }
×
195

NEW
196
    fn get_output(&self, _index: usize) -> Array1<f32> {
×
NEW
197
      self.labels.row(0).to_owned()
×
NEW
198
    }
×
199
  }
200

201
  struct RecordingHandler {
202
    config: TrainConfig,
203
    model: PredictiveCodingModel,
204
    data: Arc<dyn data_handler::TrainingDataset>,
205
    file_output_prefix: String,
206
    events: Vec<String>,
207
  }
208

209
  impl RecordingHandler {
210
    fn new(config: TrainConfig, output_prefix: String) -> Self {
2✔
211
      let model: PredictiveCodingModel = PredictiveCodingModel::new(&PredictiveCodingModelConfig {
2✔
212
        layer_sizes: vec![4, 10],
2✔
213
        alpha: 0.01,
2✔
214
        gamma: 0.05,
2✔
215
        convergence_threshold: 0.0,
2✔
216
        convergence_steps: 1,
2✔
217
        activation_function: ActivationFunction::Relu,
2✔
218
      });
2✔
219

220
      let data: Arc<dyn data_handler::TrainingDataset> = Arc::new(DummyTrainingDataset {
2✔
221
        dataset_size: 1,
2✔
222
        input_size: 4,
2✔
223
        output_size: 10,
2✔
224
        inputs: Array2::zeros((1, 4)),
2✔
225
        labels: Array2::zeros((1, 10)),
2✔
226
      });
2✔
227

228
      RecordingHandler {
2✔
229
        config,
2✔
230
        model,
2✔
231
        data,
2✔
232
        file_output_prefix: output_prefix,
2✔
233
        events: Vec::new(),
2✔
234
      }
2✔
235
    }
2✔
236
  }
237

238
  // Construct a dummy handler which just records what it's done, in what order.
239
  impl TrainingHandler for RecordingHandler {
240
    fn get_config(&self) -> &TrainConfig {
4✔
241
      &self.config
4✔
242
    }
4✔
243

244
    fn get_model(&mut self) -> &mut PredictiveCodingModel {
11✔
245
      &mut self.model
11✔
246
    }
11✔
247

NEW
248
    fn get_data(&self) -> &dyn data_handler::TrainingDataset {
×
NEW
249
      self.data.as_ref()
×
UNCOV
250
    }
×
251

252
    fn get_file_output_prefix(&self) -> &String {
9✔
253
      &self.file_output_prefix
9✔
254
    }
9✔
255

256
    fn pre_training_hook(&mut self) -> Result<()> {
2✔
257
      self.events.push(String::from("pre_training"));
2✔
258
      Ok(())
2✔
259
    }
2✔
260

261
    fn train_step(&mut self, step: u32) -> Result<()> {
8✔
262
      self.events.push(format!("train_step:{step}"));
8✔
263
      Ok(())
8✔
264
    }
8✔
265

266
    fn report_hook(&mut self, step: u32) -> Result<()> {
2✔
267
      self.events.push(format!("report:{step}"));
2✔
268
      Ok(())
2✔
269
    }
2✔
270

271
    fn pre_step_hook(&mut self, step: u32) -> Result<()> {
8✔
272
      self.events.push(format!("pre_step:{step}"));
8✔
273
      Ok(())
8✔
274
    }
8✔
275

276
    fn post_step_hook(&mut self, step: u32) -> Result<()> {
8✔
277
      self.events.push(format!("post_step:{step}"));
8✔
278
      Ok(())
8✔
279
    }
8✔
280

281
    fn post_training_hook(&mut self) -> Result<()> {
2✔
282
      self.events.push(String::from("post_training"));
2✔
283
      let final_output_path = format!("{}_final_model.json", self.get_file_output_prefix());
2✔
284
      save_model_snapshot(self.get_model(), &final_output_path)
2✔
285
    }
2✔
286
  }
287

288
  fn test_config(training_steps: u32, report_interval: u32, snapshot_interval: u32) -> TrainConfig {
2✔
289
    TrainConfig {
2✔
290
      model_source: ModelSource::Config(String::from("unused.json")),
2✔
291
      dataset: DataSetSource::MNIST {
2✔
292
        input_idx_file: String::from("unused-images.idx"),
2✔
293
        output_idx_file: String::from("unused-labels.idx"),
2✔
294
      },
2✔
295
      training_strategy: TrainingStrategy::SingleThread,
2✔
296
      training_steps,
2✔
297
      report_interval,
2✔
298
      snapshot_interval,
2✔
299
    }
2✔
300
  }
2✔
301

302
  #[test]
303
  fn training_loop_preserves_hook_order_and_report_interval() {
1✔
304
    let temp_dir: TempDir = TempDir::new("training_loop_hooks");
1✔
305
    let output_prefix: String = temp_dir.path().join("model").display().to_string();
1✔
306
    let mut handler: RecordingHandler = RecordingHandler::new(test_config(3, 2, 0), output_prefix);
1✔
307

308
    run_supervised_training_loop(&mut handler).unwrap();
1✔
309

310
    assert_eq!(
1✔
311
      handler.events,
312
      vec![
1✔
313
        String::from("pre_training"),
1✔
314
        String::from("pre_step:0"),
1✔
315
        String::from("train_step:0"),
1✔
316
        String::from("post_step:0"),
1✔
317
        String::from("report:0"),
1✔
318
        String::from("pre_step:1"),
1✔
319
        String::from("train_step:1"),
1✔
320
        String::from("post_step:1"),
1✔
321
        String::from("pre_step:2"),
1✔
322
        String::from("train_step:2"),
1✔
323
        String::from("post_step:2"),
1✔
324
        String::from("report:2"),
1✔
325
        String::from("post_training"),
1✔
326
      ]
327
    );
328
  }
1✔
329

330
  #[test]
331
  fn training_loop_saves_snapshots_at_configured_steps() {
1✔
332
    let temp_dir = TempDir::new("training_loop_snapshots");
1✔
333
    let output_prefix = temp_dir.path().join("model").display().to_string();
1✔
334
    let mut handler = RecordingHandler::new(test_config(5, 0, 2), output_prefix.clone());
1✔
335

336
    run_supervised_training_loop(&mut handler).unwrap();
1✔
337

338
    assert!(temp_dir.path().join("model_snapshot_step_0.json").exists());
1✔
339
    assert!(!temp_dir.path().join("model_snapshot_step_1.json").exists());
1✔
340
    assert!(temp_dir.path().join("model_snapshot_step_2.json").exists());
1✔
341
    assert!(temp_dir.path().join("model_snapshot_step_4.json").exists());
1✔
342
    assert!(temp_dir.path().join("model_final_model.json").exists());
1✔
343
  }
1✔
344
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc