• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine / 15023660263

14 May 2025 02:39PM UTC coverage: 89.873%. First build
15023660263

Pull #1000

github

web-flow
Merge a8d61cd3e into 8c18b149b
Pull Request #1000: add tensor shape to ml model

296 of 436 new or added lines in 11 files covered. (67.89%)

126875 of 141171 relevant lines covered (89.87%)

57016.26 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

62.9
/datatypes/src/machine_learning.rs
1
use crate::{
2
    dataset::{SYSTEM_NAMESPACE, is_invalid_name_char},
3
    raster::{GridShape2D, GridSize, RasterDataType},
4
};
5
use serde::{Deserialize, Serialize, de::Visitor};
6
use snafu::Snafu;
7
use std::path::PathBuf;
8
use std::str::FromStr;
9
use strum::IntoStaticStr;
10

11
const NAME_DELIMITER: char = ':';
12

13
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
14
pub struct MlModelName {
15
    pub namespace: Option<String>,
16
    pub name: String,
17
}
18

19
#[derive(Snafu, IntoStaticStr, Debug)]
×
20
#[snafu(visibility(pub(crate)))]
21
#[snafu(context(suffix(false)))] // disables default `Snafu` suffix
22
pub enum MlModelNameError {
23
    #[snafu(display("MlModelName is empty"))]
24
    IsEmpty,
25
    #[snafu(display("invalid character '{invalid_char}' in named model"))]
26
    InvalidCharacter { invalid_char: String },
27
    #[snafu(display("ml model name must consist of at most two parts"))]
28
    TooManyParts,
29
}
30

31
impl MlModelName {
32
    /// Canonicalize a name that reflects the system namespace and model.
33
    fn canonicalize<S: Into<String> + PartialEq<&'static str>>(
2✔
34
        name: S,
2✔
35
        system_name: &'static str,
2✔
36
    ) -> Option<String> {
2✔
37
        if name == system_name {
2✔
38
            None
1✔
39
        } else {
40
            Some(name.into())
1✔
41
        }
42
    }
2✔
43

44
    pub fn new<S: Into<String>>(namespace: Option<String>, name: S) -> Self {
2✔
45
        Self {
2✔
46
            namespace,
2✔
47
            name: name.into(),
2✔
48
        }
2✔
49
    }
2✔
50
}
51

52
impl Serialize for MlModelName {
53
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
3✔
54
    where
3✔
55
        S: serde::Serializer,
3✔
56
    {
3✔
57
        let d = NAME_DELIMITER;
3✔
58
        let serialized = match (&self.namespace, &self.name) {
3✔
59
            (None, name) => name.to_string(),
3✔
60
            (Some(namespace), name) => {
×
61
                format!("{namespace}{d}{name}")
×
62
            }
63
        };
64

65
        serializer.serialize_str(&serialized)
3✔
66
    }
3✔
67
}
68

69
impl<'de> Deserialize<'de> for MlModelName {
70
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
×
71
    where
×
72
        D: serde::Deserializer<'de>,
×
73
    {
×
74
        deserializer.deserialize_str(MlModelNameDeserializeVisitor)
×
75
    }
×
76
}
77

78
impl FromStr for MlModelName {
79
    type Err = MlModelNameError;
80

81
    fn from_str(s: &str) -> Result<Self, Self::Err> {
5✔
82
        let mut strings = [None, None];
5✔
83
        let mut split = s.split(NAME_DELIMITER);
5✔
84

85
        for (buffer, part) in strings.iter_mut().zip(&mut split) {
7✔
86
            if part.is_empty() {
7✔
87
                return Err(MlModelNameError::IsEmpty);
×
88
            }
7✔
89

90
            if let Some(c) = part.matches(is_invalid_name_char).next() {
7✔
91
                return Err(MlModelNameError::InvalidCharacter {
×
92
                    invalid_char: c.to_string(),
×
93
                });
×
94
            }
7✔
95

7✔
96
            *buffer = Some(part.to_string());
7✔
97
        }
98

99
        if split.next().is_some() {
5✔
100
            return Err(MlModelNameError::TooManyParts);
×
101
        }
5✔
102

103
        match strings {
5✔
104
            [Some(namespace), Some(name)] => Ok(MlModelName {
2✔
105
                namespace: MlModelName::canonicalize(namespace, SYSTEM_NAMESPACE),
2✔
106
                name,
2✔
107
            }),
2✔
108
            [Some(name), None] => Ok(MlModelName {
3✔
109
                namespace: None,
3✔
110
                name,
3✔
111
            }),
3✔
112
            _ => Err(MlModelNameError::IsEmpty),
×
113
        }
114
    }
5✔
115
}
116

117
struct MlModelNameDeserializeVisitor;
118

119
impl Visitor<'_> for MlModelNameDeserializeVisitor {
120
    type Value = MlModelName;
121

122
    /// always keep in sync with [`is_allowed_name_char`]
123
    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
×
124
        write!(
×
125
            formatter,
×
126
            "a string consisting of a namespace and name name, separated by a colon, only using alphanumeric characters, underscores & dashes"
×
127
        )
×
128
    }
×
129

130
    fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
×
131
    where
×
132
        E: serde::de::Error,
×
133
    {
×
134
        MlModelName::from_str(s).map_err(|e| E::custom(e.to_string()))
×
135
    }
×
136
}
137

138
/// A struct describing tensor shape for `MlModelMetadata`
139
#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize)]
140
pub struct MlTensorShape3D {
141
    pub y: u32,
142
    pub x: u32,
143
    pub bands: u32, // TODO: named attributes?
144
}
145

146
impl MlTensorShape3D {
NEW
147
    pub fn new_y_x_bands(y: u32, x: u32, bands: u32) -> Self {
×
NEW
148
        Self { y, x, bands }
×
NEW
149
    }
×
150

151
    pub fn new_single_pixel_bands(bands: u32) -> Self {
2✔
152
        Self { y: 1, x: 1, bands }
2✔
153
    }
2✔
154

NEW
155
    pub fn new_single_pixel_single_band() -> Self {
×
NEW
156
        Self::new_single_pixel_bands(1)
×
NEW
157
    }
×
158

159
    pub fn axis_size_y(&self) -> u32 {
6✔
160
        self.y
6✔
161
    }
6✔
162

163
    pub fn axis_size_x(&self) -> u32 {
6✔
164
        self.x
6✔
165
    }
6✔
166

NEW
167
    pub fn yx_matches_tile_shape(&self, tile_shape: &GridShape2D) -> bool {
×
NEW
168
        self.axis_size_x() as usize == tile_shape.axis_size_x()
×
NEW
169
            && self.axis_size_y() as usize == tile_shape.axis_size_y()
×
NEW
170
    }
×
171
}
172

173
// For now we assume all models are pixel-wise, i.e., they take a single pixel with multiple bands as input and produce a single output value.
174
// To support different inputs, we would need a more sophisticated logic to produce the inputs for the model.
175
#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize)]
176
pub struct MlModelMetadata {
177
    pub file_path: PathBuf,
178
    pub input_type: RasterDataType,
179
    pub output_type: RasterDataType,
180
    pub input_shape: MlTensorShape3D,
181
    pub output_shape: MlTensorShape3D, // TODO: output measurement, e.g. classification or regression, label names for classification. This would have to be provided by the model creator along the model file as it cannot be extracted from the model file(?)
182
}
183

184
impl MlModelMetadata {
NEW
185
    pub fn num_input_bands(&self) -> u32 {
×
NEW
186
        self.input_shape.bands
×
NEW
187
    }
×
188

NEW
189
    pub fn num_output_bands(&self) -> u32 {
×
NEW
190
        self.output_shape.bands
×
NEW
191
    }
×
192

193
    pub fn input_is_single_pixel(&self) -> bool {
9✔
194
        self.input_shape.x == 1 && self.input_shape.y == 1
9✔
195
    }
9✔
196

197
    pub fn output_is_single_pixel(&self) -> bool {
3✔
198
        self.output_shape.x == 1 && self.output_shape.y == 1
3✔
199
    }
3✔
200

NEW
201
    pub fn output_is_single_attribute(&self) -> bool {
×
NEW
202
        self.num_output_bands() == 1
×
NEW
203
    }
×
204
}
205

206
#[cfg(test)]
207
mod tests {
208
    use super::*;
209

210
    #[test]
211
    fn ml_model_name_from_str() {
1✔
212
        const ML_MODEL_NAME: &str = "myModelName";
213
        let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
1✔
214
        assert_eq!(mln.name, ML_MODEL_NAME);
1✔
215
        assert!(mln.namespace.is_none());
1✔
216
    }
1✔
217

218
    #[test]
219
    fn ml_model_name_from_str_prefixed() {
1✔
220
        const ML_MODEL_NAME: &str = "d5328854-6190-4af9-ad69-4e74b0961ac9:myModelName";
221
        let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
1✔
222
        assert_eq!(mln.name, "myModelName".to_string());
1✔
223
        assert_eq!(
1✔
224
            mln.namespace,
1✔
225
            Some("d5328854-6190-4af9-ad69-4e74b0961ac9".to_string())
1✔
226
        );
1✔
227
    }
1✔
228

229
    #[test]
230
    fn ml_model_name_from_str_system() {
1✔
231
        const ML_MODEL_NAME: &str = "_:myModelName";
232
        let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
1✔
233
        assert_eq!(mln.name, "myModelName".to_string());
1✔
234
        assert!(mln.namespace.is_none());
1✔
235
    }
1✔
236
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc