• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine / 13195982647

07 Feb 2025 08:25AM CUT coverage: 90.031% (-0.04%) from 90.073%
13195982647

Pull #1000

github

web-flow
Merge 2410e0a4e into 22fe2f6bb
Pull Request #1000: add tensor shape to ml model

109 of 194 new or added lines in 10 files covered. (56.19%)

3 existing lines in 3 files now uncovered.

126035 of 139990 relevant lines covered (90.03%)

57499.02 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.48
/datatypes/src/machine_learning.rs
1
use crate::{
2
    dataset::{is_invalid_name_char, SYSTEM_NAMESPACE},
3
    raster::{GridShape2D, GridSize, RasterDataType},
4
};
5
use serde::{de::Visitor, Deserialize, Serialize};
6
use snafu::Snafu;
7
use std::path::PathBuf;
8
use std::str::FromStr;
9
use strum::IntoStaticStr;
10

11
const NAME_DELIMITER: char = ':';
12

13
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
14
pub struct MlModelName {
15
    pub namespace: Option<String>,
16
    pub name: String,
17
}
18

19
#[derive(Snafu, IntoStaticStr, Debug)]
×
20
#[snafu(visibility(pub(crate)))]
21
#[snafu(context(suffix(false)))] // disables default `Snafu` suffix
22
pub enum MlModelNameError {
23
    #[snafu(display("MlModelName is empty"))]
24
    IsEmpty,
25
    #[snafu(display("invalid character '{invalid_char}' in named model"))]
26
    InvalidCharacter { invalid_char: String },
27
    #[snafu(display("ml model name must consist of at most two parts"))]
28
    TooManyParts,
29
}
30

31
impl MlModelName {
32
    /// Canonicalize a name that reflects the system namespace and model.
33
    fn canonicalize<S: Into<String> + PartialEq<&'static str>>(
2✔
34
        name: S,
2✔
35
        system_name: &'static str,
2✔
36
    ) -> Option<String> {
2✔
37
        if name == system_name {
2✔
38
            None
1✔
39
        } else {
40
            Some(name.into())
1✔
41
        }
42
    }
2✔
43

44
    pub fn new<S: Into<String>>(namespace: Option<String>, name: S) -> Self {
2✔
45
        Self {
2✔
46
            namespace,
2✔
47
            name: name.into(),
2✔
48
        }
2✔
49
    }
2✔
50
}
51

52
impl Serialize for MlModelName {
53
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
3✔
54
    where
3✔
55
        S: serde::Serializer,
3✔
56
    {
3✔
57
        let d = NAME_DELIMITER;
3✔
58
        let serialized = match (&self.namespace, &self.name) {
3✔
59
            (None, name) => name.to_string(),
3✔
60
            (Some(namespace), name) => {
×
61
                format!("{namespace}{d}{name}")
×
62
            }
63
        };
64

65
        serializer.serialize_str(&serialized)
3✔
66
    }
3✔
67
}
68

69
impl<'de> Deserialize<'de> for MlModelName {
70
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
×
71
    where
×
72
        D: serde::Deserializer<'de>,
×
73
    {
×
74
        deserializer.deserialize_str(MlModelNameDeserializeVisitor)
×
75
    }
×
76
}
77

78
impl FromStr for MlModelName {
79
    type Err = MlModelNameError;
80

81
    fn from_str(s: &str) -> Result<Self, Self::Err> {
5✔
82
        let mut strings = [None, None];
5✔
83
        let mut split = s.split(NAME_DELIMITER);
5✔
84

85
        for (buffer, part) in strings.iter_mut().zip(&mut split) {
7✔
86
            if part.is_empty() {
7✔
87
                return Err(MlModelNameError::IsEmpty);
×
88
            }
7✔
89

90
            if let Some(c) = part.matches(is_invalid_name_char).next() {
7✔
91
                return Err(MlModelNameError::InvalidCharacter {
×
92
                    invalid_char: c.to_string(),
×
93
                });
×
94
            }
7✔
95

7✔
96
            *buffer = Some(part.to_string());
7✔
97
        }
98

99
        if split.next().is_some() {
5✔
100
            return Err(MlModelNameError::TooManyParts);
×
101
        }
5✔
102

103
        match strings {
5✔
104
            [Some(namespace), Some(name)] => Ok(MlModelName {
2✔
105
                namespace: MlModelName::canonicalize(namespace, SYSTEM_NAMESPACE),
2✔
106
                name,
2✔
107
            }),
2✔
108
            [Some(name), None] => Ok(MlModelName {
3✔
109
                namespace: None,
3✔
110
                name,
3✔
111
            }),
3✔
112
            _ => Err(MlModelNameError::IsEmpty),
×
113
        }
114
    }
5✔
115
}
116

117
struct MlModelNameDeserializeVisitor;
118

119
impl Visitor<'_> for MlModelNameDeserializeVisitor {
120
    type Value = MlModelName;
121

122
    /// always keep in sync with [`is_allowed_name_char`]
123
    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
×
124
        write!(
×
125
            formatter,
×
126
            "a string consisting of a namespace and name name, separated by a colon, only using alphanumeric characters, underscores & dashes"
×
127
        )
×
128
    }
×
129

130
    fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
×
131
    where
×
132
        E: serde::de::Error,
×
133
    {
×
134
        MlModelName::from_str(s).map_err(|e| E::custom(e.to_string()))
×
135
    }
×
136
}
137

138
/// A struct describing tensor shape for `MlModelMetadata`
139
#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize)]
140
pub struct TensorShape3D {
141
    pub y: u32,
142
    pub x: u32,
143
    pub bands: u32, // TODO: named attributes?
144
}
145

146
impl TensorShape3D {
147
    pub fn new_y_x_bands(y: u32, x: u32, bands: u32) -> Self {
4✔
148
        Self { y, x, bands }
4✔
149
    }
4✔
150

NEW
151
    pub fn new_single_pixel_bands(bands: u32) -> Self {
×
NEW
152
        Self { y: 1, x: 1, bands }
×
NEW
153
    }
×
154

NEW
155
    pub fn new_single_pixel_single_band() -> Self {
×
NEW
156
        Self::new_single_pixel_bands(1)
×
NEW
157
    }
×
158

NEW
159
    pub fn axis_size_y(&self) -> u32 {
×
NEW
160
        self.y
×
NEW
161
    }
×
162

NEW
163
    pub fn axis_size_x(&self) -> u32 {
×
NEW
164
        self.x
×
NEW
165
    }
×
166

NEW
167
    pub fn yx_matches_tile_shape(&self, tile_shape: &GridShape2D) -> bool {
×
NEW
168
        self.axis_size_x() as usize == tile_shape.axis_size_x()
×
NEW
169
            && self.axis_size_y() as usize == tile_shape.axis_size_y()
×
NEW
170
    }
×
171
}
172

173
// For now we assume all models are pixel-wise, i.e., they take a single pixel with multiple bands as input and produce a single output value.
174
// To support different inputs, we would need a more sophisticated logic to produce the inputs for the model.
175
#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize)]
176
pub struct MlModelMetadata {
177
    pub file_path: PathBuf,
178
    pub input_type: RasterDataType,
179
    pub output_type: RasterDataType,
180
    pub input_shape: TensorShape3D,
181
    pub output_shape: TensorShape3D, // TODO: output measurement, e.g. classification or regression, label names for classification. This would have to be provided by the model creator along the model file as it cannot be extracted from the model file(?)
182
}
183

184
impl MlModelMetadata {
185
    pub fn num_input_bands(&self) -> u32 {
2✔
186
        self.input_shape.bands
2✔
187
    }
2✔
188

NEW
189
    pub fn mun_output_bands(&self) -> u32 {
×
NEW
190
        self.output_shape.bands
×
NEW
191
    }
×
192

193
    pub fn input_is_single_pixel(&self) -> bool {
6✔
194
        self.input_shape.x == 1 && self.input_shape.y == 1
6✔
195
    }
6✔
196

197
    pub fn output_is_single_pixel(&self) -> bool {
2✔
198
        self.output_shape.x == 1 && self.output_shape.y == 1
2✔
199
    }
2✔
200

NEW
201
    pub fn output_is_single_attribute(&self) -> bool {
×
NEW
202
        self.mun_output_bands() == 1
×
NEW
203
    }
×
204
}
205

206
#[cfg(test)]
207
mod tests {
208
    use super::*;
209

210
    #[test]
211
    fn ml_model_name_from_str() {
1✔
212
        const ML_MODEL_NAME: &str = "myModelName";
213
        let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
1✔
214
        assert_eq!(mln.name, ML_MODEL_NAME);
1✔
215
        assert!(mln.namespace.is_none());
1✔
216
    }
1✔
217

218
    #[test]
219
    fn ml_model_name_from_str_prefixed() {
1✔
220
        const ML_MODEL_NAME: &str = "d5328854-6190-4af9-ad69-4e74b0961ac9:myModelName";
221
        let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
1✔
222
        assert_eq!(mln.name, "myModelName".to_string());
1✔
223
        assert_eq!(
1✔
224
            mln.namespace,
1✔
225
            Some("d5328854-6190-4af9-ad69-4e74b0961ac9".to_string())
1✔
226
        );
1✔
227
    }
1✔
228

229
    #[test]
230
    fn ml_model_name_from_str_system() {
1✔
231
        const ML_MODEL_NAME: &str = "_:myModelName";
232
        let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
1✔
233
        assert_eq!(mln.name, "myModelName".to_string());
1✔
234
        assert!(mln.namespace.is_none());
1✔
235
    }
1✔
236
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc