• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine-python / 16446125411

22 Jul 2025 01:12PM UTC coverage: 76.94%. Remained the same
16446125411

push

github

web-flow
feat: specify ml model nodata handling (#236)

* adapt to updated ml model metadata from backend

* ruff ruff

* don't use alias field names

2806 of 3647 relevant lines covered (76.94%)

0.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.22
geoengine/ml.py
1
"""
2
Util functions for machine learning
3
"""
4

5
from __future__ import annotations
1✔
6

7
import tempfile
1✔
8
from dataclasses import dataclass
1✔
9
from pathlib import Path
1✔
10

11
import geoengine_openapi_client
1✔
12
from geoengine_openapi_client.models import MlModel, MlModelMetadata, MlTensorShape3D, RasterDataType
1✔
13
from onnx import ModelProto, TensorProto, TypeProto
1✔
14
from onnx.helper import tensor_dtype_to_string
1✔
15

16
from geoengine.auth import get_session
1✔
17
from geoengine.error import InputException
1✔
18
from geoengine.resource_identifier import MlModelName, UploadId
1✔
19

20

21
@dataclass
1✔
22
class MlModelConfig:
1✔
23
    """Configuration for an ml model"""
24

25
    name: str
1✔
26
    file_name: str
1✔
27
    metadata: MlModelMetadata
1✔
28
    display_name: str = "My Ml Model"
1✔
29
    description: str = "My Ml Model Description"
1✔
30

31

32
def register_ml_model(
1✔
33
    onnx_model: ModelProto, model_config: MlModelConfig, upload_timeout: int = 3600, register_timeout: int = 60
34
) -> MlModelName:
35
    """Uploads an onnx file and registers it as an ml model"""
36

37
    validate_model_config(
1✔
38
        onnx_model,
39
        input_type=model_config.metadata.input_type,
40
        output_type=model_config.metadata.output_type,
41
        input_shape=model_config.metadata.input_shape,
42
        out_shape=model_config.metadata.output_shape,
43
    )
44
    check_backend_constraints(model_config.metadata.input_shape, model_config.metadata.output_shape)
1✔
45

46
    session = get_session()
1✔
47

48
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
49
        with tempfile.TemporaryDirectory() as temp_dir:
1✔
50
            file_name = Path(temp_dir) / model_config.file_name
1✔
51

52
            with open(file_name, "wb") as file:
1✔
53
                file.write(onnx_model.SerializeToString())
1✔
54

55
            uploads_api = geoengine_openapi_client.UploadsApi(api_client)
1✔
56
            response = uploads_api.upload_handler([str(file_name)], _request_timeout=upload_timeout)
1✔
57

58
        upload_id = UploadId.from_response(response)
1✔
59

60
        ml_api = geoengine_openapi_client.MLApi(api_client)
1✔
61

62
        model = MlModel(
1✔
63
            name=model_config.name,
64
            file_name=model_config.file_name,
65
            upload=str(upload_id),
66
            metadata=model_config.metadata,
67
            display_name=model_config.display_name,
68
            description=model_config.description,
69
        )
70
        res_name = ml_api.add_ml_model(model, _request_timeout=register_timeout)
1✔
71
        return MlModelName.from_response(res_name)
1✔
72

73

74
def model_dim_to_tensorshape(model_dims):
1✔
75
    """Transform an ONNX dimension into a MlTensorShape3D"""
76

77
    mts = MlTensorShape3D(x=1, y=1, bands=1)
1✔
78
    if len(model_dims) == 1 and model_dims[0].dim_value in (-1, 0):
1✔
79
        pass  # in this case, the model will produce as many outs as inputs
1✔
80
    elif len(model_dims) == 1 and model_dims[0].dim_value > 0:
1✔
81
        mts.bands = model_dims[0].dim_value
1✔
82
    elif len(model_dims) == 2:
1✔
83
        if model_dims[0].dim_value in (None, -1, 0, 1):
1✔
84
            mts.bands = model_dims[1].dim_value
1✔
85
        else:
86
            mts.y = model_dims[0].dim_value
1✔
87
            mts.x = model_dims[1].dim_value
1✔
88
    elif len(model_dims) == 3:
1✔
89
        if model_dims[0].dim_value in (None, -1, 0, 1):
1✔
90
            mts.y = model_dims[1].dim_value
1✔
91
            mts.x = model_dims[2].dim_value
1✔
92
        else:
93
            mts.y = model_dims[0].dim_value
1✔
94
            mts.x = model_dims[1].dim_value
1✔
95
            mts.bands = model_dims[2].dim_value
1✔
96
    elif len(model_dims) == 4 and model_dims[0].dim_value in (None, -1, 0, 1):
1✔
97
        mts.y = model_dims[1].dim_value
1✔
98
        mts.x = model_dims[2].dim_value
1✔
99
        mts.bands = model_dims[3].dim_value
1✔
100
    else:
101
        raise InputException(f"Only 1D and 3D input tensors are supported. Got model dim {model_dims}")
×
102
    return mts
1✔
103

104

105
def check_backend_constraints(input_shape: MlTensorShape3D, output_shape: MlTensorShape3D, ge_tile_size=(512, 512)):
1✔
106
    """Checks that the shapes match the constraintsof the backend"""
107

108
    if not (input_shape.x in [1, ge_tile_size[0]] and input_shape.y in [1, ge_tile_size[1]] and input_shape.bands > 0):
1✔
109
        raise InputException(f"Backend currently supports single pixel and full tile shaped input! Got {input_shape}!")
×
110

111
    if not (
1✔
112
        output_shape.x in [1, ge_tile_size[0]] and output_shape.y in [1, ge_tile_size[1]] and output_shape.bands > 0
113
    ):
114
        raise InputException(f"Backend currently supports single pixel and full tile shaped Output! Got {input_shape}!")
×
115

116

117
# pylint: disable=too-many-branches,too-many-statements
118
def validate_model_config(
1✔
119
    onnx_model: ModelProto,
120
    *,
121
    input_type: RasterDataType,
122
    output_type: RasterDataType,
123
    input_shape: MlTensorShape3D,
124
    out_shape: MlTensorShape3D,
125
):
126
    """Validates the model config. Raises an exception if the model config is invalid"""
127

128
    def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix: str):
1✔
129
        if not data_type.tensor_type:
1✔
130
            raise InputException("Only tensor input types are supported")
×
131
        elem_type = data_type.tensor_type.elem_type
1✔
132
        expected_tensor_type = RASTER_TYPE_TO_ONNX_TYPE[expected_type]
1✔
133
        if elem_type != expected_tensor_type:
1✔
134
            elem_type_str = tensor_dtype_to_string(elem_type)
1✔
135
            expected_type_str = tensor_dtype_to_string(expected_tensor_type)
1✔
136
            raise InputException(
1✔
137
                f"Model {prefix} type `{elem_type_str}` does not match the expected type `{expected_type_str}`"
138
            )
139

140
    model_inputs = onnx_model.graph.input
1✔
141
    model_outputs = onnx_model.graph.output
1✔
142

143
    if len(model_inputs) != 1:
1✔
144
        raise InputException("Models with multiple inputs are not supported")
×
145
    check_data_type(model_inputs[0].type, input_type, "input")
1✔
146

147
    dim = model_inputs[0].type.tensor_type.shape.dim
1✔
148

149
    in_ts3d = model_dim_to_tensorshape(dim)
1✔
150
    if not in_ts3d == input_shape:
1✔
151
        raise InputException(f"Input shape {in_ts3d} and metadata {input_shape} not equal!")
1✔
152

153
    if len(model_outputs) < 1:
1✔
154
        raise InputException("Models with no outputs are not supported")
×
155
    check_data_type(model_outputs[0].type, output_type, "output")
1✔
156

157
    dim = model_outputs[0].type.tensor_type.shape.dim
1✔
158
    out_ts3d = model_dim_to_tensorshape(dim)
1✔
159
    if not out_ts3d == out_shape:
1✔
160
        raise InputException(f"Output shape {out_ts3d} and metadata {out_shape} not equal!")
×
161

162

163
RASTER_TYPE_TO_ONNX_TYPE = {
1✔
164
    RasterDataType.F32: TensorProto.FLOAT,
165
    RasterDataType.F64: TensorProto.DOUBLE,
166
    RasterDataType.U8: TensorProto.UINT8,
167
    RasterDataType.U16: TensorProto.UINT16,
168
    RasterDataType.U32: TensorProto.UINT32,
169
    RasterDataType.U64: TensorProto.UINT64,
170
    RasterDataType.I8: TensorProto.INT8,
171
    RasterDataType.I16: TensorProto.INT16,
172
    RasterDataType.I32: TensorProto.INT32,
173
    RasterDataType.I64: TensorProto.INT64,
174
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc