• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wildjames / predictive_coding_rs / 26039969057

18 May 2026 02:30PM UTC coverage: 89.655% (+1.0%) from 88.629%
26039969057

Pull #18

github

web-flow
Merge a5b98f021 into ab068c7f4
Pull Request #18: Implement the remaining GPU kernels

268 of 288 new or added lines in 12 files covered. (93.06%)

6 existing lines in 2 files now uncovered.

2080 of 2320 relevant lines covered (89.66%)

14.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.33
/src/model/cpu.rs
1
use ndarray::{Array1, Array2};
2

3
use crate::error::{PredictiveCodingError, Result};
4

5
use super::{
6
    ExecutionBackend, Layer, ModelRuntime, ModelSnapshot, PredictiveCodingModel,
7
    PredictiveCodingModelConfig, TrainableModelRuntime, WeightUpdateSet, maths::outer_product,
8
};
9

10
pub struct CpuModelRuntime {
11
    model: PredictiveCodingModel,
12
}
13

14
impl CpuModelRuntime {
15
    pub fn new(config: &PredictiveCodingModelConfig) -> Self {
1✔
16
        CpuModelRuntime {
1✔
17
            model: PredictiveCodingModel::new(config),
1✔
18
        }
1✔
19
    }
1✔
20

21
    pub fn from_model(model: PredictiveCodingModel) -> Self {
38✔
22
        CpuModelRuntime { model }
38✔
23
    }
38✔
24

25
    pub fn from_snapshot(snapshot: &ModelSnapshot) -> Result<Self> {
1✔
26
        Ok(CpuModelRuntime {
27
            model: PredictiveCodingModel::from_snapshot(snapshot)?,
1✔
28
        })
29
    }
1✔
30

31
    pub fn model(&self) -> &PredictiveCodingModel {
31✔
32
        &self.model
31✔
33
    }
31✔
34

35
    pub fn model_mut(&mut self) -> &mut PredictiveCodingModel {
50✔
36
        &mut self.model
50✔
37
    }
50✔
38

39
    pub fn into_model(self) -> PredictiveCodingModel {
×
40
        self.model
×
41
    }
×
42

43
    fn validate_layer_width(actual: usize, expected: usize, label: &str) -> Result<()> {
21✔
44
        if actual != expected {
21✔
45
            return Err(PredictiveCodingError::validation(format!(
1✔
46
                "{label} length {actual} does not match expected size {expected}"
1✔
47
            )));
1✔
48
        }
20✔
49

50
        Ok(())
20✔
51
    }
21✔
52

53
    fn compute_predictions_for_layer(lower_layer: &mut Layer, upper_layer: &Layer) {
33✔
54
        let preactivation: Array1<f32> = upper_layer.weights.dot(&upper_layer.values);
33✔
55
        lower_layer.predictions = preactivation.mapv(|a| upper_layer.activation_function.apply(a));
132✔
56
    }
33✔
57

58
    fn compute_errors_for_layer(layer: &mut Layer) {
66✔
59
        layer.errors = &layer.values - &layer.predictions;
66✔
60
    }
66✔
61

62
    fn total_error_for_layer(layer: &Layer) -> f32 {
×
63
        layer.errors.iter().copied().sum::<f32>()
×
64
    }
×
65

66
    fn total_energy_for_layer(layer: &Layer) -> f32 {
10✔
67
        layer.errors.mapv(|x| x.powi(2)).sum()
70✔
68
    }
10✔
69

70
    fn values_timestep(
65✔
71
        layer: &mut Layer,
65✔
72
        is_top_level: bool,
65✔
73
        gamma: f32,
65✔
74
        lower_layer: Option<&Layer>,
65✔
75
    ) -> f32 {
65✔
76
        if layer.pinned {
65✔
77
            return 0.0;
63✔
78
        }
2✔
79

80
        let rhs: Array1<f32> = if let Some(lower_layer) = lower_layer {
2✔
81
            let preactivation: Array1<f32> = layer.weights.dot(&layer.values);
2✔
82
            let activation_function_derivative: Array1<f32> =
2✔
83
                preactivation.mapv(|a| layer.activation_function.derivative(a));
5✔
84
            let gain_modulated_errors: Array1<f32> =
2✔
85
                activation_function_derivative * &lower_layer.errors;
2✔
86

87
            layer.weights.t().dot(&gain_modulated_errors)
2✔
88
        } else {
89
            Array1::zeros(layer.values.len())
×
90
        };
91

92
        let value_changes: Array1<f32> = if is_top_level {
2✔
93
            rhs * gamma
1✔
94
        } else {
95
            (-&layer.errors + rhs) * gamma
1✔
96
        };
97

98
        layer.values += &value_changes;
2✔
99
        value_changes.mapv(|x| x.abs()).sum()
11✔
100
    }
65✔
101

102
    fn compute_weight_updates_for_layer(
28✔
103
        alpha: f32,
28✔
104
        upper_layer: &Layer,
28✔
105
        lower_layer: &Layer,
28✔
106
    ) -> Array2<f32> {
28✔
107
        let preactivation: Array1<f32> = upper_layer.weights.dot(&upper_layer.values);
28✔
108
        let activation_function_derivative: Array1<f32> =
28✔
109
            preactivation.mapv(|a| upper_layer.activation_function.derivative(a));
112✔
110
        let gain_modulated_errors: Array1<f32> =
28✔
111
            &activation_function_derivative * &lower_layer.errors;
28✔
112

113
        alpha * outer_product(&gain_modulated_errors, &upper_layer.values)
28✔
114
    }
28✔
115

116
    fn compute_predictions_internal(&mut self) {
33✔
117
        let num_layers = self.model.layers.len();
33✔
118
        if num_layers < 2 {
33✔
119
            return;
×
120
        }
33✔
121

122
        for index in (0..num_layers - 1).rev() {
33✔
123
            let (lower, upper) = self.model.layers.split_at_mut(index + 1);
33✔
124
            let lower_layer = &mut lower[index];
33✔
125
            let upper_layer = &upper[0];
33✔
126

33✔
127
            Self::compute_predictions_for_layer(lower_layer, upper_layer);
33✔
128
        }
33✔
129
    }
33✔
130

131
    fn compute_errors_internal(&mut self) {
33✔
132
        for layer in &mut self.model.layers {
66✔
133
            Self::compute_errors_for_layer(layer);
66✔
134
        }
66✔
135
    }
33✔
136

137
    fn timestep_internal(&mut self) -> f32 {
32✔
138
        if self.model.layers.is_empty() {
32✔
139
            return 0.0;
×
140
        }
32✔
141

142
        let mut total_value_changes =
32✔
143
            Self::values_timestep(&mut self.model.layers[0], false, self.model.gamma, None);
32✔
144

145
        let num_layers = self.model.layers.len();
32✔
146
        for index in 1..num_layers {
33✔
147
            let (lower, upper) = self.model.layers.split_at_mut(index);
33✔
148
            let lower_layer = &lower[index - 1];
33✔
149
            let upper_layer = &mut upper[0];
33✔
150
            let is_top_level = index == num_layers - 1;
33✔
151

33✔
152
            total_value_changes += Self::values_timestep(
33✔
153
                upper_layer,
33✔
154
                is_top_level,
33✔
155
                self.model.gamma,
33✔
156
                Some(lower_layer),
33✔
157
            );
33✔
158
        }
33✔
159

160
        let total_num_nodes = self
32✔
161
            .model
32✔
162
            .layers
32✔
163
            .iter()
32✔
164
            .map(|layer| layer.values.len())
65✔
165
            .sum::<usize>() as f32;
32✔
166

167
        total_value_changes / total_num_nodes
32✔
168
    }
32✔
169

170
    fn converge_values_internal(&mut self) -> u32 {
31✔
171
        let mut converged = false;
31✔
172
        let mut convergence_count = 0;
31✔
173

174
        while !converged && convergence_count < self.model.convergence_steps {
62✔
175
            self.compute_predictions_internal();
31✔
176
            self.compute_errors_internal();
31✔
177

178
            if self.timestep_internal().abs() < self.model.convergence_threshold {
31✔
179
                converged = true;
×
180
            }
31✔
181

182
            convergence_count += 1;
31✔
183
        }
184

185
        convergence_count
31✔
186
    }
31✔
187

188
    fn total_error_internal(&self) -> f32 {
×
189
        self.model
×
190
            .layers
×
191
            .iter()
×
192
            .map(Self::total_error_for_layer)
×
193
            .sum()
×
194
    }
×
195

196
    fn total_energy_internal(&self) -> f32 {
5✔
197
        0.5 * self
5✔
198
            .model
5✔
199
            .layers
5✔
200
            .iter()
5✔
201
            .map(Self::total_energy_for_layer)
5✔
202
            .sum::<f32>()
5✔
203
    }
5✔
204

205
    fn compute_weight_updates_internal(&self) -> Vec<Array2<f32>> {
28✔
206
        let num_layers = self.model.layers.len();
28✔
207
        let mut weight_updates = Vec::with_capacity(num_layers.saturating_sub(1));
28✔
208

209
        for index in 0..num_layers.saturating_sub(1) {
28✔
210
            let mut update = Self::compute_weight_updates_for_layer(
28✔
211
                self.model.alpha,
28✔
212
                &self.model.layers[index + 1],
28✔
213
                &self.model.layers[index],
28✔
214
            );
215
            if self.model.weight_clip > 0.0 {
28✔
NEW
216
                let clip = self.model.weight_clip;
×
NEW
217
                update.mapv_inplace(|x| x.clamp(-clip, clip));
×
218
            }
28✔
219
            weight_updates.push(update);
28✔
220
        }
221

222
        weight_updates
28✔
223
    }
28✔
224

225
    fn apply_weight_updates_internal(&mut self, weight_updates: &[Array2<f32>]) {
12✔
226
        for (index, weights) in weight_updates.iter().enumerate() {
12✔
227
            self.model.layers[index + 1].weights += weights;
12✔
228
        }
12✔
229
    }
12✔
230
}
231

232
// The surface that the rest of the codebase will interact with
233
impl ModelRuntime for CpuModelRuntime {
234
    fn backend(&self) -> ExecutionBackend {
2✔
235
        ExecutionBackend::Cpu
2✔
236
    }
2✔
237

238
    fn config(&self) -> PredictiveCodingModelConfig {
8✔
239
        self.model.get_config()
8✔
240
    }
8✔
241

242
    fn layer_sizes(&self) -> Vec<usize> {
×
243
        self.model.get_layer_sizes()
×
244
    }
×
245

246
    fn snapshot(&mut self) -> Result<ModelSnapshot> {
18✔
247
        Ok(self.model.to_snapshot())
18✔
248
    }
18✔
249

250
    fn set_input(&mut self, input_values: &[f32]) -> Result<()> {
11✔
251
        Self::validate_layer_width(input_values.len(), self.model.get_input().len(), "input")?;
11✔
252
        self.model
10✔
253
            .set_input(Array1::from_vec(input_values.to_vec()));
10✔
254
        Ok(())
10✔
255
    }
11✔
256

257
    fn set_output(&mut self, output_values: &[f32]) -> Result<()> {
10✔
258
        Self::validate_layer_width(output_values.len(), self.model.get_output().len(), "output")?;
10✔
259
        self.model
10✔
260
            .set_output(Array1::from_vec(output_values.to_vec()));
10✔
261
        Ok(())
10✔
262
    }
10✔
263

264
    fn pin_input(&mut self) -> Result<()> {
4✔
265
        self.model.pin_input();
4✔
266
        Ok(())
4✔
267
    }
4✔
268

269
    fn unpin_input(&mut self) -> Result<()> {
×
270
        self.model.unpin_input();
×
271
        Ok(())
×
272
    }
×
273

274
    fn pin_output(&mut self) -> Result<()> {
4✔
275
        self.model.pin_output();
4✔
276
        Ok(())
4✔
277
    }
4✔
278

279
    fn unpin_output(&mut self) -> Result<()> {
1✔
280
        self.model.unpin_output();
1✔
281
        Ok(())
1✔
282
    }
1✔
283

284
    fn reinitialise_latents(&mut self) -> Result<()> {
30✔
285
        self.model.reinitialise_latents();
30✔
286
        Ok(())
30✔
287
    }
30✔
288

289
    fn compute_predictions_and_errors(&mut self) -> Result<()> {
2✔
290
        self.compute_predictions_internal();
2✔
291
        self.compute_errors_internal();
2✔
292
        Ok(())
2✔
293
    }
2✔
294

295
    fn timestep(&mut self) -> Result<f32> {
1✔
296
        Ok(self.timestep_internal())
1✔
297
    }
1✔
298

299
    fn converge_values(&mut self) -> Result<u32> {
31✔
300
        Ok(self.converge_values_internal())
31✔
301
    }
31✔
302

303
    fn total_error(&mut self) -> Result<f32> {
×
304
        Ok(self.total_error_internal())
×
305
    }
×
306

307
    fn total_energy(&mut self) -> Result<f32> {
5✔
308
        Ok(self.total_energy_internal())
5✔
309
    }
5✔
310

311
    fn input_values(&mut self) -> Result<Vec<f32>> {
×
312
        Ok(self.model.get_input().to_vec())
×
313
    }
×
314

315
    fn output_values(&mut self) -> Result<Vec<f32>> {
×
316
        Ok(self.model.get_output().to_vec())
×
317
    }
×
318
}
319

320
impl TrainableModelRuntime for CpuModelRuntime {
321
    fn compute_weight_updates(&mut self) -> Result<WeightUpdateSet> {
28✔
322
        let arrays: Vec<Array2<f32>> = self.compute_weight_updates_internal();
28✔
323

324
        Ok(WeightUpdateSet {
325
            shapes: arrays.iter().map(|array| array.dim()).collect(),
28✔
326
            updates: arrays
28✔
327
                .into_iter()
28✔
328
                .map(|array| array.iter().copied().collect())
28✔
329
                .collect(),
28✔
330
        })
331
    }
28✔
332

333
    fn apply_weight_updates(&mut self, updates: &WeightUpdateSet) -> Result<()> {
12✔
334
        if updates.updates.len() != updates.shapes.len() {
12✔
335
            return Err(PredictiveCodingError::validation(
×
336
                "weight update payload has mismatched update and shape counts",
×
337
            ));
×
338
        }
12✔
339

340
        let expected_layer_count: usize = self.model.get_layers().len().saturating_sub(1);
12✔
341
        if updates.updates.len() != expected_layer_count {
12✔
342
            return Err(PredictiveCodingError::validation(format!(
×
343
                "weight update payload contains {} layers but model expects {}",
×
344
                updates.updates.len(),
×
345
                expected_layer_count
×
346
            )));
×
347
        }
12✔
348

349
        let mut arrays: Vec<Array2<f32>> = Vec::with_capacity(updates.updates.len());
12✔
350
        for (index, (update_values, (rows, cols))) in updates
12✔
351
            .updates
12✔
352
            .iter()
12✔
353
            .zip(updates.shapes.iter().copied())
12✔
354
            .enumerate()
12✔
355
        {
356
            let expected_shape = self.model.get_layer(index + 1).weights.dim();
12✔
357
            if expected_shape != (rows, cols) {
12✔
358
                return Err(PredictiveCodingError::validation(format!(
×
359
                    "weight update shape {:?} does not match model layer {} shape {:?}",
×
360
                    (rows, cols),
×
361
                    index + 1,
×
362
                    expected_shape
×
363
                )));
×
364
            }
12✔
365

366
            let array: Array2<f32> = Array2::from_shape_vec((rows, cols), update_values.clone())
12✔
367
                .map_err(|_| {
12✔
368
                    PredictiveCodingError::validation(format!(
×
369
                        "weight update layer {} contains {} values but expected {}",
370
                        index + 1,
×
371
                        update_values.len(),
×
372
                        rows * cols
×
373
                    ))
374
                })?;
×
375
            arrays.push(array);
12✔
376
        }
377

378
        self.apply_weight_updates_internal(&arrays);
12✔
379
        Ok(())
12✔
380
    }
12✔
381
}
382

383
#[cfg(test)]
384
mod tests {
385
    use super::*;
386

387
    use crate::model::{ModelRuntime, TrainableModelRuntime, maths::ActivationFunction};
388
    use crate::test_utils::tiny_relu_model;
389
    use ndarray::array;
390

391
    #[test]
392
    fn cpu_runtime_snapshot_round_trips_through_model() {
1✔
393
        let model = tiny_relu_model();
1✔
394
        let mut runtime = CpuModelRuntime::from_model(model.clone());
1✔
395

396
        let snapshot = runtime.snapshot().unwrap();
1✔
397
        let restored = CpuModelRuntime::from_snapshot(&snapshot).unwrap();
1✔
398

399
        assert_eq!(restored.model().get_config(), model.get_config());
1✔
400
        assert_eq!(restored.model().get_layer_sizes(), model.get_layer_sizes());
1✔
401
    }
1✔
402

403
    #[test]
404
    fn cpu_runtime_validates_input_size() {
1✔
405
        let mut runtime = CpuModelRuntime::from_model(tiny_relu_model());
1✔
406

407
        let error = runtime.set_input(&[1.0, 2.0]).unwrap_err();
1✔
408

409
        assert_eq!(
1✔
410
            error.to_string(),
1✔
411
            "validation error: input length 2 does not match expected size 4"
412
        );
413
    }
1✔
414

415
    #[test]
416
    fn cpu_runtime_weight_updates_match_model_layer_count() {
1✔
417
        let mut runtime = CpuModelRuntime::from_model(tiny_relu_model());
1✔
418
        runtime.compute_predictions_and_errors().unwrap();
1✔
419

420
        let updates = runtime.compute_weight_updates().unwrap();
1✔
421

422
        assert_eq!(updates.updates.len(), 1);
1✔
423
        assert_eq!(updates.shapes, vec![(4, 10)]);
1✔
424
    }
1✔
425

426
    #[test]
427
    fn cpu_runtime_timestep_uses_hidden_layer_error_term_for_non_top_layers() {
1✔
428
        let mut runtime = CpuModelRuntime::new(&PredictiveCodingModelConfig {
1✔
429
            layer_sizes: vec![1, 1, 1],
1✔
430
            alpha: 0.05,
1✔
431
            gamma: 0.5,
1✔
432
            convergence_threshold: 0.0,
1✔
433
            convergence_steps: 1,
1✔
434
            activation_function: ActivationFunction::Relu,
1✔
435
            weight_clip: 0.0,
1✔
436
        });
1✔
437

438
        runtime.model_mut().layers[0].pinned = true;
1✔
439
        runtime.model_mut().layers[0].errors = array![0.0];
1✔
440
        runtime.model_mut().layers[1].values = array![1.0];
1✔
441
        runtime.model_mut().layers[1].errors = array![0.25];
1✔
442
        runtime.model_mut().layers[1].weights = array![[1.0]];
1✔
443
        runtime.model_mut().layers[2].pinned = true;
1✔
444

445
        runtime.timestep().unwrap();
1✔
446

447
        assert_eq!(runtime.model().get_layer(1).values, array![0.875]);
1✔
448
    }
1✔
449
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc