• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine-python / 16367912334

18 Jul 2025 10:06AM UTC coverage: 76.934% (+0.1%) from 76.806%
16367912334

push

github

web-flow
ci: use Ruff as new formatter and linter (#233)

* wip

* pycodestyle

* update dependencies

* skl2onnx

* use ruff

* apply formatter

* apply lint auto fixes

* manually apply lints

* change check

* ruff ci from branch

2805 of 3646 relevant lines covered (76.93%)

0.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.22
geoengine/ml.py
1
"""
2
Util functions for machine learning
3
"""
4

5
from __future__ import annotations
1✔
6

7
import tempfile
1✔
8
from dataclasses import dataclass
1✔
9
from pathlib import Path
1✔
10

11
import geoengine_openapi_client
1✔
12
import geoengine_openapi_client.models
1✔
13
from geoengine_openapi_client.models import MlModel, MlModelMetadata, MlTensorShape3D, RasterDataType
1✔
14
from onnx import ModelProto, TensorProto, TypeProto
1✔
15
from onnx.helper import tensor_dtype_to_string
1✔
16

17
from geoengine.auth import get_session
1✔
18
from geoengine.error import InputException
1✔
19
from geoengine.resource_identifier import MlModelName, UploadId
1✔
20

21

22
@dataclass
1✔
23
class MlModelConfig:
1✔
24
    """Configuration for an ml model"""
25

26
    name: str
1✔
27
    metadata: MlModelMetadata
1✔
28
    display_name: str = "My Ml Model"
1✔
29
    description: str = "My Ml Model Description"
1✔
30

31

32
def register_ml_model(
1✔
33
    onnx_model: ModelProto, model_config: MlModelConfig, upload_timeout: int = 3600, register_timeout: int = 60
34
) -> MlModelName:
35
    """Uploads an onnx file and registers it as an ml model"""
36

37
    validate_model_config(
1✔
38
        onnx_model,
39
        input_type=model_config.metadata.input_type,
40
        output_type=model_config.metadata.output_type,
41
        input_shape=model_config.metadata.input_shape,
42
        out_shape=model_config.metadata.output_shape,
43
    )
44
    check_backend_constraints(model_config.metadata.input_shape, model_config.metadata.output_shape)
1✔
45

46
    session = get_session()
1✔
47

48
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
49
        with tempfile.TemporaryDirectory() as temp_dir:
1✔
50
            file_name = Path(temp_dir) / model_config.metadata.file_name
1✔
51

52
            with open(file_name, "wb") as file:
1✔
53
                file.write(onnx_model.SerializeToString())
1✔
54

55
            uploads_api = geoengine_openapi_client.UploadsApi(api_client)
1✔
56
            response = uploads_api.upload_handler([str(file_name)], _request_timeout=upload_timeout)
1✔
57

58
        upload_id = UploadId.from_response(response)
1✔
59

60
        ml_api = geoengine_openapi_client.MLApi(api_client)
1✔
61

62
        model = MlModel(
1✔
63
            name=model_config.name,
64
            upload=str(upload_id),
65
            metadata=model_config.metadata,
66
            display_name=model_config.display_name,
67
            description=model_config.description,
68
        )
69
        res_name = ml_api.add_ml_model(model, _request_timeout=register_timeout)
1✔
70
        return MlModelName.from_response(res_name)
1✔
71

72

73
def model_dim_to_tensorshape(model_dims):
1✔
74
    """Transform an ONNX dimension into a MlTensorShape3D"""
75

76
    mts = MlTensorShape3D(x=1, y=1, bands=1)
1✔
77
    if len(model_dims) == 1 and model_dims[0].dim_value in (-1, 0):
1✔
78
        pass  # in this case, the model will produce as many outs as inputs
1✔
79
    elif len(model_dims) == 1 and model_dims[0].dim_value > 0:
1✔
80
        mts.bands = model_dims[0].dim_value
1✔
81
    elif len(model_dims) == 2:
1✔
82
        if model_dims[0].dim_value in (None, -1, 0, 1):
1✔
83
            mts.bands = model_dims[1].dim_value
1✔
84
        else:
85
            mts.y = model_dims[0].dim_value
1✔
86
            mts.x = model_dims[1].dim_value
1✔
87
    elif len(model_dims) == 3:
1✔
88
        if model_dims[0].dim_value in (None, -1, 0, 1):
1✔
89
            mts.y = model_dims[1].dim_value
1✔
90
            mts.x = model_dims[2].dim_value
1✔
91
        else:
92
            mts.y = model_dims[0].dim_value
1✔
93
            mts.x = model_dims[1].dim_value
1✔
94
            mts.bands = model_dims[2].dim_value
1✔
95
    elif len(model_dims) == 4 and model_dims[0].dim_value in (None, -1, 0, 1):
1✔
96
        mts.y = model_dims[1].dim_value
1✔
97
        mts.x = model_dims[2].dim_value
1✔
98
        mts.bands = model_dims[3].dim_value
1✔
99
    else:
100
        raise InputException(f"Only 1D and 3D input tensors are supported. Got model dim {model_dims}")
×
101
    return mts
1✔
102

103

104
def check_backend_constraints(input_shape: MlTensorShape3D, output_shape: MlTensorShape3D, ge_tile_size=(512, 512)):
1✔
105
    """Checks that the shapes match the constraintsof the backend"""
106

107
    if not (input_shape.x in [1, ge_tile_size[0]] and input_shape.y in [1, ge_tile_size[1]] and input_shape.bands > 0):
1✔
108
        raise InputException(f"Backend currently supports single pixel and full tile shaped input! Got {input_shape}!")
×
109

110
    if not (
1✔
111
        output_shape.x in [1, ge_tile_size[0]] and output_shape.y in [1, ge_tile_size[1]] and output_shape.bands > 0
112
    ):
113
        raise InputException(f"Backend currently supports single pixel and full tile shaped Output! Got {input_shape}!")
×
114

115

116
# pylint: disable=too-many-branches,too-many-statements
117
def validate_model_config(
1✔
118
    onnx_model: ModelProto,
119
    *,
120
    input_type: RasterDataType,
121
    output_type: RasterDataType,
122
    input_shape: MlTensorShape3D,
123
    out_shape: MlTensorShape3D,
124
):
125
    """Validates the model config. Raises an exception if the model config is invalid"""
126

127
    def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix: str):
1✔
128
        if not data_type.tensor_type:
1✔
129
            raise InputException("Only tensor input types are supported")
×
130
        elem_type = data_type.tensor_type.elem_type
1✔
131
        expected_tensor_type = RASTER_TYPE_TO_ONNX_TYPE[expected_type]
1✔
132
        if elem_type != expected_tensor_type:
1✔
133
            elem_type_str = tensor_dtype_to_string(elem_type)
1✔
134
            expected_type_str = tensor_dtype_to_string(expected_tensor_type)
1✔
135
            raise InputException(
1✔
136
                f"Model {prefix} type `{elem_type_str}` does not match the expected type `{expected_type_str}`"
137
            )
138

139
    model_inputs = onnx_model.graph.input
1✔
140
    model_outputs = onnx_model.graph.output
1✔
141

142
    if len(model_inputs) != 1:
1✔
143
        raise InputException("Models with multiple inputs are not supported")
×
144
    check_data_type(model_inputs[0].type, input_type, "input")
1✔
145

146
    dim = model_inputs[0].type.tensor_type.shape.dim
1✔
147

148
    in_ts3d = model_dim_to_tensorshape(dim)
1✔
149
    if not in_ts3d == input_shape:
1✔
150
        raise InputException(f"Input shape {in_ts3d} and metadata {input_shape} not equal!")
1✔
151

152
    if len(model_outputs) < 1:
1✔
153
        raise InputException("Models with no outputs are not supported")
×
154
    check_data_type(model_outputs[0].type, output_type, "output")
1✔
155

156
    dim = model_outputs[0].type.tensor_type.shape.dim
1✔
157
    out_ts3d = model_dim_to_tensorshape(dim)
1✔
158
    if not out_ts3d == out_shape:
1✔
159
        raise InputException(f"Output shape {out_ts3d} and metadata {out_shape} not equal!")
×
160

161

162
RASTER_TYPE_TO_ONNX_TYPE = {
1✔
163
    RasterDataType.F32: TensorProto.FLOAT,
164
    RasterDataType.F64: TensorProto.DOUBLE,
165
    RasterDataType.U8: TensorProto.UINT8,
166
    RasterDataType.U16: TensorProto.UINT16,
167
    RasterDataType.U32: TensorProto.UINT32,
168
    RasterDataType.U64: TensorProto.UINT64,
169
    RasterDataType.I8: TensorProto.INT8,
170
    RasterDataType.I16: TensorProto.INT16,
171
    RasterDataType.I32: TensorProto.INT32,
172
    RasterDataType.I64: TensorProto.INT64,
173
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc