• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wildjames / predictive_coding_rs / 22819940947

08 Mar 2026 11:14AM UTC coverage: 87.11%. Remained the same
22819940947

Pull #9

github

web-flow
Merge 27ab4fcf4 into 15dfb054d
Pull Request #9: Split training utils out into validation and configuration files

277 of 298 new or added lines in 5 files covered. (92.95%)

4 existing lines in 1 file now uncovered.

1257 of 1443 relevant lines covered (87.11%)

9.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.77
/src/model_structure/model.rs
1
//! Predictive coding model implementation.
2
//!
3
//! Defines a layered model with local prediction errors and weight updates.
4

5
use crate::model_structure::maths::{ActivationFunction, outer_product};
6

7
use serde::{Deserialize, Serialize};
8
use ndarray::{Array1, Array2};
9
use rand::RngExt;
10

11

12
/// A single predictive coding layer with values, predictions, errors, and weights.
13
#[derive(Clone, Debug, Serialize, Deserialize)]
14
pub struct Layer {
15
  pub values: Array1<f32>, /// node activation values for this layer, x^l
16
  pub predictions: Array1<f32>, //Predictions for the value of nodes in this layer, according to the layer above. u^l = f(x^{l+1}, w^{l+1})
17
  pub errors: Array1<f32>, // Errors for this layer, e^l
18
  pub weights: Array2<f32>, // weights to predict the layer below, w^l
19
  pub pinned: bool, // If a layer is pinned, its values are not updated during time evolution (e.g. input layers in unsupervised learning, or input and output layers in supervised learning)
20
  pub activation_function: ActivationFunction,
21
  pub size: usize, // The number of nodes in this layer, for easy reference. Should be the same as values.len(), predictions.len(), and /errors.len()
22
  xavier_limit: f32,
23
}
24

25
impl Layer {
26
  /// Initialises a layer of the given size.
27
  /// Populates the values if given, and pins the layer against changing the values during compute iterations if specified.
28
  /// If values are not given, they're set to random vlaues between 0 and 1
29
  /// Weights are randomly initialised, and predictions and errors are initialised to 0.0
30
  /// Takes ownership of the given values, if they are given, so that we can updated them in place later.
31
  fn new(
12✔
32
    size: usize,
12✔
33
    lower_size: Option<usize>,
12✔
34
    activation_function: ActivationFunction,
12✔
35
    values: Option<Array1<f32>>,
12✔
36
    pinned: Option<bool>
12✔
37
  ) -> Self {
12✔
38
    let mut rng = rand::rng();
12✔
39

40
    // Use provided values if we have them, otherwise random data 0..1
41
    let values: Array1<f32> = match values {
12✔
42
      Some(v) => v,
×
43
      None => Array1::from_shape_fn(size, |_| rng.random_range(0.0..1.0)),
84✔
44
    };
45

46
    // Generate random weights for a blank model layer.
47
    // Shape is (lower_size, size) to map from this layer to the one below.
48
    let weights_shape = match lower_size {
12✔
49
      Some(lower) => (lower, size),
6✔
50
      None => (0, size),
6✔
51
    };
52
    // Xavier initialization: U(-limit, limit) where limit = sqrt(6 / (fan_in + fan_out))
53
    let xavier_limit: f32 = if weights_shape.0 + weights_shape.1 > 0 {
12✔
54
      (6.0_f32 / (weights_shape.0 + weights_shape.1) as f32).sqrt()
12✔
55
    } else {
56
      1.0
×
57
    };
58
    let weights = Array2::from_shape_fn(weights_shape, |_| rng.random_range(-xavier_limit..xavier_limit));
240✔
59

60
    Layer {
12✔
61
      values,
12✔
62
      predictions: Array1::zeros(size),
12✔
63
      errors: Array1::zeros(size),
12✔
64
      weights,
12✔
65
      pinned: pinned.unwrap_or(false),
12✔
66
      activation_function,
12✔
67
      size,
12✔
68
      xavier_limit,
12✔
69
    }
12✔
70
  }
12✔
71

72
  /// Randomise weights between -xavier_limit and xavier_limit for all nodes in this layer.
73
  pub fn randomise_weights(&mut self) {
×
74
    let mut rng = rand::rng();
×
75
    self.weights = Array2::from_shape_fn(self.weights.dim(), |_| rng.random_range(-self.xavier_limit..self.xavier_limit));
×
76
  }
×
77

78
  /// Randomise values between 0..1 for all nodes in this layer.
79
  pub fn randomise_values(&mut self, rng: &mut rand::prelude::ThreadRng) {
1✔
80
    self.values = Array1::from_shape_fn(self.values.len(), |_| rng.random_range(0.0..1.0));
10✔
81
  }
1✔
82

83
  /// Replace the layer values and pin them to avoid updates during inference.
84
  fn pin_values(&mut self, values: Array1<f32>) {
57✔
85
    self.values = values;
57✔
86
    self.pinned = true;
57✔
87
  }
57✔
88

89
  /// Unpin the layer values to allow updates during inference.
90
  fn unpin_values(&mut self) {
1✔
91
    self.pinned = false;
1✔
92
  }
1✔
93

94
  /// Update the predictions for this layer based on the values of the layer above it.
95
  fn compute_predictions(&mut self, upper_layer: &Layer) {
29✔
96
    // Note that the prediction computation should *never* be run for an output layer, but making sure of this is the responsibility of the model, not the layer.
97
    // Besides, since an output layer has no upper layer to pass in, this function would not be callable
98

99
    // u^l = phi(W^{l+1} * x^{l+1})
100
    // Preactivation first, then apply nonlinearity
101
    let preactivation: Array1<f32> = upper_layer.weights.dot(&upper_layer.values);
29✔
102
    self.predictions = preactivation.mapv(|a| upper_layer.activation_function.apply(a));
116✔
103
  }
29✔
104

105
  /// Update the errors for this layer based on the predictions and values of this layer.
106
  fn compute_errors(&mut self) {
58✔
107
    self.errors = &self.values - &self.predictions;
58✔
108
  }
58✔
109

110
  /// Sum the signed error values for all nodes in this layer.
111
  fn read_total_error(&self) -> f32 {
×
112
    self.errors.iter().sum()
×
113
  }
×
114

115
  /// Sum the squared error values for all nodes in this layer.
116
  fn read_total_energy(&self) -> f32 {
8✔
117
    // E = 1/2 * sum(err^2)
118
    self.errors.mapv(|x| x.powi(2)).iter().sum()
56✔
119
  }
8✔
120

121
  /// Compute the change in node values under a single timestep of PC.
122
  /// Returns the summed absolute change in node values across this layer.
123
  /// For the input layer, there is no lower layer and None should be passed in instead.
124
  fn values_timestep(&mut self, is_top_level: bool, gamma: f32, lower_layer: Option<&Layer>) -> f32 {
56✔
125
    if self.pinned {
56✔
126
      return 0.0
55✔
127
    }
1✔
128

129
    let rhs: Array1<f32> = if let Some(lower_layer) = lower_layer {
1✔
130
      // RHS: W^{l,T} * (phi'(a^{l-1}) (hammard) e^{l-1})
131
      // where a^{l-1} = W^l * x^l is the preactivation for the layer below (see 2506.06332)
132

133
      // a^{l-1} = W^l * x^l
134
      let preactivation: Array1<f32> = self.weights.dot(&self.values);
1✔
135

136
      // phi'(a^{l-1})
137
      let activation_function_eval_derivitive: Array1<f32> = preactivation.mapv(|a| self.activation_function.derivative(a));
4✔
138

139
      // phi'(a^{l-1}) (hammard) e^{l-1}
140
      let gain_modulated_errors: Array1<f32> = activation_function_eval_derivitive * &lower_layer.errors;
1✔
141

142
      // W^{l,T} * (phi'(a^{l-1}) (hammard) e^{l-1})
143
      self.weights.t().dot(&gain_modulated_errors)
1✔
144
    } else {
145
       Array1::zeros(self.values.len())
×
146
    };
147

148
    // Note that in the output layer, errors are always 0 so the first term of the parentheses is ignored.
149
    let value_changes: Array1<f32> = if is_top_level {
1✔
150
      gamma * rhs
×
151
    } else {
152
      gamma * (-&self.errors + rhs)
1✔
153
    };
154

155
    // Update my values and sum the changes to return
156
    self.values += &value_changes;
1✔
157
    value_changes.mapv(|x| x.abs()).sum()
10✔
158
  }
56✔
159

160
  fn compute_weight_updates(&mut self, alpha: f32, lower_layer: &Layer) -> Array2<f32> {
45✔
161
    // W^{l+1} += alpha * (phi'(a^l) (hammard) e^l) * x^{l+1,T}
162
    // where a^l = W^{l+1} * x^{l+1} is the preactivation for the layer below
163
    let preactivation: Array1<f32> = self.weights.dot(&self.values);
45✔
164
    let activation_function_result: Array1<f32> = preactivation.mapv(|a| self.activation_function.derivative(a));
180✔
165
    let gain_modulated_errors: Array1<f32> = &activation_function_result * &lower_layer.errors;
45✔
166

167
    // outer product yields (lower_size, upper_size)
168
    alpha * outer_product(&gain_modulated_errors, &self.values)
45✔
169
  }
45✔
170

171
  /// Update prediction weights after convergence based on lower-layer errors.
172
  fn update_weights(&mut self, alpha: f32, lower_layer: &Layer) {
25✔
173
    let weight_changes: Array2<f32> = self.compute_weight_updates(alpha, lower_layer);
25✔
174
    self.weights += &weight_changes;
25✔
175
  }
25✔
176
}
177

178
/// A multi-layer predictive coding model with value and weight updates.
179
#[derive(Clone, Debug, Serialize, Deserialize)]
180
pub struct PredictiveCodingModel {
181
  layers: Vec<Layer>,
182
  alpha: f32, // synaptic learning rate
183
  gamma: f32, // neural learning rate
184
  convergence_threshold: f32,
185
  convergence_steps: u32,
186
}
187

188
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
189
pub struct PredictiveCodingModelConfig {
190
  pub layer_sizes: Vec<usize>,
191
  pub alpha: f32,
192
  pub gamma: f32,
193
  pub convergence_threshold: f32,
194
  pub convergence_steps: u32,
195
  pub activation_function: ActivationFunction,
196
}
197

198
impl PredictiveCodingModel {
199
  /// Construct a model with the given layer sizes and learning rates.
200
  ///
201
  /// alpha is the synaptic learning rate, which controls how much the weights are updated after each inference step.
202
  /// gamma is the neural learning rate, which controls how much the node values are updated during inference.
203
  /// activation_function is applied to the node values when computing predictions for the layer below.
204
  pub fn new(config: &PredictiveCodingModelConfig) -> Self {
6✔
205
    let mut layers = Vec::new();
6✔
206
    for (index, layer_size) in config.layer_sizes.iter().enumerate() {
12✔
207
      let lower_size = if index == 0 { None } else { Some(config.layer_sizes[index - 1]) };
12✔
208

209
      layers.push(Layer::new(
12✔
210
        *layer_size,
12✔
211
        lower_size,
12✔
212
        config.activation_function,
12✔
213
        None,
12✔
214
        None
12✔
215
      ));
216
    }
217

218
    PredictiveCodingModel {
6✔
219
      layers,
6✔
220
      alpha: config.alpha,
6✔
221
      gamma: config.gamma,
6✔
222
      convergence_threshold: config.convergence_threshold,
6✔
223
      convergence_steps: config.convergence_steps,
6✔
224
    }
6✔
225
  }
6✔
226

227
  // Getters for model properties, so I don't have to expose the model fields directly
228
  pub fn get_config(&self) -> PredictiveCodingModelConfig {
7✔
229
    PredictiveCodingModelConfig {
230
      layer_sizes: self.layers.iter().map(|l| l.size).collect(),
7✔
231
      alpha: self.alpha,
7✔
232
      gamma: self.gamma,
7✔
233
      convergence_steps: self.convergence_steps,
7✔
234
      convergence_threshold: self.convergence_threshold,
7✔
235
      // I only allow that all layers have the same activation function
236
      activation_function: self.layers.first().unwrap().activation_function,
7✔
237
    }
238
  }
7✔
UNCOV
239
  pub fn get_layers(&self) -> &Vec<Layer> {
×
240
    &self.layers
×
241
  }
×
242
  pub fn get_layer(&self, index: usize) -> &Layer {
2✔
243
    &self.layers[index]
2✔
244
  }
2✔
245
  pub fn get_layer_sizes(&self) -> Vec<usize> {
10✔
246
    self.layers.iter().map(|l| l.size).collect()
10✔
247
  }
10✔
UNCOV
248
  pub fn get_alpha(&self) -> f32 {
×
249
    self.alpha
×
250
  }
×
UNCOV
251
  pub fn get_gamma(&self) -> f32 {
×
252
    self.gamma
×
253
  }
×
UNCOV
254
  pub fn get_activation_function(&self) -> ActivationFunction {
×
255
    self.layers.first().unwrap().activation_function
×
256
  }
×
257

258
  /// Set the values of the input layer to the given input values, and pin the input layer.
259
  pub fn get_input(&self) -> &Array1<f32> {
×
260
    &self.layers[0].values
×
261
  }
×
262

263
  /// Sets the values of the input layer to the given input values, and pin the input layer.
264
  pub fn set_input(&mut self, input_values: Array1<f32>) {
29✔
265
    self.layers[0].pin_values(input_values);
29✔
266
  }
29✔
267

268
  /// Prevent the input layer values from being updated during inference, by pinning the input layer.
269
  pub fn pin_input(&mut self) {
4✔
270
    self.layers[0].pinned = true;
4✔
271
  }
4✔
272

273
  /// Allow the input layer values to be updated during inference, by unpinning the input layer.
274
  pub fn unpin_input(&mut self) {
×
275
    self.layers[0].unpin_values();
×
276
  }
×
277

278
  /// Randomise the input layer values between 0..1 and unpin the input layer to allow updates during inference.
279
  pub fn randomise_input(&mut self) {
×
280
    let input_layer = &mut self.layers[0];
×
281
    let mut rng = rand::rng();
×
282
    input_layer.randomise_values(&mut rng);
×
283
    input_layer.unpin_values();
×
284
  }
×
285

286
  /// Set the values of the output layer to the given output values, and pins the output layer.
287
  pub fn get_output(&self) -> &Array1<f32> {
1✔
288
    &self.layers.last().unwrap().values
1✔
289
  }
1✔
290

291
  /// Sets the values of the output layer to the given output values, and pins the output layer.
292
  pub fn set_output(&mut self, output_values: Array1<f32>) {
28✔
293
    self.layers.last_mut().unwrap().pin_values(output_values);
28✔
294
  }
28✔
295

296
  /// Prevent the output layer values from being updated during inference, by pinning the output layer.
297
  pub fn pin_output(&mut self) {
4✔
298
    self.layers.last_mut().unwrap().pinned = true;
4✔
299
  }
4✔
300

301
  /// Allow the output layer values to be updated during inference, by unpinning the output layer.
302
  pub fn unpin_output(&mut self) {
1✔
303
    self.layers.last_mut().unwrap().unpin_values();
1✔
304
  }
1✔
305

306
  /// Randomise the output layer values between 0..1 and unpin the output layer to allow updates during inference.
307
  pub fn randomise_output(&mut self) {
×
308
    let output_layer = self.layers.last_mut().unwrap();
×
309
    let mut rng = rand::rng();
×
310
    output_layer.randomise_values(&mut rng);
×
311
    output_layer.unpin_values();
×
312
  }
×
313

314
  /// Reinitialise all unpinned (latent) layer values to small random values.
315
  /// Should be called before each new training sample to avoid carrying over
316
  /// converged state from a previous sample.
317
  pub fn reinitialise_latents(&mut self) {
28✔
318
    let mut rng = rand::rng();
28✔
319
    for layer in &mut self.layers {
56✔
320
      if !layer.pinned {
56✔
321
        layer.randomise_values(&mut rng);
1✔
322
      }
55✔
323
    }
324
  }
28✔
325

326
  /// Evolves node values until convergence, recomputing predictions and errors each step.
327
  /// Returns the number of steps taken to converge.
328
  pub fn converge_values(&mut self) -> u32{
28✔
329
    let mut converged: bool = false;
28✔
330
    let mut convergence_count: u32 = 0;
28✔
331

332
    while !converged && (convergence_count < self.convergence_steps) {
56✔
333
      self.compute_predictions_and_errors();
28✔
334

335
      if self.timestep().abs() < self.convergence_threshold {
28✔
336
        converged = true;
×
337
      }
28✔
338
      convergence_count += 1;
28✔
339
    };
340

341
    convergence_count
28✔
342
  }
28✔
343

344
  /// Compute predictions for each layer and then update errors.
345
  pub fn compute_predictions_and_errors(&mut self) {
29✔
346
    self.compute_predictions();
29✔
347
    self.compute_errors();
29✔
348
  }
29✔
349

350
  /// Compute predictions for all layers from top to bottom.
351
  pub fn compute_predictions(&mut self) {
29✔
352
    let num_layers = self.layers.len();
29✔
353
    for i in (0..num_layers - 1).rev() { // iterate backwards through the layers
29✔
354
      // Since the target layer needs to be mutable to update the predictions, I need to split the vector
29✔
355
      // Luckily, this is not a transformative operation, so split_at_mut is still fast
29✔
356
      let (lower, upper) = self.layers.split_at_mut(i+1);
29✔
357
      let lower_layer = &mut lower[i];
29✔
358
      let upper_layer = &upper[0];
29✔
359

29✔
360
      lower_layer.compute_predictions(upper_layer);
29✔
361
    }
29✔
362
  }
29✔
363

364
  /// Compute prediction errors for all layers.
365
  pub fn compute_errors(&mut self) {
29✔
366
    for i in 0..self.layers.len() {
58✔
367
      self.layers[i].compute_errors();
58✔
368
    }
58✔
369
  }
29✔
370

371
  /// Sum signed errors across all layers.
372
  pub fn read_total_error(&self) -> f32 {
×
373
    // Sum the errors of all nodes in all layers
374
    let mut total_error = 0.0;
×
375
    for layer in &self.layers {
×
376
      total_error += layer.read_total_error();
×
377
    }
×
378
    total_error
×
379
  }
×
380

381
  /// Sum squared errors across all layers.
382
  pub fn read_total_energy(&self) -> f32 {
4✔
383
    // Sum the energy of all nodes in all layers
384
    let mut total_energy = 0.0;
4✔
385
    for layer in &self.layers {
8✔
386
      total_energy += layer.read_total_energy();
8✔
387
    }
8✔
388

389
    0.5 * total_energy
4✔
390
  }
4✔
391

392
  /// Compute the change in node values under a single timestep of PC.
393
  /// Returns the mean change in node values across all layers.
394
  pub fn timestep(&mut self) -> f32 {
28✔
395
    let mut total_value_changes = 0.0;
28✔
396

397
    // update the input layer, which has no lower layer
398
    total_value_changes += self.layers[0].values_timestep(false, self.gamma, None);
28✔
399

400
    // the update of a node value depends on the errors of the layer below it.
401
    let num_layers: usize = self.layers.len();
28✔
402
    for i in 0..num_layers - 1 { // in rust, the range is exclusive of the upper bound
28✔
403
      let (lower, upper) = self.layers.split_at_mut(i + 1);
28✔
404
      let lower_layer: &Layer = &lower[i];
28✔
405
      let upper_layer: &mut Layer = &mut upper[0];
28✔
406

407
      // The last layer is handled differently.
408
      if i == num_layers - 1 { // the last i to be processed will be the second to last layer
28✔
409
        total_value_changes += upper_layer.values_timestep(true, self.gamma, Some(lower_layer));
×
410
      } else {
28✔
411
        total_value_changes += upper_layer.values_timestep(false, self.gamma, Some(lower_layer));
28✔
412
      }
28✔
413
    }
414

415
    // Mean
416
    let total_num_nodes = self.layers.iter().map(|layer| layer.values.len()).sum::<usize>() as f32;
56✔
417
    total_value_changes / total_num_nodes
28✔
418
  }
28✔
419

420
  /// Compute and apply prediction weights for all layers after inference.
421
  pub fn update_weights(&mut self) {
25✔
422
    let num_layers = self.layers.len();
25✔
423
    for i in 0..num_layers - 1 {
25✔
424
      let (lower, upper) = self.layers.split_at_mut(i + 1);
25✔
425
      let lower_layer: &Layer = &lower[i];
25✔
426
      let upper_layer: &mut Layer = &mut upper[0];
25✔
427

25✔
428
      upper_layer.update_weights(self.alpha, lower_layer);
25✔
429
    }
25✔
430
  }
25✔
431

432
  /// Compute the change in weights based on the current errors and values, without applying the changes to the model.
433
  /// Returns a Vec where index i contains the weight updates for layers[i+1].weights.
434
  pub fn compute_weight_updates(&mut self) -> Vec<Array2<f32>> {
20✔
435
    let mut weight_updates: Vec<Array2<f32>> = Vec::new();
20✔
436

437
    let num_layers = self.layers.len();
20✔
438
    for i in 0..num_layers - 1 {
20✔
439
      let (lower, upper) = self.layers.split_at_mut(i + 1);
20✔
440
      let lower_layer: &Layer = &lower[i];
20✔
441
      let upper_layer: &mut Layer = &mut upper[0];
20✔
442

20✔
443
      weight_updates.push(
20✔
444
        upper_layer.compute_weight_updates(self.alpha, lower_layer)
20✔
445
      );
20✔
446
    }
20✔
447

448
    weight_updates
20✔
449
  }
20✔
450

451
  pub fn apply_weight_updates(&mut self, weight_updates: Vec<Array2<f32>>) {
5✔
452
    // weight_updates[i] corresponds to layers[i+1].weights
453
    for (i, weights) in weight_updates.iter().enumerate() {
5✔
454
      self.layers[i + 1].weights += weights;
5✔
455
    }
5✔
456
  }
5✔
457

458
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc