• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine / 15296353110

28 May 2025 09:16AM UTC coverage: 89.915%. First build
15296353110

Pull #1000

github

web-flow
Merge 4296ac453 into ee3a9c265
Pull Request #1000: feat: Add ml model input and output shape to allow models run on entire tiles

494 of 591 new or added lines in 11 files covered. (83.59%)

127066 of 141318 relevant lines covered (89.91%)

59687.5 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.81
/datatypes/src/machine_learning.rs
1
use crate::{
2
    dataset::{SYSTEM_NAMESPACE, is_invalid_name_char},
3
    raster::{GridShape2D, GridSize, RasterDataType},
4
};
5
use serde::{Deserialize, Serialize, de::Visitor};
6
use snafu::Snafu;
7
use std::path::PathBuf;
8
use std::str::FromStr;
9
use strum::IntoStaticStr;
10

11
const NAME_DELIMITER: char = ':';
12

13
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
14
pub struct MlModelName {
15
    pub namespace: Option<String>,
16
    pub name: String,
17
}
18

19
#[derive(Snafu, IntoStaticStr, Debug)]
×
20
#[snafu(visibility(pub(crate)))]
21
#[snafu(context(suffix(false)))] // disables default `Snafu` suffix
22
pub enum MlModelNameError {
23
    #[snafu(display("MlModelName is empty"))]
24
    IsEmpty,
25
    #[snafu(display("invalid character '{invalid_char}' in named model"))]
26
    InvalidCharacter { invalid_char: String },
27
    #[snafu(display("ml model name must consist of at most two parts"))]
28
    TooManyParts,
29
}
30

31
impl MlModelName {
32
    /// Canonicalize a name that reflects the system namespace and model.
33
    fn canonicalize<S: Into<String> + PartialEq<&'static str>>(
2✔
34
        name: S,
2✔
35
        system_name: &'static str,
2✔
36
    ) -> Option<String> {
2✔
37
        if name == system_name {
2✔
38
            None
1✔
39
        } else {
40
            Some(name.into())
1✔
41
        }
42
    }
2✔
43

44
    pub fn new<S: Into<String>>(namespace: Option<String>, name: S) -> Self {
2✔
45
        Self {
2✔
46
            namespace,
2✔
47
            name: name.into(),
2✔
48
        }
2✔
49
    }
2✔
50
}
51

52
impl Serialize for MlModelName {
53
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
4✔
54
    where
4✔
55
        S: serde::Serializer,
4✔
56
    {
4✔
57
        let d = NAME_DELIMITER;
4✔
58
        let serialized = match (&self.namespace, &self.name) {
4✔
59
            (None, name) => name.to_string(),
4✔
60
            (Some(namespace), name) => {
×
61
                format!("{namespace}{d}{name}")
×
62
            }
63
        };
64

65
        serializer.serialize_str(&serialized)
4✔
66
    }
4✔
67
}
68

69
impl<'de> Deserialize<'de> for MlModelName {
70
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
×
71
    where
×
72
        D: serde::Deserializer<'de>,
×
73
    {
×
74
        deserializer.deserialize_str(MlModelNameDeserializeVisitor)
×
75
    }
×
76
}
77

78
impl FromStr for MlModelName {
79
    type Err = MlModelNameError;
80

81
    fn from_str(s: &str) -> Result<Self, Self::Err> {
5✔
82
        let mut strings = [None, None];
5✔
83
        let mut split = s.split(NAME_DELIMITER);
5✔
84

85
        for (buffer, part) in strings.iter_mut().zip(&mut split) {
7✔
86
            if part.is_empty() {
7✔
87
                return Err(MlModelNameError::IsEmpty);
×
88
            }
7✔
89

90
            if let Some(c) = part.matches(is_invalid_name_char).next() {
7✔
91
                return Err(MlModelNameError::InvalidCharacter {
×
92
                    invalid_char: c.to_string(),
×
93
                });
×
94
            }
7✔
95

7✔
96
            *buffer = Some(part.to_string());
7✔
97
        }
98

99
        if split.next().is_some() {
5✔
100
            return Err(MlModelNameError::TooManyParts);
×
101
        }
5✔
102

103
        match strings {
5✔
104
            [Some(namespace), Some(name)] => Ok(MlModelName {
2✔
105
                namespace: MlModelName::canonicalize(namespace, SYSTEM_NAMESPACE),
2✔
106
                name,
2✔
107
            }),
2✔
108
            [Some(name), None] => Ok(MlModelName {
3✔
109
                namespace: None,
3✔
110
                name,
3✔
111
            }),
3✔
112
            _ => Err(MlModelNameError::IsEmpty),
×
113
        }
114
    }
5✔
115
}
116

117
struct MlModelNameDeserializeVisitor;
118

119
impl Visitor<'_> for MlModelNameDeserializeVisitor {
120
    type Value = MlModelName;
121

122
    /// always keep in sync with [`is_allowed_name_char`]
123
    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
×
124
        write!(
×
125
            formatter,
×
126
            "a string consisting of a namespace and name name, separated by a colon, only using alphanumeric characters, underscores & dashes"
×
127
        )
×
128
    }
×
129

130
    fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
×
131
    where
×
132
        E: serde::de::Error,
×
133
    {
×
134
        MlModelName::from_str(s).map_err(|e| E::custom(e.to_string()))
×
135
    }
×
136
}
137

138
/// A struct describing tensor shape for `MlModelMetadata`
139
#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize, Serialize)]
140
pub struct MlTensorShape3D {
141
    pub y: u32,
142
    pub x: u32,
143
    pub bands: u32, // TODO: named attributes?
144
}
145

146
impl MlTensorShape3D {
147
    pub fn new_y_x_bands(y: u32, x: u32, bands: u32) -> Self {
3✔
148
        Self { y, x, bands }
3✔
149
    }
3✔
150

151
    pub fn new_single_pixel_bands(bands: u32) -> Self {
6✔
152
        Self { y: 1, x: 1, bands }
6✔
153
    }
6✔
154

155
    pub fn new_single_pixel_single_band() -> Self {
2✔
156
        Self::new_single_pixel_bands(1)
2✔
157
    }
2✔
158

159
    pub fn axis_size_y(&self) -> u32 {
12✔
160
        self.y
12✔
161
    }
12✔
162

163
    pub fn axis_size_x(&self) -> u32 {
12✔
164
        self.x
12✔
165
    }
12✔
166

167
    pub fn yx_matches_tile_shape(&self, tile_shape: &GridShape2D) -> bool {
4✔
168
        self.axis_size_x() as usize == tile_shape.axis_size_x()
4✔
169
            && self.axis_size_y() as usize == tile_shape.axis_size_y()
4✔
170
    }
4✔
171
}
172

173
// For now we assume all models are pixel-wise, i.e., they take a single pixel with multiple bands as input and produce a single output value.
174
// To support different inputs, we would need a more sophisticated logic to produce the inputs for the model.
175
#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize)]
176
pub struct MlModelMetadata {
177
    pub file_path: PathBuf,
178
    pub input_type: RasterDataType,
179
    pub output_type: RasterDataType,
180
    pub input_shape: MlTensorShape3D,
181
    pub output_shape: MlTensorShape3D, // TODO: output measurement, e.g. classification or regression, label names for classification. This would have to be provided by the model creator along the model file as it cannot be extracted from the model file(?)
182
}
183

184
impl MlModelMetadata {
NEW
185
    pub fn num_input_bands(&self) -> u32 {
×
NEW
186
        self.input_shape.bands
×
NEW
187
    }
×
188

189
    pub fn num_output_bands(&self) -> u32 {
4✔
190
        self.output_shape.bands
4✔
191
    }
4✔
192

193
    pub fn input_is_single_pixel(&self) -> bool {
13✔
194
        self.input_shape.x == 1 && self.input_shape.y == 1
13✔
195
    }
13✔
196

197
    pub fn output_is_single_pixel(&self) -> bool {
4✔
198
        self.output_shape.x == 1 && self.output_shape.y == 1
4✔
199
    }
4✔
200

201
    pub fn output_is_single_attribute(&self) -> bool {
4✔
202
        self.num_output_bands() == 1
4✔
203
    }
4✔
204
}
205

206
#[cfg(test)]
207
mod tests {
208
    use super::*;
209

210
    #[test]
211
    fn ml_model_name_from_str() {
1✔
212
        const ML_MODEL_NAME: &str = "myModelName";
213
        let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
1✔
214
        assert_eq!(mln.name, ML_MODEL_NAME);
1✔
215
        assert!(mln.namespace.is_none());
1✔
216
    }
1✔
217

218
    #[test]
219
    fn ml_model_name_from_str_prefixed() {
1✔
220
        const ML_MODEL_NAME: &str = "d5328854-6190-4af9-ad69-4e74b0961ac9:myModelName";
221
        let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
1✔
222
        assert_eq!(mln.name, "myModelName".to_string());
1✔
223
        assert_eq!(
1✔
224
            mln.namespace,
1✔
225
            Some("d5328854-6190-4af9-ad69-4e74b0961ac9".to_string())
1✔
226
        );
1✔
227
    }
1✔
228

229
    #[test]
230
    fn ml_model_name_from_str_system() {
1✔
231
        const ML_MODEL_NAME: &str = "_:myModelName";
232
        let mln = MlModelName::from_str(ML_MODEL_NAME).unwrap();
1✔
233
        assert_eq!(mln.name, "myModelName".to_string());
1✔
234
        assert!(mln.namespace.is_none());
1✔
235
    }
1✔
236
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc