• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine / 15296353110

28 May 2025 09:16AM UTC coverage: 89.915%. First build
15296353110

Pull #1000

github

web-flow
Merge 4296ac453 into ee3a9c265
Pull Request #1000: feat: Add ml model input and output shape to allow models run on entire tiles

494 of 591 new or added lines in 11 files covered. (83.59%)

127066 of 141318 relevant lines covered (89.91%)

59687.5 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

67.51
/operators/src/machine_learning/onnx_util.rs
1
use geoengine_datatypes::{
2
    machine_learning::{MlModelMetadata, MlTensorShape3D},
3
    raster::{GridShape2D, GridSize, RasterDataType},
4
};
5
use ort::session::Session;
6
use snafu::{ResultExt, ensure};
7

8
use crate::machine_learning::error::{
9
    InvalidInputPixelShape, InvalidInputTensorShape, InvalidInputType, InvalidOutputPixelShape,
10
    InvalidOutputType, MetadataModelInputShapeMismatch, MetadataModelInputTypeMismatch,
11
    MetadataModelOutputShapeMismatch, MultipleInputsNotSupported, UnsupportedInOutMapping,
12
    UnsupportedNumberOfOutputAttributes,
13
};
14

15
use super::{MachineLearningError, error::Ort};
16

17
pub fn load_onnx_model_from_metadata(
4✔
18
    ml_model_metadata: &MlModelMetadata,
4✔
19
) -> Result<Session, MachineLearningError> {
4✔
20
    ort::session::Session::builder()
4✔
21
        .context(Ort)?
4✔
22
        .commit_from_file(&ml_model_metadata.file_path)
4✔
23
        .context(Ort)
4✔
24
        .inspect_err(|e| {
4✔
NEW
25
            tracing::debug!(
×
NEW
26
                "Could not create ONNX session for {:?}. Error: {}",
×
NEW
27
                ml_model_metadata.file_path.file_name(),
×
28
                e
29
            );
30
        })
4✔
31
}
4✔
32

33
pub fn check_model_shape(
4✔
34
    model_metadata: &MlModelMetadata,
4✔
35
    tiling_shape: GridShape2D,
4✔
36
) -> Result<(), MachineLearningError> {
4✔
37
    check_model_input_shape_supported(model_metadata, tiling_shape)?;
4✔
38
    check_model_output_shape_supported(model_metadata, tiling_shape)?;
4✔
39
    check_input_output_mapping_supported(model_metadata)
4✔
40
}
4✔
41

42
pub fn check_model_input_shape_supported(
4✔
43
    model_metadata: &MlModelMetadata,
4✔
44
    tiling_shape: GridShape2D,
4✔
45
) -> Result<(), MachineLearningError> {
4✔
46
    // check that we can use the model input shape with the operator
4✔
47
    ensure!(
4✔
48
        model_metadata.input_is_single_pixel()
4✔
49
            || model_metadata
1✔
50
                .input_shape
1✔
51
                .yx_matches_tile_shape(&tiling_shape),
1✔
NEW
52
        InvalidInputPixelShape {
×
NEW
53
            tensor_shape: model_metadata.input_shape,
×
NEW
54
            tiling_shape
×
NEW
55
        }
×
56
    );
57

58
    Ok(())
4✔
59
}
4✔
60

61
pub fn check_model_output_shape_supported(
4✔
62
    model_metadata: &MlModelMetadata,
4✔
63
    tiling_shape: GridShape2D,
4✔
64
) -> Result<(), MachineLearningError> {
4✔
65
    // check that we can use the model output shape with the operator
4✔
66
    ensure!(
4✔
67
        model_metadata.output_is_single_pixel()
4✔
68
            || model_metadata
1✔
69
                .output_shape
1✔
70
                .yx_matches_tile_shape(&tiling_shape),
1✔
NEW
71
        InvalidOutputPixelShape {
×
NEW
72
            tensor_shape: model_metadata.output_shape,
×
NEW
73
            tiling_shape
×
NEW
74
        }
×
75
    );
76

77
    ensure!(
4✔
78
        model_metadata.output_is_single_attribute(),
4✔
NEW
79
        UnsupportedNumberOfOutputAttributes {
×
NEW
80
            output_attributes: model_metadata.num_output_bands()
×
NEW
81
        }
×
82
    );
83

84
    Ok(())
4✔
85
}
4✔
86

87
pub fn check_input_output_mapping_supported(
4✔
88
    model_metadata: &MlModelMetadata,
4✔
89
) -> Result<(), MachineLearningError> {
4✔
90
    ensure!(
4✔
91
        model_metadata.input_shape.axis_size_x() == model_metadata.output_shape.axis_size_x()
4✔
92
            && model_metadata.input_shape.axis_size_y()
4✔
93
                == model_metadata.output_shape.axis_size_y(),
4✔
NEW
94
        UnsupportedInOutMapping {
×
NEW
95
            in_shape: model_metadata.input_shape,
×
NEW
96
            out_shape: model_metadata.output_shape
×
NEW
97
        }
×
98
    );
99

100
    Ok(())
4✔
101
}
4✔
102

103
pub fn try_onnx_tensor_to_ml_tensorshape_3d(
2✔
104
    tensor_dimensions: &[i64],
2✔
105
) -> Result<MlTensorShape3D, MachineLearningError> {
2✔
NEW
106
    match *tensor_dimensions {
×
107
        [-1..=1] => Ok(MlTensorShape3D {
1✔
108
            x: 1,
1✔
109
            y: 1,
1✔
110
            bands: 1,
1✔
111
        }),
1✔
112
        [bands] | [-1..=1, bands] if bands > 0 => Ok(MlTensorShape3D {
1✔
113
            x: 1,
1✔
114
            y: 1,
1✔
115
            bands: (bands as u32),
1✔
116
        }),
1✔
NEW
117
        [x, y] | [-1..=1, x, y] if x > 0 && y > 0 => Ok(MlTensorShape3D {
×
NEW
118
            x: x as u32,
×
NEW
119
            y: y as u32,
×
NEW
120
            bands: 1,
×
NEW
121
        }),
×
NEW
122
        [x, y, bands] | [-1..=1, x, y, bands] if x > 0 && y > 0 && bands > 0 => {
×
NEW
123
            Ok(MlTensorShape3D {
×
NEW
124
                x: x as u32,
×
NEW
125
                y: y as u32,
×
NEW
126
                bands: bands as u32,
×
NEW
127
            })
×
128
        }
NEW
129
        _ => Err(MachineLearningError::InvalidDimensions {
×
NEW
130
            dimensions: tensor_dimensions.to_vec(),
×
NEW
131
        }),
×
132
    }
133
}
2✔
134

135
///
136
/// Check that the session input is a tensor with the dimension specified in the metadata.
137
///
138
/// # Panics
139
///
140
/// If the input is a tensor but no `tensor_dimension` is provided.
141
///
142
pub fn check_onnx_model_input_matches_metadata(
1✔
143
    session: &Session,
1✔
144
    metadata_input: MlTensorShape3D,
1✔
145
    metadata_input_type: RasterDataType,
1✔
146
) -> Result<(), MachineLearningError> {
1✔
147
    let inputs = &session.inputs;
1✔
148
    ensure!(
1✔
149
        inputs.len() == 1,
1✔
NEW
150
        MultipleInputsNotSupported {
×
NEW
151
            num_inputs: inputs.len()
×
NEW
152
        }
×
153
    );
154

155
    let input = &inputs[0];
1✔
156
    ensure!(
1✔
157
        input.input_type.is_tensor(),
1✔
NEW
158
        InvalidInputType {
×
NEW
159
            input_type: input.input_type.clone()
×
NEW
160
        }
×
161
    );
162
    let dimensions = input
1✔
163
        .input_type
1✔
164
        .tensor_dimensions()
1✔
165
        .expect("input must be a tensor. checked before!");
1✔
166

167
    let shape = try_onnx_tensor_to_ml_tensorshape_3d(dimensions)?;
1✔
168

169
    ensure!(
1✔
170
        shape == metadata_input,
1✔
NEW
171
        MetadataModelInputShapeMismatch {
×
NEW
172
            model_dimensions: dimensions.clone(),
×
NEW
173
            model_shape: shape,
×
NEW
174
            metadata_shape: metadata_input
×
NEW
175
        }
×
176
    );
177

178
    let input_tensor_type = input
1✔
179
        .input_type
1✔
180
        .tensor_type()
1✔
181
        .expect("input must be a tensor. ckecked above!");
1✔
182
    let input_raster_type = try_raster_datatype_from_tensor_element_type(input_tensor_type)?;
1✔
183

184
    ensure!(
1✔
185
        input_raster_type == metadata_input_type,
1✔
NEW
186
        MetadataModelInputTypeMismatch {
×
NEW
187
            model_tensor_type: input_tensor_type,
×
NEW
188
            model_raster_type: input_raster_type,
×
NEW
189
            metadata_type: metadata_input_type
×
NEW
190
        }
×
191
    );
192

193
    Ok(())
1✔
194
}
1✔
195

196
///
197
/// Check that the session output is a tensor with the dimension speified in the metadata.
198
///
199
/// # Panics
200
///
201
/// If the output is a tensor but no `tensor_dimension` is provided.
202
///
203
pub fn check_onnx_model_output_matches_metadata(
1✔
204
    session: &Session,
1✔
205
    metadata_output: MlTensorShape3D,
1✔
206
    metadata_output_type: RasterDataType,
1✔
207
) -> Result<(), MachineLearningError> {
1✔
208
    let outputs = &session.outputs;
1✔
209

1✔
210
    // we assume that the first output is the one to use
1✔
211
    // TODO: make this configurable?
1✔
212
    let output = &outputs[0];
1✔
213
    ensure!(
1✔
214
        output.output_type.is_tensor(),
1✔
NEW
215
        InvalidOutputType {
×
NEW
216
            output_type: output.output_type.clone()
×
NEW
217
        }
×
218
    );
219

220
    let dimensions = output
1✔
221
        .output_type
1✔
222
        .tensor_dimensions()
1✔
223
        .expect("input must be a tensor. checked before!");
1✔
224

225
    let shape = try_onnx_tensor_to_ml_tensorshape_3d(dimensions)?;
1✔
226

227
    ensure!(
1✔
228
        shape == metadata_output,
1✔
NEW
229
        MetadataModelOutputShapeMismatch {
×
NEW
230
            model_dimensions: dimensions.clone(),
×
NEW
231
            model_shape: shape,
×
NEW
232
            metadata_shape: metadata_output
×
NEW
233
        }
×
234
    );
235

236
    let output_tensor_type = output
1✔
237
        .output_type
1✔
238
        .tensor_type()
1✔
239
        .expect("output must be a tensor. ckecked above!");
1✔
240
    let output_raster_type = try_raster_datatype_from_tensor_element_type(output_tensor_type)?;
1✔
241

242
    ensure!(
1✔
243
        output_raster_type == metadata_output_type,
1✔
NEW
244
        MetadataModelInputTypeMismatch {
×
NEW
245
            model_tensor_type: output_tensor_type,
×
NEW
246
            model_raster_type: output_raster_type,
×
NEW
247
            metadata_type: metadata_output_type
×
NEW
248
        }
×
249
    );
250

251
    Ok(())
1✔
252
}
1✔
253

254
pub fn check_onnx_model_matches_metadata(
1✔
255
    session: &Session,
1✔
256
    model_metadata: &MlModelMetadata,
1✔
257
) -> Result<(), MachineLearningError> {
1✔
258
    check_onnx_model_input_matches_metadata(
1✔
259
        session,
1✔
260
        model_metadata.input_shape,
1✔
261
        model_metadata.input_type,
1✔
262
    )?;
1✔
263
    check_onnx_model_output_matches_metadata(
1✔
264
        session,
1✔
265
        model_metadata.output_shape,
1✔
266
        model_metadata.output_type,
1✔
267
    )
1✔
268
}
1✔
269

270
pub fn check_model_input_features(
3✔
271
    model_metadata: &MlModelMetadata,
3✔
272
    tiling_shape: GridShape2D,
3✔
273
    num_bands: u32,
3✔
274
) -> Result<(), MachineLearningError> {
3✔
275
    let used_in_shape = if model_metadata.input_is_single_pixel() {
3✔
276
        MlTensorShape3D::new_single_pixel_bands(num_bands)
2✔
277
    } else {
278
        MlTensorShape3D::new_y_x_bands(
1✔
279
            tiling_shape.axis_size_y() as u32,
1✔
280
            tiling_shape.axis_size_x() as u32,
1✔
281
            num_bands,
1✔
282
        )
1✔
283
    };
284

285
    // check that number of input bands fits number of model features
286
    ensure!(
3✔
287
        model_metadata.input_shape == used_in_shape,
3✔
NEW
288
        InvalidInputTensorShape {
×
NEW
289
            input_shape: used_in_shape,
×
NEW
290
            model_shape: model_metadata.input_shape
×
NEW
291
        }
×
292
    );
293

294
    Ok(())
3✔
295
}
3✔
296

297
// can't implement `TryFrom` here because `RasterDataType` is in operators crate
298
pub(crate) fn try_raster_datatype_from_tensor_element_type(
2✔
299
    value: ort::tensor::TensorElementType,
2✔
300
) -> Result<RasterDataType, MachineLearningError> {
2✔
301
    match value {
2✔
302
        ort::tensor::TensorElementType::Float32 => Ok(RasterDataType::F32),
1✔
303
        ort::tensor::TensorElementType::Uint8 | ort::tensor::TensorElementType::Bool => {
NEW
304
            Ok(RasterDataType::U8)
×
305
        }
NEW
306
        ort::tensor::TensorElementType::Int8 => Ok(RasterDataType::I8),
×
NEW
307
        ort::tensor::TensorElementType::Uint16 => Ok(RasterDataType::U16),
×
NEW
308
        ort::tensor::TensorElementType::Int16 => Ok(RasterDataType::I16),
×
NEW
309
        ort::tensor::TensorElementType::Int32 => Ok(RasterDataType::I32),
×
310
        ort::tensor::TensorElementType::Int64 => Ok(RasterDataType::I64),
1✔
NEW
311
        ort::tensor::TensorElementType::Float64 => Ok(RasterDataType::F64),
×
NEW
312
        ort::tensor::TensorElementType::Uint32 => Ok(RasterDataType::U32),
×
NEW
313
        ort::tensor::TensorElementType::Uint64 => Ok(RasterDataType::U64),
×
NEW
314
        _ => Err(MachineLearningError::UnsupportedTensorElementType {
×
NEW
315
            element_type: value,
×
NEW
316
        }),
×
317
    }
318
}
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc