• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wildjames / predictive_coding_rs / 22809140016

07 Mar 2026 11:04PM UTC coverage: 87.11% (-6.0%) from 93.084%
22809140016

push

github

web-flow
Add an error handler, and use Result<>  (#8)

* Add an Error handler

* Wherever an error is thrown, return a Result. When errors are thrown, use appropriate ones from the new classes.

281 of 391 new or added lines in 10 files covered. (71.87%)

1 existing line in 1 file now uncovered.

1257 of 1443 relevant lines covered (87.11%)

9.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.15
/src/model_structure/model_utils.rs
1
//! Math utilities for predictive coding models.
2

3
use crate::{
4
  data_handling::data_handler::TrainingDataset,
5
  error::{PredictiveCodingError, Result},
6
  model_structure::model::{
7
    PredictiveCodingModel,
8
    PredictiveCodingModelConfig
9
  }
10
};
11

12
use std::path::Path;
13

14
use ndarray::{Array1, Array2, ArrayBase, Data, Dimension};
15
use serde::{Deserialize, Serialize};
16

17
/// Choose a random index using the rng threadlocal generator, and set the model I/O accordingly.
18
pub fn set_rand_input_and_output(
27✔
19
  model: &mut PredictiveCodingModel,
27✔
20
  data: &TrainingDataset
27✔
21
) {
27✔
22
  let rand_index: usize = usize::from_ne_bytes(rand::random()) % data.dataset_size;
27✔
23

24
  // Normalise to the range 0..1
25
  let input_values: Array1<f32> = data.inputs
27✔
26
    .row(rand_index)
27✔
27
    .to_owned();
27✔
28

29
  // One-hot output row with label value set to 1.0
30
  let output_values: Array1<f32> = data.labels
27✔
31
    .row(rand_index)
27✔
32
    .to_owned();
27✔
33

34
  model.set_input(input_values);
27✔
35
  model.set_output(output_values);
27✔
36
}
27✔
37

38
fn ensure_parent_dir(filename: &str) -> Result<()> {
20✔
39
  if let Some(parent) = Path::new(filename).parent() && !parent.as_os_str().is_empty() {
20✔
40
      std::fs::create_dir_all(parent)
20✔
41
        .map_err(|source| PredictiveCodingError::io("create directory", parent, source))?;
20✔
NEW
42
    }
×
43

44
  Ok(())
20✔
45
}
20✔
46

47
pub fn load_model_config(fname: &str) -> Result<PredictiveCodingModelConfig> {
3✔
48
  let file = std::fs::File::open(fname)
3✔
49
    .map_err(|source| PredictiveCodingError::io("open model config", fname, source))?;
3✔
50

51
  serde_json::from_reader(file)
2✔
52
    .map_err(|source| PredictiveCodingError::json_deserialize(fname, source))
2✔
53
}
3✔
54

NEW
55
pub fn create_from_config(fname: &str) -> Result<PredictiveCodingModel> {
×
NEW
56
  let config = load_model_config(fname)?;
×
NEW
57
  Ok(PredictiveCodingModel::new(&config))
×
UNCOV
58
}
×
59

60
pub fn save_model_config(
6✔
61
  config: &PredictiveCodingModelConfig,
6✔
62
  filename: &str
6✔
63
) -> Result<()> {
6✔
64
  ensure_parent_dir(filename)?;
6✔
65

66
  let config_ser = serde_json::to_string(config)
6✔
67
    .map_err(|source| PredictiveCodingError::json_serialize(filename, source))?;
6✔
68
  std::fs::write(filename, config_ser)
6✔
69
    .map_err(|source| PredictiveCodingError::io("write model config", filename, source))?;
6✔
70

71
  Ok(())
6✔
72
}
6✔
73

74
pub fn save_model_snapshot(
14✔
75
  model: &PredictiveCodingModel,
14✔
76
  filename: &str
14✔
77
) -> Result<()> {
14✔
78
  ensure_parent_dir(filename)?;
14✔
79

80
  let model_ser = serde_json::to_string(&model)
14✔
81
    .map_err(|source| PredictiveCodingError::json_serialize(filename, source))?;
14✔
82
  std::fs::write(filename, model_ser)
14✔
83
    .map_err(|source| PredictiveCodingError::io("write model snapshot", filename, source))?;
14✔
84

85
  Ok(())
14✔
86
}
14✔
87

88
pub fn load_model_snapshot(filename: &str) -> Result<PredictiveCodingModel> {
6✔
89
  let model_ser = std::fs::read_to_string(filename)
6✔
90
    .map_err(|source| PredictiveCodingError::io("read model snapshot", filename, source))?;
6✔
91

92
  serde_json::from_str(&model_ser)
6✔
93
    .map_err(|source| PredictiveCodingError::json_deserialize(filename, source))
6✔
94
}
6✔
95

96
/// Activation function identifiers for serialization-friendly models.
97
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
98
pub enum ActivationFunction {
99
  Relu,
100
  Sigmoid,
101
  Tanh
102
}
103

104
impl ActivationFunction {
105
  /// Apply the activation function.
106
  pub fn apply(&self, x: f32) -> f32 {
116✔
107
    match self {
116✔
108
      ActivationFunction::Relu => relu(x),
116✔
109
      ActivationFunction::Sigmoid => sigmoid(x),
×
110
      ActivationFunction::Tanh => tanh(x),
×
111
    }
112
  }
116✔
113

114
  /// Apply the activation function derivative.
115
  pub fn derivative(&self, x: f32) -> f32 {
184✔
116
    match self {
184✔
117
      ActivationFunction::Relu => relu_derivitive(x),
184✔
118
      ActivationFunction::Sigmoid => sigmoid_derivitive(x),
×
119
      ActivationFunction::Tanh => tanh_derivative(x),
×
120
    }
121
  }
184✔
122
}
123

124

125
/// Compute the outer product of two arbitrary-dimensional arrays flattened
126
/// in iteration order.
127
///
128
/// Returns a matrix with shape `(a.len(), b.len())` where each element is
129
/// `a[i] * b[j]`.
130
pub fn outer_product<SA, DA, SB, DB>(
45✔
131
  a: &ArrayBase<SA, DA>,
45✔
132
  b: &ArrayBase<SB, DB>
45✔
133
) -> Array2<f32>
45✔
134
where
45✔
135
  SA: Data<Elem = f32>,
45✔
136
  SB: Data<Elem = f32>,
45✔
137
  DA: Dimension,
45✔
138
  DB: Dimension,
45✔
139
{
140
  let a_values: Vec<f32> = a.iter().copied().collect();
45✔
141
  let b_values: Vec<f32> = b.iter().copied().collect();
45✔
142
  let rows = a_values.len();
45✔
143
  let cols = b_values.len();
45✔
144

145
  Array2::from_shape_fn((rows, cols), |(i, j)| a_values[i] * b_values[j])
1,800✔
146
}
45✔
147

148
/// Apply the ReLU activation function.
149
pub fn relu(x: f32) -> f32 {
116✔
150
  if x > 0.0 {
116✔
151
    x
61✔
152
  } else {
153
    0.0
55✔
154
  }
155
}
116✔
156

157
pub fn relu_derivitive(x: f32) -> f32 {
184✔
158
  if x > 0.0 {
184✔
159
    1.0
93✔
160
  } else {
161
    0.0
91✔
162
  }
163
}
184✔
164

165
/// Apply the sigmoid function
166
pub fn sigmoid(x:f32) -> f32 {
2✔
167
  1.0 / (1.0 + (-x).exp())
2✔
168
}
2✔
169

170
pub fn sigmoid_derivitive(x: f32) -> f32 {
2✔
171
  (-x).exp() / (1.0 + (-x).exp()).powi(2)
2✔
172
}
2✔
173

174
// Apply tanh
175
pub fn tanh(x: f32) -> f32 {
×
176
  x.tanh()
×
177
}
×
178

179
pub fn tanh_derivative(x: f32) -> f32 {
×
180
  4.0 / ( (-x).exp() + x.exp()).powi(2)
×
181
}
×
182

183
#[cfg(test)]
184
mod tests {
185
  use super::*;
186

187
  use ndarray::array;
188
  use std::{
189
    fs,
190
    path::{Path, PathBuf},
191
    time::{SystemTime, UNIX_EPOCH}
192
  };
193

194
  struct TempDir {
195
    path: PathBuf,
196
  }
197

198
  impl TempDir {
199
    fn new(prefix: &str) -> Self {
4✔
200
      let unique_id = SystemTime::now()
4✔
201
        .duration_since(UNIX_EPOCH)
4✔
202
        .unwrap()
4✔
203
        .as_nanos();
4✔
204
      let path = std::env::temp_dir().join(format!(
4✔
205
        "predictive_coding_{prefix}_{}_{}",
4✔
206
        std::process::id(),
4✔
207
        unique_id
4✔
208
      ));
4✔
209
      fs::create_dir_all(&path).unwrap();
4✔
210
      TempDir { path }
4✔
211
    }
4✔
212

213
    fn join(&self, filename: &str) -> PathBuf {
4✔
214
      self.path.join(filename)
4✔
215
    }
4✔
216
  }
217

218
  impl Drop for TempDir {
219
    fn drop(&mut self) {
4✔
220
      let _ = fs::remove_dir_all(&self.path);
4✔
221
    }
4✔
222
  }
223

224
  fn write_file(path: &Path, contents: &str) {
2✔
225
    fs::write(path, contents).unwrap();
2✔
226
  }
2✔
227

228
  fn assert_within_tol(expected: f32, actual: f32, tol: f32) {
4✔
229
    assert!(
4✔
230
      (expected - actual).abs() < tol,
4✔
231
      "Expected {}, got {}, which is outside the tolerance of {}",
232
      expected,
233
      actual,
234
      tol
235
    );
236
  }
4✔
237

238
  #[test]
239
  fn test_sigmoid() {
1✔
240
    let x = 0.5;
1✔
241
    let expected = 0.622459;
1✔
242
    let actual = sigmoid(x);
1✔
243
    assert_within_tol(expected, actual, 1e-6);
1✔
244

245
    let x = -0.5;
1✔
246
    let expected = 0.377541;
1✔
247
    let actual = sigmoid(x);
1✔
248
    assert_within_tol(expected, actual, 1e-6);
1✔
249
  }
1✔
250

251
  #[test]
252
  fn test_sigmoid_derivative() {
1✔
253
    let x = 0.5;
1✔
254
    let expected = 0.235004;
1✔
255
    let actual = sigmoid_derivitive(x);
1✔
256
    assert_within_tol(expected, actual, 1e-6);
1✔
257

258
    let x = -0.5;
1✔
259
    let expected = 0.235004;
1✔
260
    let actual = sigmoid_derivitive(x);
1✔
261
    assert_within_tol(expected, actual, 1e-6);
1✔
262
  }
1✔
263

264
  #[test]
265
  fn load_model_config_parses_expected_json_shape() {
1✔
266
    let temp_dir = TempDir::new("model_config_parse");
1✔
267
    let config_path = temp_dir.join("model_config.json");
1✔
268
    write_file(
1✔
269
      &config_path,
1✔
270
      r#"{
1✔
271
  "layer_sizes": [4, 10],
1✔
272
  "alpha": 0.01,
1✔
273
  "gamma": 0.05,
1✔
274
  "convergence_threshold": 0.0,
1✔
275
  "convergence_steps": 2,
1✔
276
  "activation_function": "Tanh"
1✔
277
}"#
1✔
278
    );
279

280
    let actual = load_model_config(config_path.to_str().unwrap()).unwrap();
1✔
281
    let expected = PredictiveCodingModelConfig {
1✔
282
      layer_sizes: vec![4, 10],
1✔
283
      alpha: 0.01,
1✔
284
      gamma: 0.05,
1✔
285
      convergence_threshold: 0.0,
1✔
286
      convergence_steps: 2,
1✔
287
      activation_function: ActivationFunction::Tanh,
1✔
288
    };
1✔
289

290
    assert_eq!(actual, expected);
1✔
291
  }
1✔
292

293
  #[test]
294
  fn load_model_config_rejects_missing_required_fields() {
1✔
295
    let temp_dir = TempDir::new("model_config_missing_field");
1✔
296
    let config_path = temp_dir.join("model_config.json");
1✔
297
    write_file(
1✔
298
      &config_path,
1✔
299
      r#"{
1✔
300
  "layer_sizes": [4, 10],
1✔
301
  "alpha": 0.01,
1✔
302
  "gamma": 0.05,
1✔
303
  "convergence_threshold": 0.0,
1✔
304
  "convergence_steps": 2
1✔
305
}"#
1✔
306
    );
307

308
    let result = load_model_config(config_path.to_str().unwrap());
1✔
309

310
    assert!(result.is_err());
1✔
311
  }
1✔
312

313
  #[test]
314
  fn load_model_config_error_includes_path() {
1✔
315
    let temp_dir = TempDir::new("missing_model_config");
1✔
316
    let config_path = temp_dir.join("missing_model_config.json");
1✔
317

318
    let error = load_model_config(config_path.to_str().unwrap()).unwrap_err();
1✔
319

320
    assert!(error.to_string().contains(&config_path.display().to_string()));
1✔
321
  }
1✔
322

323
  #[test]
324
  fn model_snapshot_round_trips_through_disk() {
1✔
325
    let temp_dir = TempDir::new("snapshot_round_trip");
1✔
326
    let snapshot_path = temp_dir.join("model_snapshot.json");
1✔
327
    let mut model = PredictiveCodingModel::new(&PredictiveCodingModelConfig {
1✔
328
      layer_sizes: vec![4, 10],
1✔
329
      alpha: 0.01,
1✔
330
      gamma: 0.05,
1✔
331
      convergence_threshold: 0.0,
1✔
332
      convergence_steps: 2,
1✔
333
      activation_function: ActivationFunction::Relu,
1✔
334
    });
1✔
335

336
    model.set_input(array![1.0, 0.0, 0.5, 0.25]);
1✔
337
    model.set_output(array![0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
1✔
338
    model.compute_predictions_and_errors();
1✔
339

340
    save_model_snapshot(&model, snapshot_path.to_str().unwrap()).unwrap();
1✔
341
    let loaded_model = load_model_snapshot(snapshot_path.to_str().unwrap()).unwrap();
1✔
342

343
    let original_json = serde_json::to_value(&model).unwrap();
1✔
344
    let loaded_json = serde_json::to_value(&loaded_model).unwrap();
1✔
345
    assert_eq!(loaded_json, original_json);
1✔
346
  }
1✔
347
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc