• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine / 13195982647

07 Feb 2025 08:25AM UTC coverage: 90.031% (-0.04%) from 90.073%
13195982647

Pull #1000

github

web-flow
Merge 2410e0a4e into 22fe2f6bb
Pull Request #1000: add tensor shape to ml model

109 of 194 new or added lines in 10 files covered. (56.19%)

3 existing lines in 3 files now uncovered.

126035 of 139990 relevant lines covered (90.03%)

57499.02 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

52.86
/operators/src/machine_learning/metadata_from_file.rs
1
use super::MachineLearningError;
2
use crate::machine_learning::error::{MultipleInputsNotSupported, Ort};
3
use geoengine_datatypes::{
4
    machine_learning::{MlModelMetadata, TensorShape3D},
5
    raster::RasterDataType,
6
};
7
use snafu::{ensure, ResultExt};
8
use std::path::Path;
9

10
pub fn load_model_metadata(path: &Path) -> Result<MlModelMetadata, MachineLearningError> {
2✔
11
    // TODO: proper error if model file cannot be found
12
    let session = ort::session::Session::builder()
2✔
13
        .context(Ort)?
2✔
14
        .commit_from_file(path)
2✔
15
        .context(Ort)?;
2✔
16

17
    // Onnx model may have multiple inputs, but we only support one input (with multiple features/bands)
18
    ensure!(
2✔
19
        session.inputs.len() == 1,
2✔
20
        MultipleInputsNotSupported {
×
21
            num_inputs: session.inputs.len()
×
22
        }
×
23
    );
24

25
    // Onnx model input type must be a Tensor in order to accept a 2d ndarray as input
26
    let ort::value::ValueType::Tensor {
27
        ty: input_tensor_element_type,
2✔
28
        dimensions: input_dimensions,
2✔
29
        dimension_symbols: _dimension_symbols,
2✔
30
    } = &session.inputs[0].input_type
2✔
31
    else {
32
        return Err(MachineLearningError::InvalidInputType {
×
33
            input_type: session.inputs[0].input_type.clone(),
×
34
        });
×
35
    };
36

37
    // Input dimensions must be [-1, b] to accept a table of (arbitrarily many) single pixel features (rows) with `b` bands (columns)
38
    let input_shape = try_dimensions_to_tensor_shape(input_dimensions)?;
2✔
39

40
    // Onnx model must output one prediction per pixel as
41
    // (1) a Tensor with a single dimension of unknown size (dim = [-1]), or
42
    // (2) a Tensor with two dimensions, the first of unknown size and the second of size 1 (dim = [-1, 1])
43
    let ort::value::ValueType::Tensor {
44
        ty: output_tensor_element_type,
2✔
45
        dimensions: output_dimensions,
2✔
46
        dimension_symbols: _,
47
    } = &session.outputs[0].output_type
2✔
48
    else {
UNCOV
49
        return Err(MachineLearningError::InvalidOutputType {
×
50
            output_type: session.outputs[0].output_type.clone(),
×
51
        });
×
52
    };
53

54
    let output_shape = try_dimensions_to_tensor_shape(output_dimensions)?;
2✔
55

56
    Ok(MlModelMetadata {
57
        file_path: path.to_owned(),
2✔
58
        input_type: try_raster_datatype_from_tensor_element_type(*input_tensor_element_type)?,
2✔
59
        input_shape,
2✔
60
        output_shape,
2✔
61
        output_type: try_raster_datatype_from_tensor_element_type(*output_tensor_element_type)?,
2✔
62
    })
63
}
2✔
64

65
fn try_dimensions_to_tensor_shape(
4✔
66
    dimensions: &[i64],
4✔
67
) -> Result<TensorShape3D, MachineLearningError> {
4✔
68
    if dimensions.len() == 1 && dimensions[0] == -1 {
4✔
69
        Ok(TensorShape3D::new_y_x_bands(1, 1, 1))
1✔
70
    } else if dimensions.len() == 2 && dimensions[0] == -1 && dimensions[1] > 0 {
3✔
71
        Ok(TensorShape3D::new_y_x_bands(1, 1, dimensions[1] as u32))
3✔
NEW
72
    } else if dimensions.len() == 4
×
NEW
73
        && dimensions[0] == -1
×
NEW
74
        && dimensions[1] > 0
×
NEW
75
        && dimensions[2] > 0
×
NEW
76
        && dimensions[3] > 0
×
77
    {
NEW
78
        Ok(TensorShape3D::new_y_x_bands(
×
NEW
79
            dimensions[1] as u32, // TODO: figure out how the axis in the dimensions are ordered!
×
NEW
80
            dimensions[2] as u32,
×
NEW
81
            dimensions[3] as u32, // In this case we could also accept attributes at first position, however we need to figure out how we would handle this...
×
NEW
82
        ))
×
83
    } else {
NEW
84
        Err(MachineLearningError::InvalidDimensions {
×
NEW
85
            dimensions: dimensions.to_vec(),
×
NEW
86
        })
×
87
    }
88
}
4✔
89

90
// can't implement `TryFrom` here because `RasterDataType` is in operators crate
91
fn try_raster_datatype_from_tensor_element_type(
4✔
92
    value: ort::tensor::TensorElementType,
4✔
93
) -> Result<RasterDataType, MachineLearningError> {
4✔
94
    match value {
4✔
95
        ort::tensor::TensorElementType::Float32 => Ok(RasterDataType::F32),
3✔
96
        ort::tensor::TensorElementType::Uint8 | ort::tensor::TensorElementType::Bool => {
97
            Ok(RasterDataType::U8)
×
98
        }
99
        ort::tensor::TensorElementType::Int8 => Ok(RasterDataType::I8),
×
100
        ort::tensor::TensorElementType::Uint16 => Ok(RasterDataType::U16),
×
101
        ort::tensor::TensorElementType::Int16 => Ok(RasterDataType::I16),
×
102
        ort::tensor::TensorElementType::Int32 => Ok(RasterDataType::I32),
×
103
        ort::tensor::TensorElementType::Int64 => Ok(RasterDataType::I64),
1✔
104
        ort::tensor::TensorElementType::Float64 => Ok(RasterDataType::F64),
×
105
        ort::tensor::TensorElementType::Uint32 => Ok(RasterDataType::U32),
×
106
        ort::tensor::TensorElementType::Uint64 => Ok(RasterDataType::U64),
×
107
        _ => Err(MachineLearningError::UnsupportedTensorElementType {
×
108
            element_type: value,
×
109
        }),
×
110
    }
111
}
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc