• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine-python / 16373368314

18 Jul 2025 02:02PM UTC coverage: 76.94% (+0.006%) from 76.934%
16373368314

push

github

web-flow
fix: adapt to new websockets API (#234)

* wip

* pycodestyle

* update dependencies

* skl2onnx

* use ruff

* apply formatter

* apply lint auto fixes

* manually apply lints

* change check

* ruff ci from branch

* fix websockets API

* revert pyproject

* fix doc comment

* set ci branch to main again

2806 of 3647 relevant lines covered (76.94%)

0.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.1
geoengine/workflow.py
1
"""
2
A workflow representation and methods on workflows
3
"""
4
# pylint: disable=too-many-lines
5
# TODO: split into multiple files
6

7
from __future__ import annotations
1✔
8

9
import asyncio
1✔
10
import json
1✔
11
from collections import defaultdict
1✔
12
from collections.abc import AsyncIterator
1✔
13
from io import BytesIO
1✔
14
from logging import debug
1✔
15
from os import PathLike
1✔
16
from typing import Any, TypedDict, cast
1✔
17
from uuid import UUID
1✔
18

19
import geoengine_openapi_client
1✔
20
import geopandas as gpd
1✔
21
import numpy as np
1✔
22
import pandas as pd
1✔
23
import pyarrow as pa
1✔
24
import rasterio.io
1✔
25
import requests as req
1✔
26
import rioxarray
1✔
27
import websockets
1✔
28
import websockets.asyncio.client
1✔
29
import xarray as xr
1✔
30
from owslib.util import Authentication, ResponseWrapper
1✔
31
from owslib.wcs import WebCoverageService
1✔
32
from PIL import Image
1✔
33
from vega import VegaLite
1✔
34

35
from geoengine import api, backports
1✔
36
from geoengine.auth import get_session
1✔
37
from geoengine.error import (
1✔
38
    GeoEngineException,
39
    InputException,
40
    MethodNotCalledOnPlotException,
41
    MethodNotCalledOnRasterException,
42
    MethodNotCalledOnVectorException,
43
    OGCXMLError,
44
)
45
from geoengine.raster import RasterTile2D
1✔
46
from geoengine.tasks import Task, TaskId
1✔
47
from geoengine.types import (
1✔
48
    ClassificationMeasurement,
49
    ProvenanceEntry,
50
    QueryRectangle,
51
    RasterColorizer,
52
    ResultDescriptor,
53
    VectorResultDescriptor,
54
)
55
from geoengine.workflow_builder.operators import Operator as WorkflowBuilderOperator
1✔
56

57
# TODO: Define as recursive type when supported in mypy: https://github.com/python/mypy/issues/731
58
JsonType = dict[str, Any] | list[Any] | int | str | float | bool | type[None]
1✔
59

60

61
class Axis(TypedDict):
1✔
62
    title: str
1✔
63

64

65
class Bin(TypedDict):
1✔
66
    binned: bool
1✔
67
    step: float
1✔
68

69

70
class Field(TypedDict):
1✔
71
    field: str
1✔
72

73

74
class DatasetIds(TypedDict):
1✔
75
    upload: UUID
1✔
76
    dataset: UUID
1✔
77

78

79
class Values(TypedDict):
1✔
80
    binStart: float
1✔
81
    binEnd: float
1✔
82
    Frequency: int
1✔
83

84

85
class X(TypedDict):
1✔
86
    field: Field
1✔
87
    bin: Bin
1✔
88
    axis: Axis
1✔
89

90

91
class X2(TypedDict):
1✔
92
    field: Field
1✔
93

94

95
class Y(TypedDict):
1✔
96
    field: Field
1✔
97
    type: str
1✔
98

99

100
class Encoding(TypedDict):
1✔
101
    x: X
1✔
102
    x2: X2
1✔
103
    y: Y
1✔
104

105

106
VegaSpec = TypedDict("VegaSpec", {"$schema": str, "data": list[Values], "mark": str, "encoding": Encoding})
1✔
107

108

109
class WorkflowId:
1✔
110
    """
111
    A wrapper around a workflow UUID
112
    """
113

114
    __workflow_id: UUID
1✔
115

116
    def __init__(self, workflow_id: UUID) -> None:
1✔
117
        self.__workflow_id = workflow_id
1✔
118

119
    @classmethod
1✔
120
    def from_response(cls, response: geoengine_openapi_client.IdResponse) -> WorkflowId:
1✔
121
        """
122
        Create a `WorkflowId` from an http response
123
        """
124
        return WorkflowId(UUID(response.id))
1✔
125

126
    def __str__(self) -> str:
1✔
127
        return str(self.__workflow_id)
1✔
128

129
    def __repr__(self) -> str:
1✔
130
        return str(self)
1✔
131

132

133
class RasterStreamProcessing:
1✔
134
    """
135
    Helper class to process raster stream data
136
    """
137

138
    @classmethod
1✔
139
    def read_arrow_ipc(cls, arrow_ipc: bytes) -> pa.RecordBatch:
1✔
140
        """Read an Arrow IPC file from a byte array"""
141

142
        reader = pa.ipc.open_file(arrow_ipc)
1✔
143
        # We know from the backend that there is only one record batch
144
        record_batch = reader.get_record_batch(0)
1✔
145
        return record_batch
1✔
146

147
    @classmethod
1✔
148
    def process_bytes(cls, tile_bytes: bytes | None) -> RasterTile2D | None:
1✔
149
        """Process a tile from a byte array"""
150

151
        if tile_bytes is None:
1✔
152
            return None
1✔
153

154
        # process the received data
155
        record_batch = RasterStreamProcessing.read_arrow_ipc(tile_bytes)
1✔
156
        tile = RasterTile2D.from_ge_record_batch(record_batch)
1✔
157

158
        return tile
1✔
159

160
    @classmethod
1✔
161
    def merge_tiles(cls, tiles: list[xr.DataArray]) -> xr.DataArray | None:
1✔
162
        """Merge a list of tiles into a single xarray"""
163

164
        if len(tiles) == 0:
1✔
165
            return None
×
166

167
        # group the tiles by band
168
        tiles_by_band: dict[int, list[xr.DataArray]] = defaultdict(list)
1✔
169
        for tile in tiles:
1✔
170
            band = tile.band.item()  # assuming 'band' is a coordinate with a single value
1✔
171
            tiles_by_band[band].append(tile)
1✔
172

173
        # build one spatial tile per band
174
        combined_by_band = []
1✔
175
        for band_tiles in tiles_by_band.values():
1✔
176
            combined = xr.combine_by_coords(band_tiles)
1✔
177
            # `combine_by_coords` always returns a `DataArray` for single variable input arrays.
178
            # This assertion verifies this for mypy
179
            assert isinstance(combined, xr.DataArray)
1✔
180
            combined_by_band.append(combined)
1✔
181

182
        # build one array with all bands and geo coordinates
183
        combined_tile = xr.concat(combined_by_band, dim="band")
1✔
184

185
        return combined_tile
1✔
186

187

188
class Workflow:
1✔
189
    """
190
    Holds a workflow id and allows querying data
191
    """
192

193
    __workflow_id: WorkflowId
1✔
194
    __result_descriptor: ResultDescriptor
1✔
195

196
    def __init__(self, workflow_id: WorkflowId) -> None:
1✔
197
        self.__workflow_id = workflow_id
1✔
198
        self.__result_descriptor = self.__query_result_descriptor()
1✔
199

200
    def __str__(self) -> str:
1✔
201
        return str(self.__workflow_id)
1✔
202

203
    def __repr__(self) -> str:
1✔
204
        return repr(self.__workflow_id)
1✔
205

206
    def __query_result_descriptor(self, timeout: int = 60) -> ResultDescriptor:
1✔
207
        """
208
        Query the metadata of the workflow result
209
        """
210

211
        session = get_session()
1✔
212

213
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
214
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
215
            response = workflows_api.get_workflow_metadata_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
216

217
        debug(response)
1✔
218

219
        return ResultDescriptor.from_response(response)
1✔
220

221
    def get_result_descriptor(self) -> ResultDescriptor:
1✔
222
        """
223
        Return the metadata of the workflow result
224
        """
225

226
        return self.__result_descriptor
1✔
227

228
    def workflow_definition(self, timeout: int = 60) -> geoengine_openapi_client.Workflow:
1✔
229
        """Return the workflow definition for this workflow"""
230

231
        session = get_session()
1✔
232

233
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
234
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
235
            response = workflows_api.load_workflow_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
236

237
        return response
1✔
238

239
    def get_dataframe(
1✔
240
        self, bbox: QueryRectangle, timeout: int = 3600, resolve_classifications: bool = False
241
    ) -> gpd.GeoDataFrame:
242
        """
243
        Query a workflow and return the WFS result as a GeoPandas `GeoDataFrame`
244
        """
245

246
        if not self.__result_descriptor.is_vector_result():
1✔
247
            raise MethodNotCalledOnVectorException()
1✔
248

249
        session = get_session()
1✔
250

251
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
252
            wfs_api = geoengine_openapi_client.OGCWFSApi(api_client)
1✔
253
            response = wfs_api.wfs_feature_handler(
1✔
254
                workflow=str(self.__workflow_id),
255
                service=geoengine_openapi_client.WfsService(geoengine_openapi_client.WfsService.WFS),
256
                request=geoengine_openapi_client.GetFeatureRequest(
257
                    geoengine_openapi_client.GetFeatureRequest.GETFEATURE
258
                ),
259
                type_names=str(self.__workflow_id),
260
                bbox=bbox.bbox_str,
261
                version=geoengine_openapi_client.WfsVersion(geoengine_openapi_client.WfsVersion.ENUM_2_DOT_0_DOT_0),
262
                time=bbox.time_str,
263
                srs_name=bbox.srs,
264
                query_resolution=str(bbox.spatial_resolution),
265
                _request_timeout=timeout,
266
            )
267

268
        def geo_json_with_time_to_geopandas(geo_json):
1✔
269
            """
270
            GeoJson has no standard for time, so we parse the when field
271
            separately and attach it to the data frame as columns `start`
272
            and `end`.
273
            """
274

275
            data = gpd.GeoDataFrame.from_features(geo_json)
1✔
276
            data = data.set_crs(bbox.srs, allow_override=True)
1✔
277

278
            start = [f["when"]["start"] for f in geo_json["features"]]
1✔
279
            end = [f["when"]["end"] for f in geo_json["features"]]
1✔
280

281
            # TODO: find a good way to infer BoT/EoT
282

283
            data["start"] = gpd.pd.to_datetime(start, errors="coerce")
1✔
284
            data["end"] = gpd.pd.to_datetime(end, errors="coerce")
1✔
285

286
            return data
1✔
287

288
        def transform_classifications(data: gpd.GeoDataFrame):
1✔
289
            result_descriptor: VectorResultDescriptor = self.__result_descriptor  # type: ignore
×
290
            for column, info in result_descriptor.columns.items():
×
291
                if isinstance(info.measurement, ClassificationMeasurement):
×
292
                    measurement: ClassificationMeasurement = info.measurement
×
293
                    classes = measurement.classes
×
294
                    data[column] = data[column].apply(lambda x, classes=classes: classes[x])  # pylint: disable=cell-var-from-loop
×
295

296
            return data
×
297

298
        result = geo_json_with_time_to_geopandas(response.to_dict())
1✔
299

300
        if resolve_classifications:
1✔
301
            result = transform_classifications(result)
×
302

303
        return result
1✔
304

305
    def wms_get_map_as_image(self, bbox: QueryRectangle, raster_colorizer: RasterColorizer) -> Image.Image:
1✔
306
        """Return the result of a WMS request as a PIL Image"""
307

308
        if not self.__result_descriptor.is_raster_result():
1✔
309
            raise MethodNotCalledOnRasterException()
×
310

311
        session = get_session()
1✔
312

313
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
314
            wms_api = geoengine_openapi_client.OGCWMSApi(api_client)
1✔
315
            response = wms_api.wms_map_handler(
1✔
316
                workflow=str(self),
317
                version=geoengine_openapi_client.WmsVersion(geoengine_openapi_client.WmsVersion.ENUM_1_DOT_3_DOT_0),
318
                service=geoengine_openapi_client.WmsService(geoengine_openapi_client.WmsService.WMS),
319
                request=geoengine_openapi_client.GetMapRequest(geoengine_openapi_client.GetMapRequest.GETMAP),
320
                width=int((bbox.spatial_bounds.xmax - bbox.spatial_bounds.xmin) / bbox.spatial_resolution.x_resolution),
321
                height=int(
322
                    (bbox.spatial_bounds.ymax - bbox.spatial_bounds.ymin) / bbox.spatial_resolution.y_resolution
323
                ),  # pylint: disable=line-too-long
324
                bbox=bbox.bbox_ogc_str,
325
                format=geoengine_openapi_client.GetMapFormat(geoengine_openapi_client.GetMapFormat.IMAGE_SLASH_PNG),
326
                layers=str(self),
327
                styles="custom:" + raster_colorizer.to_api_dict().to_json(),
328
                crs=bbox.srs,
329
                time=bbox.time_str,
330
            )
331

332
        if OGCXMLError.is_ogc_error(response):
1✔
333
            raise OGCXMLError(response)
1✔
334

335
        return Image.open(BytesIO(response))
1✔
336

337
    def plot_json(self, bbox: QueryRectangle, timeout: int = 3600) -> geoengine_openapi_client.WrappedPlotOutput:
1✔
338
        """
339
        Query a workflow and return the plot chart result as WrappedPlotOutput
340
        """
341

342
        if not self.__result_descriptor.is_plot_result():
1✔
343
            raise MethodNotCalledOnPlotException()
×
344

345
        session = get_session()
1✔
346

347
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
348
            plots_api = geoengine_openapi_client.PlotsApi(api_client)
1✔
349
            return plots_api.get_plot_handler(
1✔
350
                bbox.bbox_str,
351
                bbox.time_str,
352
                str(bbox.spatial_resolution),
353
                str(self.__workflow_id),
354
                bbox.srs,
355
                _request_timeout=timeout,
356
            )
357

358
    def plot_chart(self, bbox: QueryRectangle, timeout: int = 3600) -> VegaLite:
1✔
359
        """
360
        Query a workflow and return the plot chart result as a vega plot
361
        """
362

363
        response = self.plot_json(bbox, timeout)
1✔
364
        vega_spec: VegaSpec = json.loads(response.data["vegaString"])
1✔
365

366
        return VegaLite(vega_spec)
1✔
367

368
    def __request_wcs(
1✔
369
        self,
370
        bbox: QueryRectangle,
371
        timeout=3600,
372
        file_format: str = "image/tiff",
373
        force_no_data_value: float | None = None,
374
    ) -> ResponseWrapper:
375
        """
376
        Query a workflow and return the coverage
377

378
        Parameters
379
        ----------
380
        bbox : A bounding box for the query
381
        timeout : HTTP request timeout in seconds
382
        file_format : The format of the returned raster
383
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
384
            Otherwise, use the Geo Engine will produce masked rasters.
385
        """
386

387
        if not self.__result_descriptor.is_raster_result():
1✔
388
            raise MethodNotCalledOnRasterException()
×
389

390
        session = get_session()
1✔
391

392
        # TODO: properly build CRS string for bbox
393
        crs = f"urn:ogc:def:crs:{bbox.srs.replace(':', '::')}"
1✔
394

395
        wcs_url = f"{session.server_url}/wcs/{self.__workflow_id}"
1✔
396
        wcs = WebCoverageService(
1✔
397
            wcs_url,
398
            version="1.1.1",
399
            auth=Authentication(auth_delegate=session.requests_bearer_auth()),
400
        )
401

402
        [resx, resy] = bbox.resolution_ogc
1✔
403

404
        kwargs = {}
1✔
405

406
        if force_no_data_value is not None:
1✔
407
            kwargs["nodatavalue"] = str(float(force_no_data_value))
×
408

409
        return wcs.getCoverage(
1✔
410
            identifier=f"{self.__workflow_id}",
411
            bbox=bbox.bbox_ogc,
412
            time=[bbox.time_str],
413
            format=file_format,
414
            crs=crs,
415
            resx=resx,
416
            resy=resy,
417
            timeout=timeout,
418
            **kwargs,
419
        )
420

421
    def __get_wcs_tiff_as_memory_file(
1✔
422
        self, bbox: QueryRectangle, timeout=3600, force_no_data_value: float | None = None
423
    ) -> rasterio.io.MemoryFile:
424
        """
425
        Query a workflow and return the raster result as a memory mapped GeoTiff
426

427
        Parameters
428
        ----------
429
        bbox : A bounding box for the query
430
        timeout : HTTP request timeout in seconds
431
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
432
            Otherwise, use the Geo Engine will produce masked rasters.
433
        """
434

435
        response = self.__request_wcs(bbox, timeout, "image/tiff", force_no_data_value).read()
1✔
436

437
        # response is checked via `raise_on_error` in `getCoverage` / `openUrl`
438

439
        memory_file = rasterio.io.MemoryFile(response)
1✔
440

441
        return memory_file
1✔
442

443
    def get_array(self, bbox: QueryRectangle, timeout=3600, force_no_data_value: float | None = None) -> np.ndarray:
1✔
444
        """
445
        Query a workflow and return the raster result as a numpy array
446

447
        Parameters
448
        ----------
449
        bbox : A bounding box for the query
450
        timeout : HTTP request timeout in seconds
451
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
452
            Otherwise, use the Geo Engine will produce masked rasters.
453
        """
454

455
        with (
1✔
456
            self.__get_wcs_tiff_as_memory_file(bbox, timeout, force_no_data_value) as memfile,
457
            memfile.open() as dataset,
458
        ):
459
            array = dataset.read(1)
1✔
460

461
            return array
1✔
462

463
    def get_xarray(self, bbox: QueryRectangle, timeout=3600, force_no_data_value: float | None = None) -> xr.DataArray:
1✔
464
        """
465
        Query a workflow and return the raster result as a georeferenced xarray
466

467
        Parameters
468
        ----------
469
        bbox : A bounding box for the query
470
        timeout : HTTP request timeout in seconds
471
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
472
            Otherwise, use the Geo Engine will produce masked rasters.
473
        """
474

475
        with (
1✔
476
            self.__get_wcs_tiff_as_memory_file(bbox, timeout, force_no_data_value) as memfile,
477
            memfile.open() as dataset,
478
        ):
479
            data_array = rioxarray.open_rasterio(dataset)
1✔
480

481
            # helping mypy with inference
482
            assert isinstance(data_array, xr.DataArray)
1✔
483

484
            rio: xr.DataArray = data_array.rio
1✔
485
            rio.update_attrs(
1✔
486
                {
487
                    "crs": rio.crs,
488
                    "res": rio.resolution(),
489
                    "transform": rio.transform(),
490
                },
491
                inplace=True,
492
            )
493

494
            # TODO: add time information to dataset
495
            return data_array.load()
1✔
496

497
    # pylint: disable=too-many-arguments,too-many-positional-arguments
498
    def download_raster(
1✔
499
        self,
500
        bbox: QueryRectangle,
501
        file_path: str,
502
        timeout=3600,
503
        file_format: str = "image/tiff",
504
        force_no_data_value: float | None = None,
505
    ) -> None:
506
        """
507
        Query a workflow and save the raster result as a file on disk
508

509
        Parameters
510
        ----------
511
        bbox : A bounding box for the query
512
        file_path : The path to the file to save the raster to
513
        timeout : HTTP request timeout in seconds
514
        file_format : The format of the returned raster
515
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
516
            Otherwise, use the Geo Engine will produce masked rasters.
517
        """
518

519
        response = self.__request_wcs(bbox, timeout, file_format, force_no_data_value)
×
520

521
        with open(file_path, "wb") as file:
×
522
            file.write(response.read())
×
523

524
    def get_provenance(self, timeout: int = 60) -> list[ProvenanceEntry]:
1✔
525
        """
526
        Query the provenance of the workflow
527
        """
528

529
        session = get_session()
1✔
530

531
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
532
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
533
            response = workflows_api.get_workflow_provenance_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
534

535
        return [ProvenanceEntry.from_response(item) for item in response]
1✔
536

537
    def metadata_zip(self, path: PathLike | BytesIO, timeout: int = 60) -> None:
1✔
538
        """
539
        Query workflow metadata and citations and stores it as zip file to `path`
540
        """
541

542
        session = get_session()
×
543

544
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
×
545
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
×
546
            response = workflows_api.get_workflow_all_metadata_zip_handler(
×
547
                str(self.__workflow_id), _request_timeout=timeout
548
            )
549

550
        if isinstance(path, BytesIO):
×
551
            path.write(response)
×
552
        else:
553
            with open(path, "wb") as file:
×
554
                file.write(response)
×
555

556
    # pylint: disable=too-many-positional-arguments,too-many-positional-arguments
557
    def save_as_dataset(
1✔
558
        self,
559
        query_rectangle: geoengine_openapi_client.RasterQueryRectangle,
560
        name: str | None,
561
        display_name: str,
562
        description: str = "",
563
        timeout: int = 3600,
564
    ) -> Task:
565
        """Init task to store the workflow result as a layer"""
566

567
        # Currently, it only works for raster results
568
        if not self.__result_descriptor.is_raster_result():
1✔
569
            raise MethodNotCalledOnRasterException()
×
570

571
        session = get_session()
1✔
572

573
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
574
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
575
            response = workflows_api.dataset_from_workflow_handler(
1✔
576
                str(self.__workflow_id),
577
                geoengine_openapi_client.RasterDatasetFromWorkflow(
578
                    name=name, display_name=display_name, description=description, query=query_rectangle
579
                ),
580
                _request_timeout=timeout,
581
            )
582

583
        return Task(TaskId.from_response(response))
1✔
584

585
    async def raster_stream(
1✔
586
        self,
587
        query_rectangle: QueryRectangle,
588
        open_timeout: int = 60,
589
        bands: list[int] | None = None,  # TODO: move into query rectangle?
590
    ) -> AsyncIterator[RasterTile2D]:
591
        """Stream the workflow result as series of RasterTile2D (transformable to numpy and xarray)"""
592

593
        if bands is None:
1✔
594
            bands = [0]
1✔
595

596
        if len(bands) == 0:
1✔
597
            raise InputException("At least one band must be specified")
×
598

599
        # Currently, it only works for raster results
600
        if not self.__result_descriptor.is_raster_result():
1✔
601
            raise MethodNotCalledOnRasterException()
×
602

603
        session = get_session()
1✔
604

605
        url = (
1✔
606
            req.Request(
607
                "GET",
608
                url=f"{session.server_url}/workflow/{self.__workflow_id}/rasterStream",
609
                params={
610
                    "resultType": "arrow",
611
                    "spatialBounds": query_rectangle.bbox_str,
612
                    "timeInterval": query_rectangle.time_str,
613
                    "spatialResolution": str(query_rectangle.spatial_resolution),
614
                    "attributes": ",".join(map(str, bands)),
615
                },
616
            )
617
            .prepare()
618
            .url
619
        )
620

621
        if url is None:
1✔
622
            raise InputException("Invalid websocket url")
×
623

624
        async with websockets.asyncio.client.connect(
1✔
625
            uri=self.__replace_http_with_ws(url),
626
            additional_headers=session.auth_header,
627
            open_timeout=open_timeout,
628
            max_size=None,
629
        ) as websocket:
630
            tile_bytes: bytes | None = None
1✔
631

632
            while websocket.state == websockets.protocol.State.OPEN:
1✔
633

634
                async def read_new_bytes() -> bytes | None:
1✔
635
                    # already send the next request to speed up the process
636
                    try:
1✔
637
                        await websocket.send("NEXT")
1✔
638
                    except websockets.exceptions.ConnectionClosed:
×
639
                        # the websocket connection is already closed, we cannot read anymore
640
                        return None
×
641

642
                    try:
1✔
643
                        data: str | bytes = await websocket.recv()
1✔
644

645
                        if isinstance(data, str):
1✔
646
                            # the server sent an error message
647
                            raise GeoEngineException({"error": data})
×
648

649
                        return data
1✔
650
                    except websockets.exceptions.ConnectionClosedOK:
×
651
                        # the websocket connection closed gracefully, so we stop reading
652
                        return None
×
653

654
                (tile_bytes, tile) = await asyncio.gather(
1✔
655
                    read_new_bytes(),
656
                    # asyncio.to_thread(process_bytes, tile_bytes), # TODO: use this when min Python version is 3.9
657
                    backports.to_thread(RasterStreamProcessing.process_bytes, tile_bytes),
658
                )
659

660
                if tile is not None:
1✔
661
                    yield tile
1✔
662

663
            # process the last tile
664
            tile = RasterStreamProcessing.process_bytes(tile_bytes)
1✔
665

666
            if tile is not None:
1✔
667
                yield tile
1✔
668

669
    async def raster_stream_into_xarray(
1✔
670
        self,
671
        query_rectangle: QueryRectangle,
672
        clip_to_query_rectangle: bool = False,
673
        open_timeout: int = 60,
674
        bands: list[int] | None = None,  # TODO: move into query rectangle?
675
    ) -> xr.DataArray:
676
        """
677
        Stream the workflow result into memory and output a single xarray.
678

679
        NOTE: You can run out of memory if the query rectangle is too large.
680
        """
681

682
        if bands is None:
1✔
683
            bands = [0]
1✔
684

685
        if len(bands) == 0:
1✔
686
            raise InputException("At least one band must be specified")
×
687

688
        tile_stream = self.raster_stream(query_rectangle, open_timeout=open_timeout, bands=bands)
1✔
689

690
        timestep_xarrays: list[xr.DataArray] = []
1✔
691

692
        spatial_clip_bounds = query_rectangle.spatial_bounds if clip_to_query_rectangle else None
1✔
693

694
        async def read_tiles(
1✔
695
            remainder_tile: RasterTile2D | None,
696
        ) -> tuple[list[xr.DataArray], RasterTile2D | None]:
697
            last_timestep: np.datetime64 | None = None
1✔
698
            tiles = []
1✔
699

700
            if remainder_tile is not None:
1✔
701
                last_timestep = remainder_tile.time_start_ms
1✔
702
                xr_tile = remainder_tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
703
                tiles.append(xr_tile)
1✔
704

705
            async for tile in tile_stream:
1✔
706
                timestep: np.datetime64 = tile.time_start_ms
1✔
707
                if last_timestep is None:
1✔
708
                    last_timestep = timestep
1✔
709
                elif last_timestep != timestep:
1✔
710
                    return tiles, tile
1✔
711

712
                xr_tile = tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
713
                tiles.append(xr_tile)
1✔
714

715
            # this seems to be the last time step, so just return tiles
716
            return tiles, None
1✔
717

718
        (tiles, remainder_tile) = await read_tiles(None)
1✔
719

720
        while len(tiles):
1✔
721
            ((new_tiles, new_remainder_tile), new_timestep_xarray) = await asyncio.gather(
1✔
722
                read_tiles(remainder_tile),
723
                backports.to_thread(RasterStreamProcessing.merge_tiles, tiles),
724
                # asyncio.to_thread(merge_tiles, tiles), # TODO: use this when min Python version is 3.9
725
            )
726

727
            tiles = new_tiles
1✔
728
            remainder_tile = new_remainder_tile
1✔
729

730
            if new_timestep_xarray is not None:
1✔
731
                timestep_xarrays.append(new_timestep_xarray)
1✔
732

733
        output: xr.DataArray = cast(
1✔
734
            xr.DataArray,
735
            # await asyncio.to_thread( # TODO: use this when min Python version is 3.9
736
            await backports.to_thread(
737
                xr.concat,
738
                # TODO: This is a typings error, since the method accepts also a `xr.DataArray` and returns one
739
                cast(list[xr.Dataset], timestep_xarrays),
740
                dim="time",
741
            ),
742
        )
743

744
        return output
1✔
745

746
    async def vector_stream(
1✔
747
        self,
748
        query_rectangle: QueryRectangle,
749
        time_start_column: str = "time_start",
750
        time_end_column: str = "time_end",
751
        open_timeout: int = 60,
752
    ) -> AsyncIterator[gpd.GeoDataFrame]:
753
        """Stream the workflow result as series of `GeoDataFrame`s"""
754

755
        def read_arrow_ipc(arrow_ipc: bytes) -> pa.RecordBatch:
1✔
756
            reader = pa.ipc.open_file(arrow_ipc)
1✔
757
            # We know from the backend that there is only one record batch
758
            record_batch = reader.get_record_batch(0)
1✔
759
            return record_batch
1✔
760

761
        def create_geo_data_frame(
1✔
762
            record_batch: pa.RecordBatch, time_start_column: str, time_end_column: str
763
        ) -> gpd.GeoDataFrame:
764
            metadata = record_batch.schema.metadata
1✔
765
            spatial_reference = metadata[b"spatialReference"].decode("utf-8")
1✔
766

767
            data_frame = record_batch.to_pandas()
1✔
768

769
            geometry = gpd.GeoSeries.from_wkt(data_frame[api.GEOMETRY_COLUMN_NAME])
1✔
770
            del data_frame[api.GEOMETRY_COLUMN_NAME]  # delete the duplicated column
1✔
771

772
            geo_data_frame = gpd.GeoDataFrame(
1✔
773
                data_frame,
774
                geometry=geometry,
775
                crs=spatial_reference,
776
            )
777

778
            # split time column
779
            geo_data_frame[[time_start_column, time_end_column]] = geo_data_frame[api.TIME_COLUMN_NAME].tolist()
1✔
780
            del geo_data_frame[api.TIME_COLUMN_NAME]  # delete the duplicated column
1✔
781

782
            # parse time columns
783
            for time_column in [time_start_column, time_end_column]:
1✔
784
                geo_data_frame[time_column] = pd.to_datetime(
1✔
785
                    geo_data_frame[time_column],
786
                    utc=True,
787
                    unit="ms",
788
                    # TODO: solve time conversion problem from Geo Engine to Python for large (+/-) time instances
789
                    errors="coerce",
790
                )
791

792
            return geo_data_frame
1✔
793

794
        def process_bytes(batch_bytes: bytes | None) -> gpd.GeoDataFrame | None:
1✔
795
            if batch_bytes is None:
1✔
796
                return None
1✔
797

798
            # process the received data
799
            record_batch = read_arrow_ipc(batch_bytes)
1✔
800
            tile = create_geo_data_frame(
1✔
801
                record_batch,
802
                time_start_column=time_start_column,
803
                time_end_column=time_end_column,
804
            )
805

806
            return tile
1✔
807

808
        # Currently, it only works for raster results
809
        if not self.__result_descriptor.is_vector_result():
1✔
810
            raise MethodNotCalledOnVectorException()
×
811

812
        session = get_session()
1✔
813

814
        url = (
1✔
815
            req.Request(
816
                "GET",
817
                url=f"{session.server_url}/workflow/{self.__workflow_id}/vectorStream",
818
                params={
819
                    "resultType": "arrow",
820
                    "spatialBounds": query_rectangle.bbox_str,
821
                    "timeInterval": query_rectangle.time_str,
822
                    "spatialResolution": str(query_rectangle.spatial_resolution),
823
                },
824
            )
825
            .prepare()
826
            .url
827
        )
828

829
        if url is None:
1✔
830
            raise InputException("Invalid websocket url")
×
831

832
        async with websockets.asyncio.client.connect(
1✔
833
            uri=self.__replace_http_with_ws(url),
834
            extra_headers=session.auth_header,
835
            open_timeout=open_timeout,
836
            max_size=None,  # allow arbitrary large messages, since it is capped by the server's chunk size
837
        ) as websocket:
838
            batch_bytes: bytes | None = None
1✔
839

840
            while websocket.state == websockets.protocol.State.OPEN:
1✔
841

842
                async def read_new_bytes() -> bytes | None:
1✔
843
                    # already send the next request to speed up the process
844
                    try:
1✔
845
                        await websocket.send("NEXT")
1✔
846
                    except websockets.exceptions.ConnectionClosed:
×
847
                        # the websocket connection is already closed, we cannot read anymore
848
                        return None
×
849

850
                    try:
1✔
851
                        data: str | bytes = await websocket.recv()
1✔
852

853
                        if isinstance(data, str):
1✔
854
                            # the server sent an error message
855
                            raise GeoEngineException({"error": data})
×
856

857
                        return data
1✔
858
                    except websockets.exceptions.ConnectionClosedOK:
×
859
                        # the websocket connection closed gracefully, so we stop reading
860
                        return None
×
861

862
                (batch_bytes, batch) = await asyncio.gather(
1✔
863
                    read_new_bytes(),
864
                    # asyncio.to_thread(process_bytes, batch_bytes), # TODO: use this when min Python version is 3.9
865
                    backports.to_thread(process_bytes, batch_bytes),
866
                )
867

868
                if batch is not None:
1✔
869
                    yield batch
1✔
870

871
            # process the last tile
872
            batch = process_bytes(batch_bytes)
1✔
873

874
            if batch is not None:
1✔
875
                yield batch
1✔
876

877
    async def vector_stream_into_geopandas(
1✔
878
        self,
879
        query_rectangle: QueryRectangle,
880
        time_start_column: str = "time_start",
881
        time_end_column: str = "time_end",
882
        open_timeout: int = 60,
883
    ) -> gpd.GeoDataFrame:
884
        """
885
        Stream the workflow result into memory and output a single geo data frame.
886

887
        NOTE: You can run out of memory if the query rectangle is too large.
888
        """
889

890
        chunk_stream = self.vector_stream(
1✔
891
            query_rectangle,
892
            time_start_column=time_start_column,
893
            time_end_column=time_end_column,
894
            open_timeout=open_timeout,
895
        )
896

897
        data_frame: gpd.GeoDataFrame | None = None
1✔
898
        chunk: gpd.GeoDataFrame | None = None
1✔
899

900
        async def read_dataframe() -> gpd.GeoDataFrame | None:
1✔
901
            try:
1✔
902
                return await chunk_stream.__anext__()
1✔
903
            except StopAsyncIteration:
1✔
904
                return None
1✔
905

906
        def merge_dataframes(df_a: gpd.GeoDataFrame | None, df_b: gpd.GeoDataFrame | None) -> gpd.GeoDataFrame | None:
1✔
907
            if df_a is None:
1✔
908
                return df_b
1✔
909

910
            if df_b is None:
1✔
911
                return df_a
×
912

913
            return pd.concat([df_a, df_b], ignore_index=True)
1✔
914

915
        while True:
1✔
916
            (chunk, data_frame) = await asyncio.gather(
1✔
917
                read_dataframe(),
918
                backports.to_thread(merge_dataframes, data_frame, chunk),
919
                # TODO: use this when min Python version is 3.9
920
                # asyncio.to_thread(merge_dataframes, data_frame, chunk),
921
            )
922

923
            # we can stop when the chunk stream is exhausted
924
            if chunk is None:
1✔
925
                break
1✔
926

927
        return data_frame
1✔
928

929
    def __replace_http_with_ws(self, url: str) -> str:
1✔
930
        """
931
        Replace the protocol in the url from `http` to `ws`.
932

933
        For the websockets library, it is necessary that the url starts with `ws://`.
934
        For HTTPS, we need to use `wss://` instead.
935
        """
936

937
        [protocol, url_part] = url.split("://", maxsplit=1)
1✔
938

939
        ws_prefix = "wss://" if "s" in protocol.lower() else "ws://"
1✔
940

941
        return f"{ws_prefix}{url_part}"
1✔
942

943

944
def register_workflow(workflow: dict[str, Any] | WorkflowBuilderOperator, timeout: int = 60) -> Workflow:
1✔
945
    """
946
    Register a workflow in Geo Engine and receive a `WorkflowId`
947
    """
948

949
    if isinstance(workflow, WorkflowBuilderOperator):
1✔
950
        workflow = workflow.to_workflow_dict()
1✔
951

952
    workflow_model = geoengine_openapi_client.Workflow.from_dict(workflow)
1✔
953

954
    if workflow_model is None:
1✔
955
        raise InputException("Invalid workflow definition")
×
956

957
    session = get_session()
1✔
958

959
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
960
        workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
961
        response = workflows_api.register_workflow_handler(workflow_model, _request_timeout=timeout)
1✔
962

963
    return Workflow(WorkflowId.from_response(response))
1✔
964

965

966
def workflow_by_id(workflow_id: UUID) -> Workflow:
1✔
967
    """
968
    Create a workflow object from a workflow id
969
    """
970

971
    # TODO: check that workflow exists
972

973
    return Workflow(WorkflowId(workflow_id))
1✔
974

975

976
def get_quota(user_id: UUID | None = None, timeout: int = 60) -> geoengine_openapi_client.Quota:
1✔
977
    """
978
    Gets a user's quota. Only admins can get other users' quota.
979
    """
980

981
    session = get_session()
×
982

983
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
×
984
        user_api = geoengine_openapi_client.UserApi(api_client)
×
985

986
        if user_id is None:
×
987
            return user_api.quota_handler(_request_timeout=timeout)
×
988

989
        return user_api.get_user_quota_handler(str(user_id), _request_timeout=timeout)
×
990

991

992
def update_quota(user_id: UUID, new_available_quota: int, timeout: int = 60) -> None:
1✔
993
    """
994
    Update a user's quota. Only admins can perform this operation.
995
    """
996

997
    session = get_session()
×
998

999
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
×
1000
        user_api = geoengine_openapi_client.UserApi(api_client)
×
1001
        user_api.update_user_quota_handler(
×
1002
            str(user_id), geoengine_openapi_client.UpdateQuota(available=new_available_quota), _request_timeout=timeout
1003
        )
1004

1005

1006
def data_usage(offset: int = 0, limit: int = 10) -> list[geoengine_openapi_client.DataUsage]:
1✔
1007
    """
1008
    Get data usage
1009
    """
1010

1011
    session = get_session()
1✔
1012

1013
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
1014
        user_api = geoengine_openapi_client.UserApi(api_client)
1✔
1015
        response = user_api.data_usage_handler(offset=offset, limit=limit)
1✔
1016

1017
        # create dataframe from response
1018
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1019
        df = pd.DataFrame(usage_dicts)
1✔
1020
        if "timestamp" in df.columns:
1✔
1021
            df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)
1✔
1022

1023
    return df
1✔
1024

1025

1026
def data_usage_summary(
1✔
1027
    granularity: geoengine_openapi_client.UsageSummaryGranularity,
1028
    dataset: str | None = None,
1029
    offset: int = 0,
1030
    limit: int = 10,
1031
) -> pd.DataFrame:
1032
    """
1033
    Get data usage summary
1034
    """
1035

1036
    session = get_session()
1✔
1037

1038
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
1039
        user_api = geoengine_openapi_client.UserApi(api_client)
1✔
1040
        response = user_api.data_usage_summary_handler(
1✔
1041
            dataset=dataset, granularity=granularity, offset=offset, limit=limit
1042
        )
1043

1044
        # create dataframe from response
1045
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1046
        df = pd.DataFrame(usage_dicts)
1✔
1047
        if "timestamp" in df.columns:
1✔
1048
            df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)
1✔
1049

1050
    return df
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc