• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine-python / 16367912334

18 Jul 2025 10:06AM UTC coverage: 76.934% (+0.1%) from 76.806%
16367912334

push

github

web-flow
ci: use Ruff as new formatter and linter (#233)

* wip

* pycodestyle

* update dependencies

* skl2onnx

* use ruff

* apply formatter

* apply lint auto fixes

* manually apply lints

* change check

* ruff ci from branch

2805 of 3646 relevant lines covered (76.93%)

0.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.07
geoengine/workflow.py
1
"""
2
A workflow representation and methods on workflows
3
"""
4
# pylint: disable=too-many-lines
5
# TODO: split into multiple files
6

7
from __future__ import annotations
1✔
8

9
import asyncio
1✔
10
import json
1✔
11
from collections import defaultdict
1✔
12
from collections.abc import AsyncIterator
1✔
13
from io import BytesIO
1✔
14
from logging import debug
1✔
15
from os import PathLike
1✔
16
from typing import Any, TypedDict, cast
1✔
17
from uuid import UUID
1✔
18

19
import geoengine_openapi_client
1✔
20
import geopandas as gpd
1✔
21
import numpy as np
1✔
22
import pandas as pd
1✔
23
import pyarrow as pa
1✔
24
import rasterio.io
1✔
25
import requests as req
1✔
26
import rioxarray
1✔
27
import websockets
1✔
28
import xarray as xr
1✔
29
from owslib.util import Authentication, ResponseWrapper
1✔
30
from owslib.wcs import WebCoverageService
1✔
31
from PIL import Image
1✔
32
from vega import VegaLite
1✔
33

34
from geoengine import api, backports
1✔
35
from geoengine.auth import get_session
1✔
36
from geoengine.error import (
1✔
37
    GeoEngineException,
38
    InputException,
39
    MethodNotCalledOnPlotException,
40
    MethodNotCalledOnRasterException,
41
    MethodNotCalledOnVectorException,
42
    OGCXMLError,
43
)
44
from geoengine.raster import RasterTile2D
1✔
45
from geoengine.tasks import Task, TaskId
1✔
46
from geoengine.types import (
1✔
47
    ClassificationMeasurement,
48
    ProvenanceEntry,
49
    QueryRectangle,
50
    RasterColorizer,
51
    ResultDescriptor,
52
    VectorResultDescriptor,
53
)
54
from geoengine.workflow_builder.operators import Operator as WorkflowBuilderOperator
1✔
55

56
# TODO: Define as recursive type when supported in mypy: https://github.com/python/mypy/issues/731
57
JsonType = dict[str, Any] | list[Any] | int | str | float | bool | type[None]
1✔
58

59

60
class Axis(TypedDict):
1✔
61
    title: str
1✔
62

63

64
class Bin(TypedDict):
1✔
65
    binned: bool
1✔
66
    step: float
1✔
67

68

69
class Field(TypedDict):
1✔
70
    field: str
1✔
71

72

73
class DatasetIds(TypedDict):
1✔
74
    upload: UUID
1✔
75
    dataset: UUID
1✔
76

77

78
class Values(TypedDict):
1✔
79
    binStart: float
1✔
80
    binEnd: float
1✔
81
    Frequency: int
1✔
82

83

84
class X(TypedDict):
1✔
85
    field: Field
1✔
86
    bin: Bin
1✔
87
    axis: Axis
1✔
88

89

90
class X2(TypedDict):
1✔
91
    field: Field
1✔
92

93

94
class Y(TypedDict):
1✔
95
    field: Field
1✔
96
    type: str
1✔
97

98

99
class Encoding(TypedDict):
1✔
100
    x: X
1✔
101
    x2: X2
1✔
102
    y: Y
1✔
103

104

105
VegaSpec = TypedDict("VegaSpec", {"$schema": str, "data": list[Values], "mark": str, "encoding": Encoding})
1✔
106

107

108
class WorkflowId:
1✔
109
    """
110
    A wrapper around a workflow UUID
111
    """
112

113
    __workflow_id: UUID
1✔
114

115
    def __init__(self, workflow_id: UUID) -> None:
1✔
116
        self.__workflow_id = workflow_id
1✔
117

118
    @classmethod
1✔
119
    def from_response(cls, response: geoengine_openapi_client.IdResponse) -> WorkflowId:
1✔
120
        """
121
        Create a `WorkflowId` from an http response
122
        """
123
        return WorkflowId(UUID(response.id))
1✔
124

125
    def __str__(self) -> str:
1✔
126
        return str(self.__workflow_id)
1✔
127

128
    def __repr__(self) -> str:
1✔
129
        return str(self)
1✔
130

131

132
class RasterStreamProcessing:
1✔
133
    """
134
    Helper class to process raster stream data
135
    """
136

137
    @classmethod
1✔
138
    def read_arrow_ipc(cls, arrow_ipc: bytes) -> pa.RecordBatch:
1✔
139
        """Read an Arrow IPC file from a byte array"""
140

141
        reader = pa.ipc.open_file(arrow_ipc)
1✔
142
        # We know from the backend that there is only one record batch
143
        record_batch = reader.get_record_batch(0)
1✔
144
        return record_batch
1✔
145

146
    @classmethod
1✔
147
    def process_bytes(cls, tile_bytes: bytes | None) -> RasterTile2D | None:
1✔
148
        """Process a tile from a byte array"""
149

150
        if tile_bytes is None:
1✔
151
            return None
1✔
152

153
        # process the received data
154
        record_batch = RasterStreamProcessing.read_arrow_ipc(tile_bytes)
1✔
155
        tile = RasterTile2D.from_ge_record_batch(record_batch)
1✔
156

157
        return tile
1✔
158

159
    @classmethod
1✔
160
    def merge_tiles(cls, tiles: list[xr.DataArray]) -> xr.DataArray | None:
1✔
161
        """Merge a list of tiles into a single xarray"""
162

163
        if len(tiles) == 0:
1✔
164
            return None
×
165

166
        # group the tiles by band
167
        tiles_by_band: dict[int, list[xr.DataArray]] = defaultdict(list)
1✔
168
        for tile in tiles:
1✔
169
            band = tile.band.item()  # assuming 'band' is a coordinate with a single value
1✔
170
            tiles_by_band[band].append(tile)
1✔
171

172
        # build one spatial tile per band
173
        combined_by_band = []
1✔
174
        for band_tiles in tiles_by_band.values():
1✔
175
            combined = xr.combine_by_coords(band_tiles)
1✔
176
            # `combine_by_coords` always returns a `DataArray` for single variable input arrays.
177
            # This assertion verifies this for mypy
178
            assert isinstance(combined, xr.DataArray)
1✔
179
            combined_by_band.append(combined)
1✔
180

181
        # build one array with all bands and geo coordinates
182
        combined_tile = xr.concat(combined_by_band, dim="band")
1✔
183

184
        return combined_tile
1✔
185

186

187
class Workflow:
1✔
188
    """
189
    Holds a workflow id and allows querying data
190
    """
191

192
    __workflow_id: WorkflowId
1✔
193
    __result_descriptor: ResultDescriptor
1✔
194

195
    def __init__(self, workflow_id: WorkflowId) -> None:
1✔
196
        self.__workflow_id = workflow_id
1✔
197
        self.__result_descriptor = self.__query_result_descriptor()
1✔
198

199
    def __str__(self) -> str:
1✔
200
        return str(self.__workflow_id)
1✔
201

202
    def __repr__(self) -> str:
1✔
203
        return repr(self.__workflow_id)
1✔
204

205
    def __query_result_descriptor(self, timeout: int = 60) -> ResultDescriptor:
1✔
206
        """
207
        Query the metadata of the workflow result
208
        """
209

210
        session = get_session()
1✔
211

212
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
213
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
214
            response = workflows_api.get_workflow_metadata_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
215

216
        debug(response)
1✔
217

218
        return ResultDescriptor.from_response(response)
1✔
219

220
    def get_result_descriptor(self) -> ResultDescriptor:
1✔
221
        """
222
        Return the metadata of the workflow result
223
        """
224

225
        return self.__result_descriptor
1✔
226

227
    def workflow_definition(self, timeout: int = 60) -> geoengine_openapi_client.Workflow:
1✔
228
        """Return the workflow definition for this workflow"""
229

230
        session = get_session()
1✔
231

232
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
233
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
234
            response = workflows_api.load_workflow_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
235

236
        return response
1✔
237

238
    def get_dataframe(
1✔
239
        self, bbox: QueryRectangle, timeout: int = 3600, resolve_classifications: bool = False
240
    ) -> gpd.GeoDataFrame:
241
        """
242
        Query a workflow and return the WFS result as a GeoPandas `GeoDataFrame`
243
        """
244

245
        if not self.__result_descriptor.is_vector_result():
1✔
246
            raise MethodNotCalledOnVectorException()
1✔
247

248
        session = get_session()
1✔
249

250
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
251
            wfs_api = geoengine_openapi_client.OGCWFSApi(api_client)
1✔
252
            response = wfs_api.wfs_feature_handler(
1✔
253
                workflow=str(self.__workflow_id),
254
                service=geoengine_openapi_client.WfsService(geoengine_openapi_client.WfsService.WFS),
255
                request=geoengine_openapi_client.GetFeatureRequest(
256
                    geoengine_openapi_client.GetFeatureRequest.GETFEATURE
257
                ),
258
                type_names=str(self.__workflow_id),
259
                bbox=bbox.bbox_str,
260
                version=geoengine_openapi_client.WfsVersion(geoengine_openapi_client.WfsVersion.ENUM_2_DOT_0_DOT_0),
261
                time=bbox.time_str,
262
                srs_name=bbox.srs,
263
                query_resolution=str(bbox.spatial_resolution),
264
                _request_timeout=timeout,
265
            )
266

267
        def geo_json_with_time_to_geopandas(geo_json):
1✔
268
            """
269
            GeoJson has no standard for time, so we parse the when field
270
            separately and attach it to the data frame as columns `start`
271
            and `end`.
272
            """
273

274
            data = gpd.GeoDataFrame.from_features(geo_json)
1✔
275
            data = data.set_crs(bbox.srs, allow_override=True)
1✔
276

277
            start = [f["when"]["start"] for f in geo_json["features"]]
1✔
278
            end = [f["when"]["end"] for f in geo_json["features"]]
1✔
279

280
            # TODO: find a good way to infer BoT/EoT
281

282
            data["start"] = gpd.pd.to_datetime(start, errors="coerce")
1✔
283
            data["end"] = gpd.pd.to_datetime(end, errors="coerce")
1✔
284

285
            return data
1✔
286

287
        def transform_classifications(data: gpd.GeoDataFrame):
1✔
288
            result_descriptor: VectorResultDescriptor = self.__result_descriptor  # type: ignore
×
289
            for column, info in result_descriptor.columns.items():
×
290
                if isinstance(info.measurement, ClassificationMeasurement):
×
291
                    measurement: ClassificationMeasurement = info.measurement
×
292
                    classes = measurement.classes
×
293
                    data[column] = data[column].apply(lambda x, classes=classes: classes[x])  # pylint: disable=cell-var-from-loop
×
294

295
            return data
×
296

297
        result = geo_json_with_time_to_geopandas(response.to_dict())
1✔
298

299
        if resolve_classifications:
1✔
300
            result = transform_classifications(result)
×
301

302
        return result
1✔
303

304
    def wms_get_map_as_image(self, bbox: QueryRectangle, raster_colorizer: RasterColorizer) -> Image.Image:
1✔
305
        """Return the result of a WMS request as a PIL Image"""
306

307
        if not self.__result_descriptor.is_raster_result():
1✔
308
            raise MethodNotCalledOnRasterException()
×
309

310
        session = get_session()
1✔
311

312
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
313
            wms_api = geoengine_openapi_client.OGCWMSApi(api_client)
1✔
314
            response = wms_api.wms_map_handler(
1✔
315
                workflow=str(self),
316
                version=geoengine_openapi_client.WmsVersion(geoengine_openapi_client.WmsVersion.ENUM_1_DOT_3_DOT_0),
317
                service=geoengine_openapi_client.WmsService(geoengine_openapi_client.WmsService.WMS),
318
                request=geoengine_openapi_client.GetMapRequest(geoengine_openapi_client.GetMapRequest.GETMAP),
319
                width=int((bbox.spatial_bounds.xmax - bbox.spatial_bounds.xmin) / bbox.spatial_resolution.x_resolution),
320
                height=int(
321
                    (bbox.spatial_bounds.ymax - bbox.spatial_bounds.ymin) / bbox.spatial_resolution.y_resolution
322
                ),  # pylint: disable=line-too-long
323
                bbox=bbox.bbox_ogc_str,
324
                format=geoengine_openapi_client.GetMapFormat(geoengine_openapi_client.GetMapFormat.IMAGE_SLASH_PNG),
325
                layers=str(self),
326
                styles="custom:" + raster_colorizer.to_api_dict().to_json(),
327
                crs=bbox.srs,
328
                time=bbox.time_str,
329
            )
330

331
        if OGCXMLError.is_ogc_error(response):
1✔
332
            raise OGCXMLError(response)
1✔
333

334
        return Image.open(BytesIO(response))
1✔
335

336
    def plot_json(self, bbox: QueryRectangle, timeout: int = 3600) -> geoengine_openapi_client.WrappedPlotOutput:
1✔
337
        """
338
        Query a workflow and return the plot chart result as WrappedPlotOutput
339
        """
340

341
        if not self.__result_descriptor.is_plot_result():
1✔
342
            raise MethodNotCalledOnPlotException()
×
343

344
        session = get_session()
1✔
345

346
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
347
            plots_api = geoengine_openapi_client.PlotsApi(api_client)
1✔
348
            return plots_api.get_plot_handler(
1✔
349
                bbox.bbox_str,
350
                bbox.time_str,
351
                str(bbox.spatial_resolution),
352
                str(self.__workflow_id),
353
                bbox.srs,
354
                _request_timeout=timeout,
355
            )
356

357
    def plot_chart(self, bbox: QueryRectangle, timeout: int = 3600) -> VegaLite:
1✔
358
        """
359
        Query a workflow and return the plot chart result as a vega plot
360
        """
361

362
        response = self.plot_json(bbox, timeout)
1✔
363
        vega_spec: VegaSpec = json.loads(response.data["vegaString"])
1✔
364

365
        return VegaLite(vega_spec)
1✔
366

367
    def __request_wcs(
1✔
368
        self,
369
        bbox: QueryRectangle,
370
        timeout=3600,
371
        file_format: str = "image/tiff",
372
        force_no_data_value: float | None = None,
373
    ) -> ResponseWrapper:
374
        """
375
        Query a workflow and return the coverage
376

377
        Parameters
378
        ----------
379
        bbox : A bounding box for the query
380
        timeout : HTTP request timeout in seconds
381
        file_format : The format of the returned raster
382
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
383
            Otherwise, use the Geo Engine will produce masked rasters.
384
        """
385

386
        if not self.__result_descriptor.is_raster_result():
1✔
387
            raise MethodNotCalledOnRasterException()
×
388

389
        session = get_session()
1✔
390

391
        # TODO: properly build CRS string for bbox
392
        crs = f"urn:ogc:def:crs:{bbox.srs.replace(':', '::')}"
1✔
393

394
        wcs_url = f"{session.server_url}/wcs/{self.__workflow_id}"
1✔
395
        wcs = WebCoverageService(
1✔
396
            wcs_url,
397
            version="1.1.1",
398
            auth=Authentication(auth_delegate=session.requests_bearer_auth()),
399
        )
400

401
        [resx, resy] = bbox.resolution_ogc
1✔
402

403
        kwargs = {}
1✔
404

405
        if force_no_data_value is not None:
1✔
406
            kwargs["nodatavalue"] = str(float(force_no_data_value))
×
407

408
        return wcs.getCoverage(
1✔
409
            identifier=f"{self.__workflow_id}",
410
            bbox=bbox.bbox_ogc,
411
            time=[bbox.time_str],
412
            format=file_format,
413
            crs=crs,
414
            resx=resx,
415
            resy=resy,
416
            timeout=timeout,
417
            **kwargs,
418
        )
419

420
    def __get_wcs_tiff_as_memory_file(
1✔
421
        self, bbox: QueryRectangle, timeout=3600, force_no_data_value: float | None = None
422
    ) -> rasterio.io.MemoryFile:
423
        """
424
        Query a workflow and return the raster result as a memory mapped GeoTiff
425

426
        Parameters
427
        ----------
428
        bbox : A bounding box for the query
429
        timeout : HTTP request timeout in seconds
430
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
431
            Otherwise, use the Geo Engine will produce masked rasters.
432
        """
433

434
        response = self.__request_wcs(bbox, timeout, "image/tiff", force_no_data_value).read()
1✔
435

436
        # response is checked via `raise_on_error` in `getCoverage` / `openUrl`
437

438
        memory_file = rasterio.io.MemoryFile(response)
1✔
439

440
        return memory_file
1✔
441

442
    def get_array(self, bbox: QueryRectangle, timeout=3600, force_no_data_value: float | None = None) -> np.ndarray:
1✔
443
        """
444
        Query a workflow and return the raster result as a numpy array
445

446
        Parameters
447
        ----------
448
        bbox : A bounding box for the query
449
        timeout : HTTP request timeout in seconds
450
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
451
            Otherwise, use the Geo Engine will produce masked rasters.
452
        """
453

454
        with (
1✔
455
            self.__get_wcs_tiff_as_memory_file(bbox, timeout, force_no_data_value) as memfile,
456
            memfile.open() as dataset,
457
        ):
458
            array = dataset.read(1)
1✔
459

460
            return array
1✔
461

462
    def get_xarray(self, bbox: QueryRectangle, timeout=3600, force_no_data_value: float | None = None) -> xr.DataArray:
1✔
463
        """
464
        Query a workflow and return the raster result as a georeferenced xarray
465

466
        Parameters
467
        ----------
468
        bbox : A bounding box for the query
469
        timeout : HTTP request timeout in seconds
470
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
471
            Otherwise, use the Geo Engine will produce masked rasters.
472
        """
473

474
        with (
1✔
475
            self.__get_wcs_tiff_as_memory_file(bbox, timeout, force_no_data_value) as memfile,
476
            memfile.open() as dataset,
477
        ):
478
            data_array = rioxarray.open_rasterio(dataset)
1✔
479

480
            # helping mypy with inference
481
            assert isinstance(data_array, xr.DataArray)
1✔
482

483
            rio: xr.DataArray = data_array.rio
1✔
484
            rio.update_attrs(
1✔
485
                {
486
                    "crs": rio.crs,
487
                    "res": rio.resolution(),
488
                    "transform": rio.transform(),
489
                },
490
                inplace=True,
491
            )
492

493
            # TODO: add time information to dataset
494
            return data_array.load()
1✔
495

496
    # pylint: disable=too-many-arguments,too-many-positional-arguments
497
    def download_raster(
1✔
498
        self,
499
        bbox: QueryRectangle,
500
        file_path: str,
501
        timeout=3600,
502
        file_format: str = "image/tiff",
503
        force_no_data_value: float | None = None,
504
    ) -> None:
505
        """
506
        Query a workflow and save the raster result as a file on disk
507

508
        Parameters
509
        ----------
510
        bbox : A bounding box for the query
511
        file_path : The path to the file to save the raster to
512
        timeout : HTTP request timeout in seconds
513
        file_format : The format of the returned raster
514
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
515
            Otherwise, use the Geo Engine will produce masked rasters.
516
        """
517

518
        response = self.__request_wcs(bbox, timeout, file_format, force_no_data_value)
×
519

520
        with open(file_path, "wb") as file:
×
521
            file.write(response.read())
×
522

523
    def get_provenance(self, timeout: int = 60) -> list[ProvenanceEntry]:
1✔
524
        """
525
        Query the provenance of the workflow
526
        """
527

528
        session = get_session()
1✔
529

530
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
531
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
532
            response = workflows_api.get_workflow_provenance_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
533

534
        return [ProvenanceEntry.from_response(item) for item in response]
1✔
535

536
    def metadata_zip(self, path: PathLike | BytesIO, timeout: int = 60) -> None:
1✔
537
        """
538
        Query workflow metadata and citations and stores it as zip file to `path`
539
        """
540

541
        session = get_session()
×
542

543
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
×
544
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
×
545
            response = workflows_api.get_workflow_all_metadata_zip_handler(
×
546
                str(self.__workflow_id), _request_timeout=timeout
547
            )
548

549
        if isinstance(path, BytesIO):
×
550
            path.write(response)
×
551
        else:
552
            with open(path, "wb") as file:
×
553
                file.write(response)
×
554

555
    # pylint: disable=too-many-positional-arguments,too-many-positional-arguments
556
    def save_as_dataset(
1✔
557
        self,
558
        query_rectangle: geoengine_openapi_client.RasterQueryRectangle,
559
        name: str | None,
560
        display_name: str,
561
        description: str = "",
562
        timeout: int = 3600,
563
    ) -> Task:
564
        """Init task to store the workflow result as a layer"""
565

566
        # Currently, it only works for raster results
567
        if not self.__result_descriptor.is_raster_result():
1✔
568
            raise MethodNotCalledOnRasterException()
×
569

570
        session = get_session()
1✔
571

572
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
573
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
574
            response = workflows_api.dataset_from_workflow_handler(
1✔
575
                str(self.__workflow_id),
576
                geoengine_openapi_client.RasterDatasetFromWorkflow(
577
                    name=name, display_name=display_name, description=description, query=query_rectangle
578
                ),
579
                _request_timeout=timeout,
580
            )
581

582
        return Task(TaskId.from_response(response))
1✔
583

584
    async def raster_stream(
1✔
585
        self,
586
        query_rectangle: QueryRectangle,
587
        open_timeout: int = 60,
588
        bands: list[int] | None = None,  # TODO: move into query rectangle?
589
    ) -> AsyncIterator[RasterTile2D]:
590
        """Stream the workflow result as series of RasterTile2D (transformable to numpy and xarray)"""
591

592
        if bands is None:
1✔
593
            bands = [0]
1✔
594

595
        if len(bands) == 0:
1✔
596
            raise InputException("At least one band must be specified")
×
597

598
        # Currently, it only works for raster results
599
        if not self.__result_descriptor.is_raster_result():
1✔
600
            raise MethodNotCalledOnRasterException()
×
601

602
        session = get_session()
1✔
603

604
        url = (
1✔
605
            req.Request(
606
                "GET",
607
                url=f"{session.server_url}/workflow/{self.__workflow_id}/rasterStream",
608
                params={
609
                    "resultType": "arrow",
610
                    "spatialBounds": query_rectangle.bbox_str,
611
                    "timeInterval": query_rectangle.time_str,
612
                    "spatialResolution": str(query_rectangle.spatial_resolution),
613
                    "attributes": ",".join(map(str, bands)),
614
                },
615
            )
616
            .prepare()
617
            .url
618
        )
619

620
        if url is None:
1✔
621
            raise InputException("Invalid websocket url")
×
622

623
        async with websockets.asyncio.client.connect(
1✔
624
            uri=self.__replace_http_with_ws(url),
625
            extra_headers=session.auth_header,
626
            open_timeout=open_timeout,
627
            max_size=None,
628
        ) as websocket:
629
            tile_bytes: bytes | None = None
1✔
630

631
            while websocket.state == websockets.protocol.State.OPEN:
1✔
632

633
                async def read_new_bytes() -> bytes | None:
1✔
634
                    # already send the next request to speed up the process
635
                    try:
1✔
636
                        await websocket.send("NEXT")
1✔
637
                    except websockets.exceptions.ConnectionClosed:
×
638
                        # the websocket connection is already closed, we cannot read anymore
639
                        return None
×
640

641
                    try:
1✔
642
                        data: str | bytes = await websocket.recv()
1✔
643

644
                        if isinstance(data, str):
1✔
645
                            # the server sent an error message
646
                            raise GeoEngineException({"error": data})
×
647

648
                        return data
1✔
649
                    except websockets.exceptions.ConnectionClosedOK:
×
650
                        # the websocket connection closed gracefully, so we stop reading
651
                        return None
×
652

653
                (tile_bytes, tile) = await asyncio.gather(
1✔
654
                    read_new_bytes(),
655
                    # asyncio.to_thread(process_bytes, tile_bytes), # TODO: use this when min Python version is 3.9
656
                    backports.to_thread(RasterStreamProcessing.process_bytes, tile_bytes),
657
                )
658

659
                if tile is not None:
1✔
660
                    yield tile
1✔
661

662
            # process the last tile
663
            tile = RasterStreamProcessing.process_bytes(tile_bytes)
1✔
664

665
            if tile is not None:
1✔
666
                yield tile
1✔
667

668
    async def raster_stream_into_xarray(
1✔
669
        self,
670
        query_rectangle: QueryRectangle,
671
        clip_to_query_rectangle: bool = False,
672
        open_timeout: int = 60,
673
        bands: list[int] | None = None,  # TODO: move into query rectangle?
674
    ) -> xr.DataArray:
675
        """
676
        Stream the workflow result into memory and output a single xarray.
677

678
        NOTE: You can run out of memory if the query rectangle is too large.
679
        """
680

681
        if bands is None:
1✔
682
            bands = [0]
1✔
683

684
        if len(bands) == 0:
1✔
685
            raise InputException("At least one band must be specified")
×
686

687
        tile_stream = self.raster_stream(query_rectangle, open_timeout=open_timeout, bands=bands)
1✔
688

689
        timestep_xarrays: list[xr.DataArray] = []
1✔
690

691
        spatial_clip_bounds = query_rectangle.spatial_bounds if clip_to_query_rectangle else None
1✔
692

693
        async def read_tiles(
1✔
694
            remainder_tile: RasterTile2D | None,
695
        ) -> tuple[list[xr.DataArray], RasterTile2D | None]:
696
            last_timestep: np.datetime64 | None = None
1✔
697
            tiles = []
1✔
698

699
            if remainder_tile is not None:
1✔
700
                last_timestep = remainder_tile.time_start_ms
1✔
701
                xr_tile = remainder_tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
702
                tiles.append(xr_tile)
1✔
703

704
            async for tile in tile_stream:
1✔
705
                timestep: np.datetime64 = tile.time_start_ms
1✔
706
                if last_timestep is None:
1✔
707
                    last_timestep = timestep
1✔
708
                elif last_timestep != timestep:
1✔
709
                    return tiles, tile
1✔
710

711
                xr_tile = tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
712
                tiles.append(xr_tile)
1✔
713

714
            # this seems to be the last time step, so just return tiles
715
            return tiles, None
1✔
716

717
        (tiles, remainder_tile) = await read_tiles(None)
1✔
718

719
        while len(tiles):
1✔
720
            ((new_tiles, new_remainder_tile), new_timestep_xarray) = await asyncio.gather(
1✔
721
                read_tiles(remainder_tile),
722
                backports.to_thread(RasterStreamProcessing.merge_tiles, tiles),
723
                # asyncio.to_thread(merge_tiles, tiles), # TODO: use this when min Python version is 3.9
724
            )
725

726
            tiles = new_tiles
1✔
727
            remainder_tile = new_remainder_tile
1✔
728

729
            if new_timestep_xarray is not None:
1✔
730
                timestep_xarrays.append(new_timestep_xarray)
1✔
731

732
        output: xr.DataArray = cast(
1✔
733
            xr.DataArray,
734
            # await asyncio.to_thread( # TODO: use this when min Python version is 3.9
735
            await backports.to_thread(
736
                xr.concat,
737
                # TODO: This is a typings error, since the method accepts also a `xr.DataArray` and returns one
738
                cast(list[xr.Dataset], timestep_xarrays),
739
                dim="time",
740
            ),
741
        )
742

743
        return output
1✔
744

745
    async def vector_stream(
1✔
746
        self,
747
        query_rectangle: QueryRectangle,
748
        time_start_column: str = "time_start",
749
        time_end_column: str = "time_end",
750
        open_timeout: int = 60,
751
    ) -> AsyncIterator[gpd.GeoDataFrame]:
752
        """Stream the workflow result as series of `GeoDataFrame`s"""
753

754
        def read_arrow_ipc(arrow_ipc: bytes) -> pa.RecordBatch:
1✔
755
            reader = pa.ipc.open_file(arrow_ipc)
1✔
756
            # We know from the backend that there is only one record batch
757
            record_batch = reader.get_record_batch(0)
1✔
758
            return record_batch
1✔
759

760
        def create_geo_data_frame(
1✔
761
            record_batch: pa.RecordBatch, time_start_column: str, time_end_column: str
762
        ) -> gpd.GeoDataFrame:
763
            metadata = record_batch.schema.metadata
1✔
764
            spatial_reference = metadata[b"spatialReference"].decode("utf-8")
1✔
765

766
            data_frame = record_batch.to_pandas()
1✔
767

768
            geometry = gpd.GeoSeries.from_wkt(data_frame[api.GEOMETRY_COLUMN_NAME])
1✔
769
            del data_frame[api.GEOMETRY_COLUMN_NAME]  # delete the duplicated column
1✔
770

771
            geo_data_frame = gpd.GeoDataFrame(
1✔
772
                data_frame,
773
                geometry=geometry,
774
                crs=spatial_reference,
775
            )
776

777
            # split time column
778
            geo_data_frame[[time_start_column, time_end_column]] = geo_data_frame[api.TIME_COLUMN_NAME].tolist()
1✔
779
            del geo_data_frame[api.TIME_COLUMN_NAME]  # delete the duplicated column
1✔
780

781
            # parse time columns
782
            for time_column in [time_start_column, time_end_column]:
1✔
783
                geo_data_frame[time_column] = pd.to_datetime(
1✔
784
                    geo_data_frame[time_column],
785
                    utc=True,
786
                    unit="ms",
787
                    # TODO: solve time conversion problem from Geo Engine to Python for large (+/-) time instances
788
                    errors="coerce",
789
                )
790

791
            return geo_data_frame
1✔
792

793
        def process_bytes(batch_bytes: bytes | None) -> gpd.GeoDataFrame | None:
1✔
794
            if batch_bytes is None:
1✔
795
                return None
1✔
796

797
            # process the received data
798
            record_batch = read_arrow_ipc(batch_bytes)
1✔
799
            tile = create_geo_data_frame(
1✔
800
                record_batch,
801
                time_start_column=time_start_column,
802
                time_end_column=time_end_column,
803
            )
804

805
            return tile
1✔
806

807
        # Currently, it only works for raster results
808
        if not self.__result_descriptor.is_vector_result():
1✔
809
            raise MethodNotCalledOnVectorException()
×
810

811
        session = get_session()
1✔
812

813
        url = (
1✔
814
            req.Request(
815
                "GET",
816
                url=f"{session.server_url}/workflow/{self.__workflow_id}/vectorStream",
817
                params={
818
                    "resultType": "arrow",
819
                    "spatialBounds": query_rectangle.bbox_str,
820
                    "timeInterval": query_rectangle.time_str,
821
                    "spatialResolution": str(query_rectangle.spatial_resolution),
822
                },
823
            )
824
            .prepare()
825
            .url
826
        )
827

828
        if url is None:
1✔
829
            raise InputException("Invalid websocket url")
×
830

831
        async with websockets.asyncio.client.connect(
1✔
832
            uri=self.__replace_http_with_ws(url),
833
            extra_headers=session.auth_header,
834
            open_timeout=open_timeout,
835
            max_size=None,  # allow arbitrary large messages, since it is capped by the server's chunk size
836
        ) as websocket:
837
            batch_bytes: bytes | None = None
1✔
838

839
            while websocket.state == websockets.protocol.State.OPEN:
1✔
840

841
                async def read_new_bytes() -> bytes | None:
1✔
842
                    # already send the next request to speed up the process
843
                    try:
1✔
844
                        await websocket.send("NEXT")
1✔
845
                    except websockets.exceptions.ConnectionClosed:
×
846
                        # the websocket connection is already closed, we cannot read anymore
847
                        return None
×
848

849
                    try:
1✔
850
                        data: str | bytes = await websocket.recv()
1✔
851

852
                        if isinstance(data, str):
1✔
853
                            # the server sent an error message
854
                            raise GeoEngineException({"error": data})
×
855

856
                        return data
1✔
857
                    except websockets.exceptions.ConnectionClosedOK:
×
858
                        # the websocket connection closed gracefully, so we stop reading
859
                        return None
×
860

861
                (batch_bytes, batch) = await asyncio.gather(
1✔
862
                    read_new_bytes(),
863
                    # asyncio.to_thread(process_bytes, batch_bytes), # TODO: use this when min Python version is 3.9
864
                    backports.to_thread(process_bytes, batch_bytes),
865
                )
866

867
                if batch is not None:
1✔
868
                    yield batch
1✔
869

870
            # process the last tile
871
            batch = process_bytes(batch_bytes)
1✔
872

873
            if batch is not None:
1✔
874
                yield batch
1✔
875

876
    async def vector_stream_into_geopandas(
1✔
877
        self,
878
        query_rectangle: QueryRectangle,
879
        time_start_column: str = "time_start",
880
        time_end_column: str = "time_end",
881
        open_timeout: int = 60,
882
    ) -> gpd.GeoDataFrame:
883
        """
884
        Stream the workflow result into memory and output a single geo data frame.
885

886
        NOTE: You can run out of memory if the query rectangle is too large.
887
        """
888

889
        chunk_stream = self.vector_stream(
1✔
890
            query_rectangle,
891
            time_start_column=time_start_column,
892
            time_end_column=time_end_column,
893
            open_timeout=open_timeout,
894
        )
895

896
        data_frame: gpd.GeoDataFrame | None = None
1✔
897
        chunk: gpd.GeoDataFrame | None = None
1✔
898

899
        async def read_dataframe() -> gpd.GeoDataFrame | None:
1✔
900
            try:
1✔
901
                return await chunk_stream.__anext__()
1✔
902
            except StopAsyncIteration:
1✔
903
                return None
1✔
904

905
        def merge_dataframes(df_a: gpd.GeoDataFrame | None, df_b: gpd.GeoDataFrame | None) -> gpd.GeoDataFrame | None:
1✔
906
            if df_a is None:
1✔
907
                return df_b
1✔
908

909
            if df_b is None:
1✔
910
                return df_a
×
911

912
            return pd.concat([df_a, df_b], ignore_index=True)
1✔
913

914
        while True:
1✔
915
            (chunk, data_frame) = await asyncio.gather(
1✔
916
                read_dataframe(),
917
                backports.to_thread(merge_dataframes, data_frame, chunk),
918
                # TODO: use this when min Python version is 3.9
919
                # asyncio.to_thread(merge_dataframes, data_frame, chunk),
920
            )
921

922
            # we can stop when the chunk stream is exhausted
923
            if chunk is None:
1✔
924
                break
1✔
925

926
        return data_frame
1✔
927

928
    def __replace_http_with_ws(self, url: str) -> str:
1✔
929
        """
930
        Replace the protocol in the url from `http` to `ws`.
931

932
        For the websockets library, it is necessary that the url starts with `ws://`.
933
        For HTTPS, we need to use `wss://` instead.
934
        """
935

936
        [protocol, url_part] = url.split("://", maxsplit=1)
1✔
937

938
        ws_prefix = "wss://" if "s" in protocol.lower() else "ws://"
1✔
939

940
        return f"{ws_prefix}{url_part}"
1✔
941

942

943
def register_workflow(workflow: dict[str, Any] | WorkflowBuilderOperator, timeout: int = 60) -> Workflow:
1✔
944
    """
945
    Register a workflow in Geo Engine and receive a `WorkflowId`
946
    """
947

948
    if isinstance(workflow, WorkflowBuilderOperator):
1✔
949
        workflow = workflow.to_workflow_dict()
1✔
950

951
    workflow_model = geoengine_openapi_client.Workflow.from_dict(workflow)
1✔
952

953
    if workflow_model is None:
1✔
954
        raise InputException("Invalid workflow definition")
×
955

956
    session = get_session()
1✔
957

958
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
959
        workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
960
        response = workflows_api.register_workflow_handler(workflow_model, _request_timeout=timeout)
1✔
961

962
    return Workflow(WorkflowId.from_response(response))
1✔
963

964

965
def workflow_by_id(workflow_id: UUID) -> Workflow:
1✔
966
    """
967
    Create a workflow object from a workflow id
968
    """
969

970
    # TODO: check that workflow exists
971

972
    return Workflow(WorkflowId(workflow_id))
1✔
973

974

975
def get_quota(user_id: UUID | None = None, timeout: int = 60) -> geoengine_openapi_client.Quota:
1✔
976
    """
977
    Gets a user's quota. Only admins can get other users' quota.
978
    """
979

980
    session = get_session()
×
981

982
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
×
983
        user_api = geoengine_openapi_client.UserApi(api_client)
×
984

985
        if user_id is None:
×
986
            return user_api.quota_handler(_request_timeout=timeout)
×
987

988
        return user_api.get_user_quota_handler(str(user_id), _request_timeout=timeout)
×
989

990

991
def update_quota(user_id: UUID, new_available_quota: int, timeout: int = 60) -> None:
1✔
992
    """
993
    Update a user's quota. Only admins can perform this operation.
994
    """
995

996
    session = get_session()
×
997

998
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
×
999
        user_api = geoengine_openapi_client.UserApi(api_client)
×
1000
        user_api.update_user_quota_handler(
×
1001
            str(user_id), geoengine_openapi_client.UpdateQuota(available=new_available_quota), _request_timeout=timeout
1002
        )
1003

1004

1005
def data_usage(offset: int = 0, limit: int = 10) -> list[geoengine_openapi_client.DataUsage]:
1✔
1006
    """
1007
    Get data usage
1008
    """
1009

1010
    session = get_session()
1✔
1011

1012
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
1013
        user_api = geoengine_openapi_client.UserApi(api_client)
1✔
1014
        response = user_api.data_usage_handler(offset=offset, limit=limit)
1✔
1015

1016
        # create dataframe from response
1017
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1018
        df = pd.DataFrame(usage_dicts)
1✔
1019
        if "timestamp" in df.columns:
1✔
1020
            df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)
1✔
1021

1022
    return df
1✔
1023

1024

1025
def data_usage_summary(
1✔
1026
    granularity: geoengine_openapi_client.UsageSummaryGranularity,
1027
    dataset: str | None = None,
1028
    offset: int = 0,
1029
    limit: int = 10,
1030
) -> pd.DataFrame:
1031
    """
1032
    Get data usage summary
1033
    """
1034

1035
    session = get_session()
1✔
1036

1037
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
1038
        user_api = geoengine_openapi_client.UserApi(api_client)
1✔
1039
        response = user_api.data_usage_summary_handler(
1✔
1040
            dataset=dataset, granularity=granularity, offset=offset, limit=limit
1041
        )
1042

1043
        # create dataframe from response
1044
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1045
        df = pd.DataFrame(usage_dicts)
1✔
1046
        if "timestamp" in df.columns:
1✔
1047
            df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)
1✔
1048

1049
    return df
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc