• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wildjames / predictive_coding_rs / 22830269200

08 Mar 2026 09:28PM UTC coverage: 90.452% (+3.3%) from 87.11%
22830269200

Pull #10

github

web-flow
Merge 3abad3f15 into 5f8ad5902
Pull Request #10: Generalise the data boundary

457 of 502 new or added lines in 11 files covered. (91.04%)

1 existing line in 1 file now uncovered.

1601 of 1770 relevant lines covered (90.45%)

15.35 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.35
/src/evaluate.rs
1
//! This program takes a trained modek, and evaluates it against the MNIST test dataset, reporting its accuracy, convergence time, and confidence when correct. It also saves these results to a file in the same directory as the model.
2

3
use std::{path::Path, time::Instant};
4
use ndarray::Array1;
5
use serde::{Deserialize, Serialize};
6
use tracing::info;
7

8
use predictive_coding::{
9
  data_handling::{data_handler::TrainingDataset, mnist::{MnistDataset, load_mnist}},
10
  error::{PredictiveCodingError, Result},
11
  model_structure::{configuration::load_model_snapshot, model::PredictiveCodingModel},
12
  utils::logging
13
};
14

15
use clap::Parser;
16

17
/// Evaluate a trained model against the MNIST test dataset, and report its accuracy.
18
#[derive(Parser)]
19
struct EvalArgs {
20
  /// The model file to evaluate
21
  #[arg()]
22
  model_file: String,
23

24
  /// IDX image file to evaluate against.
25
  #[arg(long, default_value_t = String::from("data/mnist/t10k-images-idx3-ubyte"))]
26
  input_idx_file: String,
27

28
  /// IDX label file to evaluate against.
29
  #[arg(long, default_value_t = String::from("data/mnist/t10k-labels-idx1-ubyte"))]
30
  output_idx_file: String,
31
}
32

33
#[derive(Clone, Copy, Deserialize, Serialize, Debug, PartialEq)]
34
struct EvaluationSummary {
35
  accuracy: f32,
36
  mean_convergence_time_ms: f32,
37
  mean_confidence_when_correct: f32,
38
  correct_predictions: usize,
39
  total_predictions: usize,
40
}
41

42
fn summarise_evaluation_samples<F>(dataset_size: usize, mut evaluate_sample: F) -> Result<EvaluationSummary>
4✔
43
where
4✔
44
  F: FnMut(usize) -> Result<(usize, usize, f32, f32)>,
4✔
45
{
46
  if dataset_size == 0 {
4✔
47
    return Err(PredictiveCodingError::invalid_data("evaluation dataset is empty"));
1✔
48
  }
3✔
49

50
  let mut correct_predictions: usize = 0;
3✔
51
  let mut total_predictions: usize = 0;
3✔
52
  let mut confidence_sum: f32 = 0.0;
3✔
53
  let mut sum_covnvergence_time: f32 = 0.0;
3✔
54

55
  for i in 0..dataset_size {
1,004✔
56
    let (output_label, predicted_label, predicted_confidence, elapsed_time_ms) = evaluate_sample(i)?;
1,004✔
57

58
    if (i > 0) && (i % 1000 == 0) {
1,004✔
59
      let accuracy_percent = correct_predictions as f32 / total_predictions as f32 * 100.0;
1✔
60
      info!(
1✔
61
        "Current accuracy after {} samples: {:.2}%",
62
        i,
63
        accuracy_percent
64
      );
65
    }
1,003✔
66

67
    if predicted_label == output_label {
1,004✔
68
      correct_predictions += 1;
1,001✔
69
      confidence_sum += predicted_confidence;
1,001✔
70
    }
1,001✔
71
    total_predictions += 1;
1,004✔
72
    sum_covnvergence_time += elapsed_time_ms;
1,004✔
73
  }
74

75
  let accuracy: f32 = correct_predictions as f32 / total_predictions as f32;
3✔
76
  let mean_convergence_time: f32 = sum_covnvergence_time / total_predictions as f32;
3✔
77
  let mean_confidence: f32 = if correct_predictions > 0 {
3✔
78
    confidence_sum / correct_predictions as f32
1✔
79
  } else {
80
    0.0
2✔
81
  };
82

83
  Ok(EvaluationSummary {
3✔
84
    accuracy,
3✔
85
    mean_convergence_time_ms: mean_convergence_time,
3✔
86
    mean_confidence_when_correct: mean_confidence,
3✔
87
    correct_predictions,
3✔
88
    total_predictions,
3✔
89
  })
3✔
90
}
4✔
91

92

93
fn evaluate_sample(model: &mut PredictiveCodingModel, data: &dyn TrainingDataset, i: usize) -> Result<(usize, usize, f32, f32)> {
1✔
94
    let input_values: Array1<f32> = data.get_input(i);
1✔
95
    let output_values: Array1<f32> = data.get_output(i);
1✔
96

97
    let output_label: usize = output_values
1✔
98
      .iter()
1✔
99
      .enumerate()
1✔
100
      .max_by(|a, b| a.1.total_cmp(b.1))
9✔
101
      .map(|(index, _)| index)
1✔
102
      .ok_or_else(|| PredictiveCodingError::invalid_data("dataset produced an empty output label"))?;
1✔
103

104
    model.reinitialise_latents();
1✔
105
    model.set_input(input_values);
1✔
106
    let start_time = Instant::now();
1✔
107
    model.converge_values();
1✔
108
    let elapsed_time = start_time.elapsed();
1✔
109

110
    let output_activations: &Array1<f32> = model.get_output();
1✔
111
    let predicted_label: usize = output_activations
1✔
112
      .iter()
1✔
113
      .enumerate()
1✔
114
      .max_by(|a, b| a.1.total_cmp(b.1))
9✔
115
      .map(|(index, _)| index)
1✔
116
      .ok_or_else(|| PredictiveCodingError::invalid_data("model produced an empty output layer"))?;
1✔
117

118
    Ok((
1✔
119
      output_label,
1✔
120
      predicted_label,
1✔
121
      output_activations[predicted_label],
1✔
122
      elapsed_time.as_millis() as f32,
1✔
123
    ))
1✔
124
  }
1✔
125

126

127
fn main() -> Result<()> {
1✔
128
  logging::setup_tracing(false);
1✔
129

130
  let args = EvalArgs::parse();
1✔
131

132
  let mut model = load_model_snapshot(&args.model_file)?;
1✔
133
  info!("Loaded model from {}", args.model_file);
1✔
134

135

136
  let data: MnistDataset = load_mnist(
1✔
137
      &args.input_idx_file,
1✔
138
      &args.output_idx_file
1✔
NEW
139
  )?;
×
140
  info!(
1✔
141
    "Loaded the MNIST testing dataset. I have {} images",
142
    data.get_dataset_size()
1✔
143
  );
144

145
  // The output must be unpinned for evaluation
146
  // It was probably pinned during training, so just check
147
  model.unpin_output();
1✔
148
  let summary: EvaluationSummary = summarise_evaluation_samples(
1✔
149
    data.get_dataset_size(),
1✔
150
    |i| evaluate_sample(&mut model, &data, i)
1✔
NEW
151
  )?;
×
152

153
  info!(
1✔
154
    "Evaluation complete. Accuracy: {:.2}%, convergence time on average is {:.0}ms. When correct, average confidence: {:.3}",
155
    summary.accuracy * 100.0,
1✔
156
    summary.mean_convergence_time_ms,
157
    summary.mean_confidence_when_correct
158
  );
159
  // Write to a file in the model directory
160
  let output_dir = Path::new(&args.model_file)
1✔
161
    .parent()
1✔
162
    .filter(|path| !path.as_os_str().is_empty())
1✔
163
    .unwrap_or_else(|| Path::new("./evaluation_results")); // fallback
1✔
164

165
  let output_path = output_dir.join("evaluation_results.json");
1✔
166
  let output_file = std::fs::File::create(&output_path)
1✔
167
    .map_err(|source| PredictiveCodingError::io("create evaluation results", &output_path, source))?;
1✔
168
  serde_json::to_writer_pretty(
1✔
169
    output_file,
1✔
170
    &serde_json::json!({
1✔
171
      "summary": summary,
1✔
172
      "model_file": args.model_file,
1✔
173
      "config": model.get_config(),
1✔
174
    }),
1✔
175
  ).map_err(|source| PredictiveCodingError::json_serialize(&output_path, source))?;
1✔
176

177
  Ok(())
1✔
178
}
1✔
179

180
#[cfg(test)]
181
mod tests {
182
  use super::*;
183

184
  #[test]
185
  fn summarise_evaluation_samples_rejects_empty_datasets() {
1✔
186
    let result = summarise_evaluation_samples(0, |_| Ok((0, 0, 0.0, 0.0)));
1✔
187

188
    assert!(matches!(
1✔
189
      result,
1✔
190
      Err(PredictiveCodingError::InvalidData { message })
1✔
191
        if message == "evaluation dataset is empty"
1✔
192
    ));
193
  }
1✔
194

195
  #[test]
196
  fn summarise_evaluation_samples_accumulates_accuracy_confidence_and_time() {
1✔
197
    let summary = summarise_evaluation_samples(1001, |_| Ok((1, 1, 0.8, 2.0))).unwrap();
1,001✔
198

199
    assert_eq!(summary.correct_predictions, 1001);
1✔
200
    assert_eq!(summary.total_predictions, 1001);
1✔
201
    assert_eq!(summary.accuracy, 1.0);
1✔
202
    assert!((summary.mean_confidence_when_correct - 0.8).abs() < 1e-5);
1✔
203
    assert_eq!(summary.mean_convergence_time_ms, 2.0);
1✔
204
  }
1✔
205

206
  #[test]
207
  fn summarise_evaluation_samples_returns_zero_confidence_when_all_predictions_are_wrong() {
1✔
208
    let summary = summarise_evaluation_samples(2, |i| Ok((i, i + 1, 0.9, 3.0))).unwrap();
2✔
209

210
    assert_eq!(summary.correct_predictions, 0);
1✔
211
    assert_eq!(summary.total_predictions, 2);
1✔
212
    assert_eq!(summary.accuracy, 0.0);
1✔
213
    assert_eq!(summary.mean_confidence_when_correct, 0.0);
1✔
214
    assert_eq!(summary.mean_convergence_time_ms, 3.0);
1✔
215
  }
1✔
216
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc