• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tspooner / lfa / 102

pending completion
102

push

travis-ci

web-flow
Merge pull request #17 from tspooner/ft-serde

New manifest feature: "serialize"

19 of 19 new or added lines in 15 files covered. (100.0%)

1313 of 1748 relevant lines covered (75.11%)

0.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.29
/src/eval/pair.rs
1
use crate::{
2
    core::*,
3
    geometry::{Matrix, MatrixView, MatrixViewMut},
4
};
5

6
/// Weight-`Projection` evaluator with pair `[f64; 2]` output.
7
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
8
#[derive(Clone, Debug)]
×
9
pub struct PairFunction {
10
    pub weights: Matrix<f64>,
×
11
}
12

13
impl PairFunction {
14
    pub fn new(weights: Matrix<f64>) -> Self {
1✔
15
        PairFunction { weights, }
1✔
16
    }
1✔
17

18
    pub fn zeros(n_features: usize) -> Self {
1✔
19
        PairFunction::new(Matrix::zeros((n_features, 2)))
1✔
20
    }
1✔
21
}
22

23
impl Parameterised for PairFunction {
24
    fn weights(&self) -> Matrix<f64> { self.weights.clone() }
×
25
    fn weights_view(&self) -> MatrixView<f64> { self.weights.view() }
×
26
    fn weights_view_mut(&mut self) -> MatrixViewMut<f64> { self.weights.view_mut() }
×
27
}
28

29
impl Approximator for PairFunction {
30
    type Output = [f64; 2];
31

32
    fn n_outputs(&self) -> usize { 2 }
1✔
33

34
    fn evaluate(&self, features: &Features) -> EvaluationResult<Self::Output> {
1✔
35
        apply_to_features!(features => activations, {
1✔
36
            Ok([
1✔
37
                self.weights.column(0).dot(activations),
1✔
38
                self.weights.column(1).dot(activations),
1✔
39
            ])
40
        }; indices, {
41
            Ok(indices.iter().fold([0.0; 2], |acc, idx| [
1✔
42
                acc[0] + self.weights[(*idx, 0)],
1✔
43
                acc[1] + self.weights[(*idx, 1)],
1✔
44
            ]))
1✔
45
        })
46
    }
1✔
47

48
    fn jacobian(&self, features: &Features) -> Matrix<f64> {
×
49
        let dim = self.weights_dim();
×
50
        let phi = features.expanded(dim.0);
×
51

52
        let mut g = Matrix::zeros(dim);
×
53

54
        g.column_mut(0).assign(&phi);
×
55
        g.column_mut(1).assign(&phi);
×
56

57
        g
×
58
    }
×
59

60
    fn update_grad(&mut self, grad: &Matrix<f64>, update: Self::Output) -> UpdateResult<()> {
×
61
        Ok({
×
62
            self.weights.column_mut(0).scaled_add(update[0], &grad.column(0));
×
63
            self.weights.column_mut(1).scaled_add(update[1], &grad.column(1));
×
64
        })
65
    }
×
66

67
    fn update(&mut self, features: &Features, errors: Self::Output) -> UpdateResult<()> {
1✔
68
        apply_to_features!(features => activations, {
1✔
69
            Ok({
1✔
70
                self.weights.column_mut(0).scaled_add(errors[0], activations);
1✔
71
                self.weights.column_mut(1).scaled_add(errors[1], activations);
1✔
72
            })
73
        }; indices, {
74
            let z = indices.len() as f64;
1✔
75

76
            let se1 = errors[0] / z;
1✔
77
            let se2 = errors[1] / z;
1✔
78

79
            Ok(indices.into_iter().for_each(|idx| {
1✔
80
                self.weights[(*idx, 0)] += se1;
1✔
81
                self.weights[(*idx, 1)] += se2;
1✔
82
            }))
1✔
83
        })
84
    }
1✔
85
}
86

87
#[cfg(test)]
88
mod tests {
89
    extern crate seahash;
90

91
    use crate::{
92
        composition::Composable,
93
        core::*,
94
        basis::{Projector, fixed::{Fourier, TileCoding}},
95
        geometry::Space,
96
    };
97
    use std::hash::BuildHasherDefault;
98
    use super::PairFunction;
99

100
    type SHBuilder = BuildHasherDefault<seahash::SeaHasher>;
101

102
    #[test]
103
    fn test_sparse_update_eval() {
1✔
104
        let projector = TileCoding::new(SHBuilder::default(), 4, 100);
1✔
105
        let mut evaluator = PairFunction::zeros(projector.dim());
1✔
106

107
        assert_eq!(evaluator.n_outputs(), 2);
1✔
108
        assert_eq!(evaluator.weights.len(), 200);
1✔
109

110
        let features = projector.project(&vec![5.0]);
1✔
111

112
        let _ = evaluator.update(&features, [20.0, 50.0]);
1✔
113
        let out = evaluator.evaluate(&features).unwrap();
1✔
114

115
        assert!((out[0] - 20.0).abs() < 1e-6);
1✔
116
        assert!((out[1] - 50.0).abs() < 1e-6);
1✔
117
    }
1✔
118

119
    #[test]
120
    fn test_dense_update_eval() {
1✔
121
        let projector = Fourier::new(3, vec![(0.0, 10.0)]).normalise_l2();
1✔
122
        let mut evaluator = PairFunction::zeros(projector.dim());
1✔
123

124
        assert_eq!(evaluator.n_outputs(), 2);
1✔
125
        assert_eq!(evaluator.weights.len(), 6);
1✔
126

127
        let features = projector.project(&vec![5.0]);
1✔
128

129
        let _ = evaluator.update(&features, [20.0, 50.0]);
1✔
130
        let out = evaluator.evaluate(&features).unwrap();
1✔
131

132
        assert!((out[0] - 20.0).abs() < 1e-6);
1✔
133
        assert!((out[1] - 50.0).abs() < 1e-6);
1✔
134
    }
1✔
135
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2023 Coveralls, Inc