• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wildjames / predictive_coding_rs / 22806481110

07 Mar 2026 08:18PM UTC coverage: 92.84%. First build
22806481110

Pull #6

github

web-flow
Merge 87ce3c80b into 7e8f877af
Pull Request #6: Add unit test coverage, and smoke tests.

414 of 428 new or added lines in 7 files covered. (96.73%)

1154 of 1243 relevant lines covered (92.84%)

10.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.17
/src/model_structure/model_utils.rs
1
//! Math utilities for predictive coding models.
2

3
use crate::{
4
  data_handling::data_handler::TrainingDataset,
5
  model_structure::model::{
6
    PredictiveCodingModel,
7
    PredictiveCodingModelConfig
8
  }
9
};
10

11
use std::path::Path;
12

13
use ndarray::{Array1, Array2, ArrayBase, Data, Dimension};
14
use serde::{Deserialize, Serialize};
15

16
/// Choose a random index using the rng threadlocal generator, and set the model I/O accordingly.
17
pub fn set_rand_input_and_output(
27✔
18
  model: &mut PredictiveCodingModel,
27✔
19
  data: &TrainingDataset
27✔
20
) {
27✔
21
  let rand_index: usize = usize::from_ne_bytes(rand::random()) % data.dataset_size;
27✔
22

23
  // Normalise to the range 0..1
24
  let input_values: Array1<f32> = data.inputs
27✔
25
    .row(rand_index)
27✔
26
    .to_owned();
27✔
27

28
  // One-hot output row with label value set to 1.0
29
  let output_values: Array1<f32> = data.labels
27✔
30
    .row(rand_index)
27✔
31
    .to_owned();
27✔
32

33
  model.set_input(input_values);
27✔
34
  model.set_output(output_values);
27✔
35

36
}
27✔
37

38
pub fn load_model_config(fname: &str) -> PredictiveCodingModelConfig {
2✔
39
  serde_json::from_reader(
2✔
40
    std::fs::File::open(fname).unwrap()
2✔
41
  ).unwrap()
2✔
42
}
2✔
43

NEW
44
pub fn create_from_config(fname: &str) -> PredictiveCodingModel {
×
NEW
45
  let config = load_model_config(fname);
×
46
  PredictiveCodingModel::new(&config)
×
47
}
×
48

49
pub fn save_model_config(
4✔
50
  config: &PredictiveCodingModelConfig,
4✔
51
  filename: &str
4✔
52
) {
4✔
53
  if let Some(parent) = Path::new(filename).parent()
4✔
54
    && !parent.as_os_str().is_empty() {
4✔
55
      std::fs::create_dir_all(parent).unwrap();
4✔
56
    }
4✔
57
  let config_ser = serde_json::to_string(config).unwrap();
4✔
58
  std::fs::write(filename, config_ser).unwrap();
4✔
59
}
4✔
60

61
pub fn save_model_snapshot(
14✔
62
  model: &PredictiveCodingModel,
14✔
63
  filename: &str
14✔
64
) {
14✔
65
  if let Some(parent) = Path::new(filename).parent()
14✔
66
    && !parent.as_os_str().is_empty() {
14✔
67
      std::fs::create_dir_all(parent).unwrap();
14✔
68
    }
14✔
69
  let model_ser = serde_json::to_string(&model).unwrap();
14✔
70
  std::fs::write(filename, model_ser).unwrap();
14✔
71
}
14✔
72

73
pub fn load_model_snapshot(filename: &str) -> PredictiveCodingModel {
8✔
74
  let model_ser = std::fs::read_to_string(filename).unwrap();
8✔
75
  serde_json::from_str(&model_ser).unwrap()
8✔
76
}
8✔
77

78
/// Activation function identifiers for serialization-friendly models.
79
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
80
pub enum ActivationFunction {
81
  Relu,
82
  Sigmoid,
83
  Tanh
84
}
85

86
impl ActivationFunction {
87
  /// Apply the activation function.
88
  pub fn apply(&self, x: f32) -> f32 {
116✔
89
    match self {
116✔
90
      ActivationFunction::Relu => relu(x),
116✔
91
      ActivationFunction::Sigmoid => sigmoid(x),
×
92
      ActivationFunction::Tanh => tanh(x),
×
93
    }
94
  }
116✔
95

96
  /// Apply the activation function derivative.
97
  pub fn derivative(&self, x: f32) -> f32 {
184✔
98
    match self {
184✔
99
      ActivationFunction::Relu => relu_derivitive(x),
184✔
100
      ActivationFunction::Sigmoid => sigmoid_derivitive(x),
×
101
      ActivationFunction::Tanh => tanh_derivative(x),
×
102
    }
103
  }
184✔
104
}
105

106

107
/// Compute the outer product of two arbitrary-dimensional arrays flattened
108
/// in iteration order.
109
///
110
/// Returns a matrix with shape `(a.len(), b.len())` where each element is
111
/// `a[i] * b[j]`.
112
pub fn outer_product<SA, DA, SB, DB>(
45✔
113
  a: &ArrayBase<SA, DA>,
45✔
114
  b: &ArrayBase<SB, DB>
45✔
115
) -> Array2<f32>
45✔
116
where
45✔
117
  SA: Data<Elem = f32>,
45✔
118
  SB: Data<Elem = f32>,
45✔
119
  DA: Dimension,
45✔
120
  DB: Dimension,
45✔
121
{
122
  let a_values: Vec<f32> = a.iter().copied().collect();
45✔
123
  let b_values: Vec<f32> = b.iter().copied().collect();
45✔
124
  let rows = a_values.len();
45✔
125
  let cols = b_values.len();
45✔
126

127
  Array2::from_shape_fn((rows, cols), |(i, j)| a_values[i] * b_values[j])
1,800✔
128
}
45✔
129

130
/// Apply the ReLU activation function.
131
pub fn relu(x: f32) -> f32 {
116✔
132
  if x > 0.0 {
116✔
133
    x
54✔
134
  } else {
135
    0.0
62✔
136
  }
137
}
116✔
138

139
pub fn relu_derivitive(x: f32) -> f32 {
184✔
140
  if x > 0.0 {
184✔
141
    1.0
85✔
142
  } else {
143
    0.0
99✔
144
  }
145
}
184✔
146

147
/// Apply the sigmoid function
148
pub fn sigmoid(x:f32) -> f32 {
2✔
149
  1.0 / (1.0 + (-x).exp())
2✔
150
}
2✔
151

152
pub fn sigmoid_derivitive(x: f32) -> f32 {
2✔
153
  (-x).exp() / (1.0 + (-x).exp()).powi(2)
2✔
154
}
2✔
155

156
// Apply tanh
157
pub fn tanh(x: f32) -> f32 {
×
158
  x.tanh()
×
159
}
×
160

161
pub fn tanh_derivative(x: f32) -> f32 {
×
162
  4.0 / ( (-x).exp() + x.exp()).powi(2)
×
163
}
×
164

165
#[cfg(test)]
166
mod tests {
167
  use super::*;
168

169
  use ndarray::array;
170
  use std::{
171
    fs,
172
    path::{Path, PathBuf},
173
    time::{SystemTime, UNIX_EPOCH}
174
  };
175

176
  struct TempDir {
177
    path: PathBuf,
178
  }
179

180
  impl TempDir {
181
    fn new(prefix: &str) -> Self {
3✔
182
      let unique_id = SystemTime::now()
3✔
183
        .duration_since(UNIX_EPOCH)
3✔
184
        .unwrap()
3✔
185
        .as_nanos();
3✔
186
      let path = std::env::temp_dir().join(format!(
3✔
187
        "predictive_coding_{prefix}_{}_{}",
3✔
188
        std::process::id(),
3✔
189
        unique_id
3✔
190
      ));
3✔
191
      fs::create_dir_all(&path).unwrap();
3✔
192
      TempDir { path }
3✔
193
    }
3✔
194

195
    fn join(&self, filename: &str) -> PathBuf {
3✔
196
      self.path.join(filename)
3✔
197
    }
3✔
198
  }
199

200
  impl Drop for TempDir {
201
    fn drop(&mut self) {
3✔
202
      let _ = fs::remove_dir_all(&self.path);
3✔
203
    }
3✔
204
  }
205

206
  fn write_file(path: &Path, contents: &str) {
2✔
207
    fs::write(path, contents).unwrap();
2✔
208
  }
2✔
209

210
  fn assert_within_tol(expected: f32, actual: f32, tol: f32) {
4✔
211
    assert!(
4✔
212
      (expected - actual).abs() < tol,
4✔
213
      "Expected {}, got {}, which is outside the tolerance of {}",
214
      expected,
215
      actual,
216
      tol
217
    );
218
  }
4✔
219

220
  #[test]
221
  fn test_sigmoid() {
1✔
222
    let x = 0.5;
1✔
223
    let expected = 0.622459;
1✔
224
    let actual = sigmoid(x);
1✔
225
    assert_within_tol(expected, actual, 1e-6);
1✔
226

227
    let x = -0.5;
1✔
228
    let expected = 0.377541;
1✔
229
    let actual = sigmoid(x);
1✔
230
    assert_within_tol(expected, actual, 1e-6);
1✔
231
  }
1✔
232

233
  #[test]
234
  fn test_sigmoid_derivative() {
1✔
235
    let x = 0.5;
1✔
236
    let expected = 0.235004;
1✔
237
    let actual = sigmoid_derivitive(x);
1✔
238
    assert_within_tol(expected, actual, 1e-6);
1✔
239

240
    let x = -0.5;
1✔
241
    let expected = 0.235004;
1✔
242
    let actual = sigmoid_derivitive(x);
1✔
243
    assert_within_tol(expected, actual, 1e-6);
1✔
244
  }
1✔
245

246
  #[test]
247
  fn load_model_config_parses_expected_json_shape() {
1✔
248
    let temp_dir = TempDir::new("model_config_parse");
1✔
249
    let config_path = temp_dir.join("model_config.json");
1✔
250
    write_file(
1✔
251
      &config_path,
1✔
252
      r#"{
1✔
253
  "layer_sizes": [4, 10],
1✔
254
  "alpha": 0.01,
1✔
255
  "gamma": 0.05,
1✔
256
  "convergence_threshold": 0.0,
1✔
257
  "convergence_steps": 2,
1✔
258
  "activation_function": "Tanh"
1✔
259
}"#
1✔
260
    );
261

262
    let actual = load_model_config(config_path.to_str().unwrap());
1✔
263
    let expected = PredictiveCodingModelConfig {
1✔
264
      layer_sizes: vec![4, 10],
1✔
265
      alpha: 0.01,
1✔
266
      gamma: 0.05,
1✔
267
      convergence_threshold: 0.0,
1✔
268
      convergence_steps: 2,
1✔
269
      activation_function: ActivationFunction::Tanh,
1✔
270
    };
1✔
271

272
    assert_eq!(actual, expected);
1✔
273
  }
1✔
274

275
  #[test]
276
  fn load_model_config_rejects_missing_required_fields() {
1✔
277
    let temp_dir = TempDir::new("model_config_missing_field");
1✔
278
    let config_path = temp_dir.join("model_config.json");
1✔
279
    write_file(
1✔
280
      &config_path,
1✔
281
      r#"{
1✔
282
  "layer_sizes": [4, 10],
1✔
283
  "alpha": 0.01,
1✔
284
  "gamma": 0.05,
1✔
285
  "convergence_threshold": 0.0,
1✔
286
  "convergence_steps": 2
1✔
287
}"#
1✔
288
    );
289

290
    let result = std::panic::catch_unwind(|| {
1✔
291
      load_model_config(config_path.to_str().unwrap());
1✔
292
    });
1✔
293

294
    assert!(result.is_err());
1✔
295
  }
1✔
296

297
  #[test]
298
  fn model_snapshot_round_trips_through_disk() {
1✔
299
    let temp_dir = TempDir::new("snapshot_round_trip");
1✔
300
    let snapshot_path = temp_dir.join("model_snapshot.json");
1✔
301
    let mut model = PredictiveCodingModel::new(&PredictiveCodingModelConfig {
1✔
302
      layer_sizes: vec![4, 10],
1✔
303
      alpha: 0.01,
1✔
304
      gamma: 0.05,
1✔
305
      convergence_threshold: 0.0,
1✔
306
      convergence_steps: 2,
1✔
307
      activation_function: ActivationFunction::Relu,
1✔
308
    });
1✔
309

310
    model.set_input(array![1.0, 0.0, 0.5, 0.25]);
1✔
311
    model.set_output(array![0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
1✔
312
    model.compute_predictions_and_errors();
1✔
313

314
    save_model_snapshot(&model, snapshot_path.to_str().unwrap());
1✔
315
    let loaded_model = load_model_snapshot(snapshot_path.to_str().unwrap());
1✔
316

317
    let original_json = serde_json::to_value(&model).unwrap();
1✔
318
    let loaded_json = serde_json::to_value(&loaded_model).unwrap();
1✔
319
    assert_eq!(loaded_json, original_json);
1✔
320
  }
1✔
321
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc