• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wildjames / predictive_coding_rs / 22806511360

07 Mar 2026 08:20PM UTC coverage: 92.679%. First build
22806511360

push

github

web-flow
Add unit test coverage, and smoke tests.  (#6)

* Ask GPT 5.4 for an architecture review. Some of it might seem a bit overbuilt to me, but I'll still consider the advice in absence of human reviewers

* End to end tests for the benchmark, train, and evaluate binaries

* Unit test coverage

* Set up debugger properly for the three binaries

* Run unit tests as a workflow

* use more advanced rust install action

* Try and upload coverage to coveralls

* Use job correctly

414 of 428 new or added lines in 7 files covered. (96.73%)

1152 of 1243 relevant lines covered (92.68%)

10.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.01
/src/training/train_handler.rs
1
use tracing::info;
2

3
use crate::{
4
  data_handling::data_handler,
5
  model_structure::{
6
    model::PredictiveCodingModel,
7
    model_utils::save_model_snapshot
8
  },
9
  training::utils::{TrainConfig}
10
};
11

12
pub trait TrainingHandler {
13
  fn get_config(&self) -> &TrainConfig;
14
  fn get_model(&mut self) -> &mut PredictiveCodingModel;
15
  fn get_data(&self) -> &data_handler::TrainingDataset;
16
  fn get_file_output_prefix(&self) -> &String;
17

18
  fn pre_training_hook(&mut self);
19

20
  fn train_step(&mut self, _step: u32);
21
  fn report_hook(&mut self, _step: u32);
22

23
  /// Once the model has completed training, this will be called. By default, it saves the final model to "{file_output_prefix}_final.json"
24
  fn post_training_hook(&mut self) {
4✔
25
    let final_output_path: String = format!("{}_{}", self.get_file_output_prefix(), "final_model.json");
4✔
26

27
    info!("Finished training, saving final model to {}", final_output_path);
4✔
28
    save_model_snapshot(
4✔
29
      self.get_model(),
4✔
30
      &final_output_path
4✔
31
    )
32
  }
4✔
33

34
  // Any actions that need to be called with an awareness of the training step can use these hooks. By default, they do nothing.
35
  // e.g. if you want to anneal the learning rate, that would go here
36
  fn pre_step_hook(&mut self, _step: u32) {}
8✔
37
  fn post_step_hook(&mut self, _step: u32) {}
8✔
38
}
39

40

41
pub fn run_supervised_training_loop(handler: &mut dyn TrainingHandler) {
4✔
42
  handler.pre_training_hook();
4✔
43

44
  // Supervised learning
45
  handler.get_model().pin_input();
4✔
46
  handler.get_model().pin_output();
4✔
47

48
  let training_config: &TrainConfig = handler.get_config();
4✔
49
  let training_steps: u32 = training_config.training_steps;
4✔
50
  let report_interval: u32 = training_config.report_interval;
4✔
51
  let snapshot_interval: u32 = training_config.snapshot_interval;
4✔
52

53
  for step in 0..training_steps {
12✔
54
    handler.pre_step_hook(step);
12✔
55
    handler.train_step(step);
12✔
56
    handler.post_step_hook(step);
12✔
57

58
    if (report_interval > 0) && (step % report_interval == 0) {
12✔
59
      handler.report_hook(step);
6✔
60
    }
6✔
61

62
    if (snapshot_interval > 0) && (step % snapshot_interval == 0) {
12✔
63
      let oname: String = format!("{}_snapshot_step_{}.json", handler.get_file_output_prefix(), step);
7✔
64
      info!("Saving model snapshot {}", oname);
7✔
65

66
      save_model_snapshot(handler.get_model(), &oname);
7✔
67
    }
5✔
68
  }
69

70
  handler.post_training_hook();
4✔
71
}
4✔
72

73
#[cfg(test)]
74
mod tests {
75
  use super::*;
76

77
  use crate::{
78
    model_structure::{
79
      model::PredictiveCodingModelConfig,
80
      model_utils::ActivationFunction
81
    },
82
    training::utils::{
83
      DataSetSource,
84
      ModelSource,
85
      TrainingStrategy
86
    }
87
  };
88
  use ndarray::Array2;
89
  use std::{
90
    fs,
91
    path::PathBuf,
92
    time::{SystemTime, UNIX_EPOCH}
93
  };
94

95
  struct TempDir {
96
    path: PathBuf,
97
  }
98

99
  impl TempDir {
100
    fn new(prefix: &str) -> Self {
2✔
101
      let unique_id = SystemTime::now()
2✔
102
        .duration_since(UNIX_EPOCH)
2✔
103
        .unwrap()
2✔
104
        .as_nanos();
2✔
105
      let path = std::env::temp_dir().join(format!(
2✔
106
        "predictive_coding_{prefix}_{}_{}",
2✔
107
        std::process::id(),
2✔
108
        unique_id
2✔
109
      ));
2✔
110
      fs::create_dir_all(&path).unwrap();
2✔
111
      TempDir { path }
2✔
112
    }
2✔
113

114
    fn path(&self) -> &PathBuf {
7✔
115
      &self.path
7✔
116
    }
7✔
117
  }
118

119
  impl Drop for TempDir {
120
    fn drop(&mut self) {
2✔
121
      let _ = fs::remove_dir_all(&self.path);
2✔
122
    }
2✔
123
  }
124

125
  struct RecordingHandler {
126
    config: TrainConfig,
127
    model: PredictiveCodingModel,
128
    data: data_handler::TrainingDataset,
129
    file_output_prefix: String,
130
    events: Vec<String>,
131
  }
132

133
  impl RecordingHandler {
134
    fn new(config: TrainConfig, output_prefix: String) -> Self {
2✔
135
      let model: PredictiveCodingModel = PredictiveCodingModel::new(&PredictiveCodingModelConfig {
2✔
136
        layer_sizes: vec![4, 10],
2✔
137
        alpha: 0.01,
2✔
138
        gamma: 0.05,
2✔
139
        convergence_threshold: 0.0,
2✔
140
        convergence_steps: 1,
2✔
141
        activation_function: ActivationFunction::Relu,
2✔
142
      });
2✔
143
      let data: data_handler::TrainingDataset = data_handler::TrainingDataset {
2✔
144
        dataset_size: 1,
2✔
145
        input_size: 4,
2✔
146
        output_size: 10,
2✔
147
        inputs: Array2::zeros((1, 4)),
2✔
148
        labels: Array2::zeros((1, 10)),
2✔
149
      };
2✔
150

151
      RecordingHandler {
2✔
152
        config,
2✔
153
        model,
2✔
154
        data,
2✔
155
        file_output_prefix: output_prefix,
2✔
156
        events: Vec::new(),
2✔
157
      }
2✔
158
    }
2✔
159
  }
160

161
  // Construct a dummy handler which just records what it's done, in what order.
162
  impl TrainingHandler for RecordingHandler {
163
    fn get_config(&self) -> &TrainConfig {
2✔
164
      &self.config
2✔
165
    }
2✔
166

167
    fn get_model(&mut self) -> &mut PredictiveCodingModel {
9✔
168
      &mut self.model
9✔
169
    }
9✔
170

NEW
171
    fn get_data(&self) -> &data_handler::TrainingDataset {
×
NEW
172
      &self.data
×
NEW
173
    }
×
174

175
    fn get_file_output_prefix(&self) -> &String {
5✔
176
      &self.file_output_prefix
5✔
177
    }
5✔
178

179
    fn pre_training_hook(&mut self) {
2✔
180
      self.events.push(String::from("pre_training"));
2✔
181
    }
2✔
182

183
    fn train_step(&mut self, step: u32) {
8✔
184
      self.events.push(format!("train_step:{step}"));
8✔
185
    }
8✔
186

187
    fn report_hook(&mut self, step: u32) {
2✔
188
      self.events.push(format!("report:{step}"));
2✔
189
    }
2✔
190

191
    fn pre_step_hook(&mut self, step: u32) {
8✔
192
      self.events.push(format!("pre_step:{step}"));
8✔
193
    }
8✔
194

195
    fn post_step_hook(&mut self, step: u32) {
8✔
196
      self.events.push(format!("post_step:{step}"));
8✔
197
    }
8✔
198

199
    fn post_training_hook(&mut self) {
2✔
200
      self.events.push(String::from("post_training"));
2✔
201
      let final_output_path = format!("{}_final_model.json", self.get_file_output_prefix());
2✔
202
      save_model_snapshot(self.get_model(), &final_output_path);
2✔
203
    }
2✔
204
  }
205

206
  fn test_config(training_steps: u32, report_interval: u32, snapshot_interval: u32) -> TrainConfig {
2✔
207
    TrainConfig {
2✔
208
      model_source: ModelSource::Config(String::from("unused.json")),
2✔
209
      dataset: DataSetSource::MNIST {
2✔
210
        input_idx_file: String::from("unused-images.idx"),
2✔
211
        output_idx_file: String::from("unused-labels.idx"),
2✔
212
      },
2✔
213
      training_strategy: TrainingStrategy::SingleThread,
2✔
214
      training_steps,
2✔
215
      report_interval,
2✔
216
      snapshot_interval,
2✔
217
    }
2✔
218
  }
2✔
219

220
  #[test]
221
  fn training_loop_preserves_hook_order_and_report_interval() {
1✔
222
    let temp_dir: TempDir = TempDir::new("training_loop_hooks");
1✔
223
    let output_prefix: String = temp_dir.path().join("model").display().to_string();
1✔
224
    let mut handler: RecordingHandler = RecordingHandler::new(test_config(3, 2, 0), output_prefix);
1✔
225

226
    run_supervised_training_loop(&mut handler);
1✔
227

228
    assert_eq!(
1✔
229
      handler.events,
230
      vec![
1✔
231
        String::from("pre_training"),
1✔
232
        String::from("pre_step:0"),
1✔
233
        String::from("train_step:0"),
1✔
234
        String::from("post_step:0"),
1✔
235
        String::from("report:0"),
1✔
236
        String::from("pre_step:1"),
1✔
237
        String::from("train_step:1"),
1✔
238
        String::from("post_step:1"),
1✔
239
        String::from("pre_step:2"),
1✔
240
        String::from("train_step:2"),
1✔
241
        String::from("post_step:2"),
1✔
242
        String::from("report:2"),
1✔
243
        String::from("post_training"),
1✔
244
      ]
245
    );
246
  }
1✔
247

248
  #[test]
249
  fn training_loop_saves_snapshots_at_configured_steps() {
1✔
250
    let temp_dir = TempDir::new("training_loop_snapshots");
1✔
251
    let output_prefix = temp_dir.path().join("model").display().to_string();
1✔
252
    let mut handler = RecordingHandler::new(test_config(5, 0, 2), output_prefix.clone());
1✔
253

254
    run_supervised_training_loop(&mut handler);
1✔
255

256
    assert!(temp_dir.path().join("model_snapshot_step_0.json").exists());
1✔
257
    assert!(!temp_dir.path().join("model_snapshot_step_1.json").exists());
1✔
258
    assert!(temp_dir.path().join("model_snapshot_step_2.json").exists());
1✔
259
    assert!(temp_dir.path().join("model_snapshot_step_4.json").exists());
1✔
260
    assert!(temp_dir.path().join("model_final_model.json").exists());
1✔
261
  }
1✔
262
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc