• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine-python / 19325873846

13 Nov 2025 08:53AM UTC coverage: 79.74% (+2.8%) from 76.961%
19325873846

push

github

web-flow
build: adapt to openapi client update (#246)

* build: adapt to openapi client update

* use published openapi version

2944 of 3692 relevant lines covered (79.74%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.01
geoengine/workflow.py
1
"""
2
A workflow representation and methods on workflows
3
"""
4
# pylint: disable=too-many-lines
5
# TODO: split into multiple files
6

7
from __future__ import annotations
1✔
8

9
import asyncio
1✔
10
import json
1✔
11
from collections import defaultdict
1✔
12
from collections.abc import AsyncIterator
1✔
13
from io import BytesIO
1✔
14
from logging import debug
1✔
15
from os import PathLike
1✔
16
from typing import Any, TypedDict, cast
1✔
17
from uuid import UUID
1✔
18

19
import geoengine_openapi_client
1✔
20
import geopandas as gpd
1✔
21
import numpy as np
1✔
22
import pandas as pd
1✔
23
import pyarrow as pa
1✔
24
import rasterio.io
1✔
25
import requests as req
1✔
26
import rioxarray
1✔
27
import websockets
1✔
28
import websockets.asyncio.client
1✔
29
import xarray as xr
1✔
30
from owslib.util import Authentication, ResponseWrapper
1✔
31
from owslib.wcs import WebCoverageService
1✔
32
from PIL import Image
1✔
33
from vega import VegaLite
1✔
34

35
from geoengine import api, backports
1✔
36
from geoengine.auth import get_session
1✔
37
from geoengine.error import (
1✔
38
    GeoEngineException,
39
    InputException,
40
    MethodNotCalledOnPlotException,
41
    MethodNotCalledOnRasterException,
42
    MethodNotCalledOnVectorException,
43
    OGCXMLError,
44
)
45
from geoengine.raster import RasterTile2D
1✔
46
from geoengine.tasks import Task, TaskId
1✔
47
from geoengine.types import (
1✔
48
    ClassificationMeasurement,
49
    ProvenanceEntry,
50
    QueryRectangle,
51
    RasterColorizer,
52
    ResultDescriptor,
53
    VectorResultDescriptor,
54
)
55
from geoengine.workflow_builder.operators import Operator as WorkflowBuilderOperator
1✔
56

57
# TODO: Define as recursive type when supported in mypy: https://github.com/python/mypy/issues/731
58
JsonType = dict[str, Any] | list[Any] | int | str | float | bool | type[None]
1✔
59

60

61
class Axis(TypedDict):
1✔
62
    title: str
1✔
63

64

65
class Bin(TypedDict):
1✔
66
    binned: bool
1✔
67
    step: float
1✔
68

69

70
class Field(TypedDict):
1✔
71
    field: str
1✔
72

73

74
class DatasetIds(TypedDict):
1✔
75
    upload: UUID
1✔
76
    dataset: UUID
1✔
77

78

79
class Values(TypedDict):
1✔
80
    binStart: float
1✔
81
    binEnd: float
1✔
82
    Frequency: int
1✔
83

84

85
class X(TypedDict):
1✔
86
    field: Field
1✔
87
    bin: Bin
1✔
88
    axis: Axis
1✔
89

90

91
class X2(TypedDict):
1✔
92
    field: Field
1✔
93

94

95
class Y(TypedDict):
1✔
96
    field: Field
1✔
97
    type: str
1✔
98

99

100
class Encoding(TypedDict):
1✔
101
    x: X
1✔
102
    x2: X2
1✔
103
    y: Y
1✔
104

105

106
VegaSpec = TypedDict("VegaSpec", {"$schema": str, "data": list[Values], "mark": str, "encoding": Encoding})
1✔
107

108

109
class WorkflowId:
1✔
110
    """
111
    A wrapper around a workflow UUID
112
    """
113

114
    __workflow_id: UUID
1✔
115

116
    def __init__(self, workflow_id: UUID) -> None:
1✔
117
        self.__workflow_id = workflow_id
1✔
118

119
    @classmethod
1✔
120
    def from_response(cls, response: geoengine_openapi_client.IdResponse) -> WorkflowId:
1✔
121
        """
122
        Create a `WorkflowId` from an http response
123
        """
124
        return WorkflowId(response.id)
1✔
125

126
    def __str__(self) -> str:
1✔
127
        return str(self.__workflow_id)
1✔
128

129
    def __repr__(self) -> str:
1✔
130
        return str(self)
1✔
131

132
    def to_dict(self) -> UUID:
1✔
133
        return self.__workflow_id
1✔
134

135

136
class RasterStreamProcessing:
1✔
137
    """
138
    Helper class to process raster stream data
139
    """
140

141
    @classmethod
1✔
142
    def read_arrow_ipc(cls, arrow_ipc: bytes) -> pa.RecordBatch:
1✔
143
        """Read an Arrow IPC file from a byte array"""
144

145
        reader = pa.ipc.open_file(arrow_ipc)
1✔
146
        # We know from the backend that there is only one record batch
147
        record_batch = reader.get_record_batch(0)
1✔
148
        return record_batch
1✔
149

150
    @classmethod
1✔
151
    def process_bytes(cls, tile_bytes: bytes | None) -> RasterTile2D | None:
1✔
152
        """Process a tile from a byte array"""
153

154
        if tile_bytes is None:
1✔
155
            return None
1✔
156

157
        # process the received data
158
        record_batch = RasterStreamProcessing.read_arrow_ipc(tile_bytes)
1✔
159
        tile = RasterTile2D.from_ge_record_batch(record_batch)
1✔
160

161
        return tile
1✔
162

163
    @classmethod
1✔
164
    def merge_tiles(cls, tiles: list[xr.DataArray]) -> xr.DataArray | None:
1✔
165
        """Merge a list of tiles into a single xarray"""
166

167
        if len(tiles) == 0:
1✔
168
            return None
×
169

170
        # group the tiles by band
171
        tiles_by_band: dict[int, list[xr.DataArray]] = defaultdict(list)
1✔
172
        for tile in tiles:
1✔
173
            band = tile.band.item()  # assuming 'band' is a coordinate with a single value
1✔
174
            tiles_by_band[band].append(tile)
1✔
175

176
        # build one spatial tile per band
177
        combined_by_band = []
1✔
178
        for band_tiles in tiles_by_band.values():
1✔
179
            combined = xr.combine_by_coords(band_tiles)
1✔
180
            # `combine_by_coords` always returns a `DataArray` for single variable input arrays.
181
            # This assertion verifies this for mypy
182
            assert isinstance(combined, xr.DataArray)
1✔
183
            combined_by_band.append(combined)
1✔
184

185
        # build one array with all bands and geo coordinates
186
        combined_tile = xr.concat(combined_by_band, dim="band")
1✔
187

188
        return combined_tile
1✔
189

190

191
class Workflow:
1✔
192
    """
193
    Holds a workflow id and allows querying data
194
    """
195

196
    __workflow_id: WorkflowId
1✔
197
    __result_descriptor: ResultDescriptor
1✔
198

199
    def __init__(self, workflow_id: WorkflowId) -> None:
1✔
200
        self.__workflow_id = workflow_id
1✔
201
        self.__result_descriptor = self.__query_result_descriptor()
1✔
202

203
    def __str__(self) -> str:
1✔
204
        return str(self.__workflow_id)
1✔
205

206
    def __repr__(self) -> str:
1✔
207
        return repr(self.__workflow_id)
1✔
208

209
    def __query_result_descriptor(self, timeout: int = 60) -> ResultDescriptor:
1✔
210
        """
211
        Query the metadata of the workflow result
212
        """
213

214
        session = get_session()
1✔
215

216
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
217
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
218
            response = workflows_api.get_workflow_metadata_handler(
1✔
219
                self.__workflow_id.to_dict(), _request_timeout=timeout
220
            )
221

222
        debug(response)
1✔
223

224
        return ResultDescriptor.from_response(response)
1✔
225

226
    def get_result_descriptor(self) -> ResultDescriptor:
1✔
227
        """
228
        Return the metadata of the workflow result
229
        """
230

231
        return self.__result_descriptor
1✔
232

233
    def workflow_definition(self, timeout: int = 60) -> geoengine_openapi_client.Workflow:
1✔
234
        """Return the workflow definition for this workflow"""
235

236
        session = get_session()
1✔
237

238
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
239
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
240
            response = workflows_api.load_workflow_handler(self.__workflow_id.to_dict(), _request_timeout=timeout)
1✔
241

242
        return response
1✔
243

244
    def get_dataframe(
1✔
245
        self, bbox: QueryRectangle, timeout: int = 3600, resolve_classifications: bool = False
246
    ) -> gpd.GeoDataFrame:
247
        """
248
        Query a workflow and return the WFS result as a GeoPandas `GeoDataFrame`
249
        """
250

251
        if not self.__result_descriptor.is_vector_result():
1✔
252
            raise MethodNotCalledOnVectorException()
1✔
253

254
        session = get_session()
1✔
255

256
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
257
            wfs_api = geoengine_openapi_client.OGCWFSApi(api_client)
1✔
258
            response = wfs_api.wfs_feature_handler(
1✔
259
                workflow=self.__workflow_id.to_dict(),
260
                service=geoengine_openapi_client.WfsService(geoengine_openapi_client.WfsService.WFS),
261
                request=geoengine_openapi_client.GetFeatureRequest(
262
                    geoengine_openapi_client.GetFeatureRequest.GETFEATURE
263
                ),
264
                type_names=str(self.__workflow_id),
265
                bbox=bbox.bbox_str,
266
                version=geoengine_openapi_client.WfsVersion(geoengine_openapi_client.WfsVersion.ENUM_2_DOT_0_DOT_0),
267
                time=bbox.time_str,
268
                srs_name=bbox.srs,
269
                query_resolution=str(bbox.spatial_resolution),
270
                _request_timeout=timeout,
271
            )
272

273
        def geo_json_with_time_to_geopandas(geo_json):
1✔
274
            """
275
            GeoJson has no standard for time, so we parse the when field
276
            separately and attach it to the data frame as columns `start`
277
            and `end`.
278
            """
279

280
            data = gpd.GeoDataFrame.from_features(geo_json)
1✔
281
            data = data.set_crs(bbox.srs, allow_override=True)
1✔
282

283
            start = [f["when"]["start"] for f in geo_json["features"]]
1✔
284
            end = [f["when"]["end"] for f in geo_json["features"]]
1✔
285

286
            # TODO: find a good way to infer BoT/EoT
287

288
            data["start"] = gpd.pd.to_datetime(start, errors="coerce")
1✔
289
            data["end"] = gpd.pd.to_datetime(end, errors="coerce")
1✔
290

291
            return data
1✔
292

293
        def transform_classifications(data: gpd.GeoDataFrame):
1✔
294
            result_descriptor: VectorResultDescriptor = self.__result_descriptor  # type: ignore
×
295
            for column, info in result_descriptor.columns.items():
×
296
                if isinstance(info.measurement, ClassificationMeasurement):
×
297
                    measurement: ClassificationMeasurement = info.measurement
×
298
                    classes = measurement.classes
×
299
                    data[column] = data[column].apply(lambda x, classes=classes: classes[x])  # pylint: disable=cell-var-from-loop
×
300

301
            return data
×
302

303
        result = geo_json_with_time_to_geopandas(response.to_dict())
1✔
304

305
        if resolve_classifications:
1✔
306
            result = transform_classifications(result)
×
307

308
        return result
1✔
309

310
    def wms_get_map_as_image(self, bbox: QueryRectangle, raster_colorizer: RasterColorizer) -> Image.Image:
1✔
311
        """Return the result of a WMS request as a PIL Image"""
312

313
        if not self.__result_descriptor.is_raster_result():
1✔
314
            raise MethodNotCalledOnRasterException()
×
315

316
        session = get_session()
1✔
317

318
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
319
            wms_api = geoengine_openapi_client.OGCWMSApi(api_client)
1✔
320
            response = wms_api.wms_map_handler(
1✔
321
                workflow=self.__workflow_id.to_dict(),
322
                version=geoengine_openapi_client.WmsVersion(geoengine_openapi_client.WmsVersion.ENUM_1_DOT_3_DOT_0),
323
                service=geoengine_openapi_client.WmsService(geoengine_openapi_client.WmsService.WMS),
324
                request=geoengine_openapi_client.GetMapRequest(geoengine_openapi_client.GetMapRequest.GETMAP),
325
                width=int((bbox.spatial_bounds.xmax - bbox.spatial_bounds.xmin) / bbox.spatial_resolution.x_resolution),
326
                height=int(
327
                    (bbox.spatial_bounds.ymax - bbox.spatial_bounds.ymin) / bbox.spatial_resolution.y_resolution
328
                ),  # pylint: disable=line-too-long
329
                bbox=bbox.bbox_ogc_str,
330
                format=geoengine_openapi_client.GetMapFormat(geoengine_openapi_client.GetMapFormat.IMAGE_SLASH_PNG),
331
                layers=str(self),
332
                styles="custom:" + raster_colorizer.to_api_dict().to_json(),
333
                crs=bbox.srs,
334
                time=bbox.time_str,
335
            )
336

337
        if OGCXMLError.is_ogc_error(response):
1✔
338
            raise OGCXMLError(response)
1✔
339

340
        return Image.open(BytesIO(response))
1✔
341

342
    def plot_json(self, bbox: QueryRectangle, timeout: int = 3600) -> geoengine_openapi_client.WrappedPlotOutput:
1✔
343
        """
344
        Query a workflow and return the plot chart result as WrappedPlotOutput
345
        """
346

347
        if not self.__result_descriptor.is_plot_result():
1✔
348
            raise MethodNotCalledOnPlotException()
×
349

350
        session = get_session()
1✔
351

352
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
353
            plots_api = geoengine_openapi_client.PlotsApi(api_client)
1✔
354
            return plots_api.get_plot_handler(
1✔
355
                bbox.bbox_str,
356
                bbox.time_str,
357
                str(bbox.spatial_resolution),
358
                self.__workflow_id.to_dict(),
359
                bbox.srs,
360
                _request_timeout=timeout,
361
            )
362

363
    def plot_chart(self, bbox: QueryRectangle, timeout: int = 3600) -> VegaLite:
1✔
364
        """
365
        Query a workflow and return the plot chart result as a vega plot
366
        """
367

368
        response = self.plot_json(bbox, timeout)
1✔
369
        vega_spec: VegaSpec = json.loads(response.data["vegaString"])
1✔
370

371
        return VegaLite(vega_spec)
1✔
372

373
    def __request_wcs(
1✔
374
        self,
375
        bbox: QueryRectangle,
376
        timeout=3600,
377
        file_format: str = "image/tiff",
378
        force_no_data_value: float | None = None,
379
    ) -> ResponseWrapper:
380
        """
381
        Query a workflow and return the coverage
382

383
        Parameters
384
        ----------
385
        bbox : A bounding box for the query
386
        timeout : HTTP request timeout in seconds
387
        file_format : The format of the returned raster
388
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
389
            Otherwise, use the Geo Engine will produce masked rasters.
390
        """
391

392
        if not self.__result_descriptor.is_raster_result():
1✔
393
            raise MethodNotCalledOnRasterException()
×
394

395
        session = get_session()
1✔
396

397
        # TODO: properly build CRS string for bbox
398
        crs = f"urn:ogc:def:crs:{bbox.srs.replace(':', '::')}"
1✔
399

400
        wcs_url = f"{session.server_url}/wcs/{self.__workflow_id}"
1✔
401
        wcs = WebCoverageService(
1✔
402
            wcs_url,
403
            version="1.1.1",
404
            auth=Authentication(auth_delegate=session.requests_bearer_auth()),
405
        )
406

407
        [resx, resy] = bbox.resolution_ogc
1✔
408

409
        kwargs = {}
1✔
410

411
        if force_no_data_value is not None:
1✔
412
            kwargs["nodatavalue"] = str(float(force_no_data_value))
1✔
413

414
        return wcs.getCoverage(
1✔
415
            identifier=f"{self.__workflow_id}",
416
            bbox=bbox.bbox_ogc,
417
            time=[bbox.time_str],
418
            format=file_format,
419
            crs=crs,
420
            resx=resx,
421
            resy=resy,
422
            timeout=timeout,
423
            **kwargs,
424
        )
425

426
    def __get_wcs_tiff_as_memory_file(
1✔
427
        self, bbox: QueryRectangle, timeout=3600, force_no_data_value: float | None = None
428
    ) -> rasterio.io.MemoryFile:
429
        """
430
        Query a workflow and return the raster result as a memory mapped GeoTiff
431

432
        Parameters
433
        ----------
434
        bbox : A bounding box for the query
435
        timeout : HTTP request timeout in seconds
436
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
437
            Otherwise, use the Geo Engine will produce masked rasters.
438
        """
439

440
        response = self.__request_wcs(bbox, timeout, "image/tiff", force_no_data_value).read()
1✔
441

442
        # response is checked via `raise_on_error` in `getCoverage` / `openUrl`
443

444
        memory_file = rasterio.io.MemoryFile(response)
1✔
445

446
        return memory_file
1✔
447

448
    def get_array(self, bbox: QueryRectangle, timeout=3600, force_no_data_value: float | None = None) -> np.ndarray:
1✔
449
        """
450
        Query a workflow and return the raster result as a numpy array
451

452
        Parameters
453
        ----------
454
        bbox : A bounding box for the query
455
        timeout : HTTP request timeout in seconds
456
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
457
            Otherwise, use the Geo Engine will produce masked rasters.
458
        """
459

460
        with (
1✔
461
            self.__get_wcs_tiff_as_memory_file(bbox, timeout, force_no_data_value) as memfile,
462
            memfile.open() as dataset,
463
        ):
464
            array = dataset.read(1)
1✔
465

466
            return array
1✔
467

468
    def get_xarray(self, bbox: QueryRectangle, timeout=3600, force_no_data_value: float | None = None) -> xr.DataArray:
1✔
469
        """
470
        Query a workflow and return the raster result as a georeferenced xarray
471

472
        Parameters
473
        ----------
474
        bbox : A bounding box for the query
475
        timeout : HTTP request timeout in seconds
476
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
477
            Otherwise, use the Geo Engine will produce masked rasters.
478
        """
479

480
        with (
1✔
481
            self.__get_wcs_tiff_as_memory_file(bbox, timeout, force_no_data_value) as memfile,
482
            memfile.open() as dataset,
483
        ):
484
            data_array = rioxarray.open_rasterio(dataset)
1✔
485

486
            # helping mypy with inference
487
            assert isinstance(data_array, xr.DataArray)
1✔
488

489
            rio: xr.DataArray = data_array.rio
1✔
490
            rio.update_attrs(
1✔
491
                {
492
                    "crs": rio.crs,
493
                    "res": rio.resolution(),
494
                    "transform": rio.transform(),
495
                },
496
                inplace=True,
497
            )
498

499
            # TODO: add time information to dataset
500
            return data_array.load()
1✔
501

502
    # pylint: disable=too-many-arguments,too-many-positional-arguments
503
    def download_raster(
1✔
504
        self,
505
        bbox: QueryRectangle,
506
        file_path: str,
507
        timeout=3600,
508
        file_format: str = "image/tiff",
509
        force_no_data_value: float | None = None,
510
    ) -> None:
511
        """
512
        Query a workflow and save the raster result as a file on disk
513

514
        Parameters
515
        ----------
516
        bbox : A bounding box for the query
517
        file_path : The path to the file to save the raster to
518
        timeout : HTTP request timeout in seconds
519
        file_format : The format of the returned raster
520
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
521
            Otherwise, use the Geo Engine will produce masked rasters.
522
        """
523

524
        response = self.__request_wcs(bbox, timeout, file_format, force_no_data_value)
1✔
525

526
        with open(file_path, "wb") as file:
1✔
527
            file.write(response.read())
1✔
528

529
    def get_provenance(self, timeout: int = 60) -> list[ProvenanceEntry]:
1✔
530
        """
531
        Query the provenance of the workflow
532
        """
533

534
        session = get_session()
1✔
535

536
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
537
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
538
            response = workflows_api.get_workflow_provenance_handler(
1✔
539
                self.__workflow_id.to_dict(), _request_timeout=timeout
540
            )
541

542
        return [ProvenanceEntry.from_response(item) for item in response]
1✔
543

544
    def metadata_zip(self, path: PathLike | BytesIO, timeout: int = 60) -> None:
1✔
545
        """
546
        Query workflow metadata and citations and stores it as zip file to `path`
547
        """
548

549
        session = get_session()
1✔
550

551
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
552
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
553
            response = workflows_api.get_workflow_all_metadata_zip_handler(
1✔
554
                self.__workflow_id.to_dict(), _request_timeout=timeout
555
            )
556

557
        if isinstance(path, BytesIO):
1✔
558
            path.write(response)
1✔
559
        else:
560
            with open(path, "wb") as file:
×
561
                file.write(response)
×
562

563
    # pylint: disable=too-many-positional-arguments,too-many-positional-arguments
564
    def save_as_dataset(
1✔
565
        self,
566
        query_rectangle: geoengine_openapi_client.RasterQueryRectangle,
567
        name: str | None,
568
        display_name: str,
569
        description: str = "",
570
        timeout: int = 3600,
571
    ) -> Task:
572
        """Init task to store the workflow result as a layer"""
573

574
        # Currently, it only works for raster results
575
        if not self.__result_descriptor.is_raster_result():
1✔
576
            raise MethodNotCalledOnRasterException()
×
577

578
        session = get_session()
1✔
579

580
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
581
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
582
            response = workflows_api.dataset_from_workflow_handler(
1✔
583
                self.__workflow_id.to_dict(),
584
                geoengine_openapi_client.RasterDatasetFromWorkflow(
585
                    name=name, display_name=display_name, description=description, query=query_rectangle
586
                ),
587
                _request_timeout=timeout,
588
            )
589

590
        return Task(TaskId.from_response(response))
1✔
591

592
    async def raster_stream(
1✔
593
        self,
594
        query_rectangle: QueryRectangle,
595
        open_timeout: int = 60,
596
        bands: list[int] | None = None,  # TODO: move into query rectangle?
597
    ) -> AsyncIterator[RasterTile2D]:
598
        """Stream the workflow result as series of RasterTile2D (transformable to numpy and xarray)"""
599

600
        if bands is None:
1✔
601
            bands = [0]
1✔
602

603
        if len(bands) == 0:
1✔
604
            raise InputException("At least one band must be specified")
×
605

606
        # Currently, it only works for raster results
607
        if not self.__result_descriptor.is_raster_result():
1✔
608
            raise MethodNotCalledOnRasterException()
×
609

610
        session = get_session()
1✔
611

612
        url = (
1✔
613
            req.Request(
614
                "GET",
615
                url=f"{session.server_url}/workflow/{self.__workflow_id}/rasterStream",
616
                params={
617
                    "resultType": "arrow",
618
                    "spatialBounds": query_rectangle.bbox_str,
619
                    "timeInterval": query_rectangle.time_str,
620
                    "spatialResolution": str(query_rectangle.spatial_resolution),
621
                    "attributes": ",".join(map(str, bands)),
622
                },
623
            )
624
            .prepare()
625
            .url
626
        )
627

628
        if url is None:
1✔
629
            raise InputException("Invalid websocket url")
×
630

631
        async with websockets.asyncio.client.connect(
1✔
632
            uri=self.__replace_http_with_ws(url),
633
            additional_headers=session.auth_header,
634
            open_timeout=open_timeout,
635
            max_size=None,
636
        ) as websocket:
637
            tile_bytes: bytes | None = None
1✔
638

639
            while websocket.state == websockets.protocol.State.OPEN:
1✔
640

641
                async def read_new_bytes() -> bytes | None:
1✔
642
                    # already send the next request to speed up the process
643
                    try:
1✔
644
                        await websocket.send("NEXT")
1✔
645
                    except websockets.exceptions.ConnectionClosed:
×
646
                        # the websocket connection is already closed, we cannot read anymore
647
                        return None
×
648

649
                    try:
1✔
650
                        data: str | bytes = await websocket.recv()
1✔
651

652
                        if isinstance(data, str):
1✔
653
                            # the server sent an error message
654
                            raise GeoEngineException({"error": data})
×
655

656
                        return data
1✔
657
                    except websockets.exceptions.ConnectionClosedOK:
×
658
                        # the websocket connection closed gracefully, so we stop reading
659
                        return None
×
660

661
                (tile_bytes, tile) = await asyncio.gather(
1✔
662
                    read_new_bytes(),
663
                    # asyncio.to_thread(process_bytes, tile_bytes), # TODO: use this when min Python version is 3.9
664
                    backports.to_thread(RasterStreamProcessing.process_bytes, tile_bytes),
665
                )
666

667
                if tile is not None:
1✔
668
                    yield tile
1✔
669

670
            # process the last tile
671
            tile = RasterStreamProcessing.process_bytes(tile_bytes)
1✔
672

673
            if tile is not None:
1✔
674
                yield tile
1✔
675

676
    async def raster_stream_into_xarray(
1✔
677
        self,
678
        query_rectangle: QueryRectangle,
679
        clip_to_query_rectangle: bool = False,
680
        open_timeout: int = 60,
681
        bands: list[int] | None = None,  # TODO: move into query rectangle?
682
    ) -> xr.DataArray:
683
        """
684
        Stream the workflow result into memory and output a single xarray.
685

686
        NOTE: You can run out of memory if the query rectangle is too large.
687
        """
688

689
        if bands is None:
1✔
690
            bands = [0]
1✔
691

692
        if len(bands) == 0:
1✔
693
            raise InputException("At least one band must be specified")
×
694

695
        tile_stream = self.raster_stream(query_rectangle, open_timeout=open_timeout, bands=bands)
1✔
696

697
        timestep_xarrays: list[xr.DataArray] = []
1✔
698

699
        spatial_clip_bounds = query_rectangle.spatial_bounds if clip_to_query_rectangle else None
1✔
700

701
        async def read_tiles(
1✔
702
            remainder_tile: RasterTile2D | None,
703
        ) -> tuple[list[xr.DataArray], RasterTile2D | None]:
704
            last_timestep: np.datetime64 | None = None
1✔
705
            tiles = []
1✔
706

707
            if remainder_tile is not None:
1✔
708
                last_timestep = remainder_tile.time_start_ms
1✔
709
                xr_tile = remainder_tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
710
                tiles.append(xr_tile)
1✔
711

712
            async for tile in tile_stream:
1✔
713
                timestep: np.datetime64 = tile.time_start_ms
1✔
714
                if last_timestep is None:
1✔
715
                    last_timestep = timestep
1✔
716
                elif last_timestep != timestep:
1✔
717
                    return tiles, tile
1✔
718

719
                xr_tile = tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
720
                tiles.append(xr_tile)
1✔
721

722
            # this seems to be the last time step, so just return tiles
723
            return tiles, None
1✔
724

725
        (tiles, remainder_tile) = await read_tiles(None)
1✔
726

727
        while len(tiles):
1✔
728
            ((new_tiles, new_remainder_tile), new_timestep_xarray) = await asyncio.gather(
1✔
729
                read_tiles(remainder_tile),
730
                backports.to_thread(RasterStreamProcessing.merge_tiles, tiles),
731
                # asyncio.to_thread(merge_tiles, tiles), # TODO: use this when min Python version is 3.9
732
            )
733

734
            tiles = new_tiles
1✔
735
            remainder_tile = new_remainder_tile
1✔
736

737
            if new_timestep_xarray is not None:
1✔
738
                timestep_xarrays.append(new_timestep_xarray)
1✔
739

740
        output: xr.DataArray = cast(
1✔
741
            xr.DataArray,
742
            # await asyncio.to_thread( # TODO: use this when min Python version is 3.9
743
            await backports.to_thread(
744
                xr.concat,
745
                # TODO: This is a typings error, since the method accepts also a `xr.DataArray` and returns one
746
                cast(list[xr.Dataset], timestep_xarrays),
747
                dim="time",
748
            ),
749
        )
750

751
        return output
1✔
752

753
    async def vector_stream(
1✔
754
        self,
755
        query_rectangle: QueryRectangle,
756
        time_start_column: str = "time_start",
757
        time_end_column: str = "time_end",
758
        open_timeout: int = 60,
759
    ) -> AsyncIterator[gpd.GeoDataFrame]:
760
        """Stream the workflow result as series of `GeoDataFrame`s"""
761

762
        def read_arrow_ipc(arrow_ipc: bytes) -> pa.RecordBatch:
1✔
763
            reader = pa.ipc.open_file(arrow_ipc)
1✔
764
            # We know from the backend that there is only one record batch
765
            record_batch = reader.get_record_batch(0)
1✔
766
            return record_batch
1✔
767

768
        def create_geo_data_frame(
1✔
769
            record_batch: pa.RecordBatch, time_start_column: str, time_end_column: str
770
        ) -> gpd.GeoDataFrame:
771
            metadata = record_batch.schema.metadata
1✔
772
            spatial_reference = metadata[b"spatialReference"].decode("utf-8")
1✔
773

774
            data_frame = record_batch.to_pandas()
1✔
775

776
            geometry = gpd.GeoSeries.from_wkt(data_frame[api.GEOMETRY_COLUMN_NAME])
1✔
777
            # delete the duplicated column
778
            del data_frame[api.GEOMETRY_COLUMN_NAME]
1✔
779

780
            geo_data_frame = gpd.GeoDataFrame(
1✔
781
                data_frame,
782
                geometry=geometry,
783
                crs=spatial_reference,
784
            )
785

786
            # split time column
787
            geo_data_frame[[time_start_column, time_end_column]] = geo_data_frame[api.TIME_COLUMN_NAME].tolist()
1✔
788
            # delete the duplicated column
789
            del geo_data_frame[api.TIME_COLUMN_NAME]
1✔
790

791
            # parse time columns
792
            for time_column in [time_start_column, time_end_column]:
1✔
793
                geo_data_frame[time_column] = pd.to_datetime(
1✔
794
                    geo_data_frame[time_column],
795
                    utc=True,
796
                    unit="ms",
797
                    # TODO: solve time conversion problem from Geo Engine to Python for large (+/-) time instances
798
                    errors="coerce",
799
                )
800

801
            return geo_data_frame
1✔
802

803
        def process_bytes(batch_bytes: bytes | None) -> gpd.GeoDataFrame | None:
1✔
804
            if batch_bytes is None:
1✔
805
                return None
1✔
806

807
            # process the received data
808
            record_batch = read_arrow_ipc(batch_bytes)
1✔
809
            tile = create_geo_data_frame(
1✔
810
                record_batch,
811
                time_start_column=time_start_column,
812
                time_end_column=time_end_column,
813
            )
814

815
            return tile
1✔
816

817
        # Currently, it only works for raster results
818
        if not self.__result_descriptor.is_vector_result():
1✔
819
            raise MethodNotCalledOnVectorException()
×
820

821
        session = get_session()
1✔
822

823
        url = (
1✔
824
            req.Request(
825
                "GET",
826
                url=f"{session.server_url}/workflow/{self.__workflow_id}/vectorStream",
827
                params={
828
                    "resultType": "arrow",
829
                    "spatialBounds": query_rectangle.bbox_str,
830
                    "timeInterval": query_rectangle.time_str,
831
                    "spatialResolution": str(query_rectangle.spatial_resolution),
832
                },
833
            )
834
            .prepare()
835
            .url
836
        )
837

838
        if url is None:
1✔
839
            raise InputException("Invalid websocket url")
×
840

841
        async with websockets.asyncio.client.connect(
1✔
842
            uri=self.__replace_http_with_ws(url),
843
            extra_headers=session.auth_header,
844
            open_timeout=open_timeout,
845
            max_size=None,  # allow arbitrary large messages, since it is capped by the server's chunk size
846
        ) as websocket:
847
            batch_bytes: bytes | None = None
1✔
848

849
            while websocket.state == websockets.protocol.State.OPEN:
1✔
850

851
                async def read_new_bytes() -> bytes | None:
1✔
852
                    # already send the next request to speed up the process
853
                    try:
1✔
854
                        await websocket.send("NEXT")
1✔
855
                    except websockets.exceptions.ConnectionClosed:
×
856
                        # the websocket connection is already closed, we cannot read anymore
857
                        return None
×
858

859
                    try:
1✔
860
                        data: str | bytes = await websocket.recv()
1✔
861

862
                        if isinstance(data, str):
1✔
863
                            # the server sent an error message
864
                            raise GeoEngineException({"error": data})
×
865

866
                        return data
1✔
867
                    except websockets.exceptions.ConnectionClosedOK:
×
868
                        # the websocket connection closed gracefully, so we stop reading
869
                        return None
×
870

871
                (batch_bytes, batch) = await asyncio.gather(
1✔
872
                    read_new_bytes(),
873
                    # asyncio.to_thread(process_bytes, batch_bytes), # TODO: use this when min Python version is 3.9
874
                    backports.to_thread(process_bytes, batch_bytes),
875
                )
876

877
                if batch is not None:
1✔
878
                    yield batch
1✔
879

880
            # process the last tile
881
            batch = process_bytes(batch_bytes)
1✔
882

883
            if batch is not None:
1✔
884
                yield batch
1✔
885

886
    async def vector_stream_into_geopandas(
1✔
887
        self,
888
        query_rectangle: QueryRectangle,
889
        time_start_column: str = "time_start",
890
        time_end_column: str = "time_end",
891
        open_timeout: int = 60,
892
    ) -> gpd.GeoDataFrame:
893
        """
894
        Stream the workflow result into memory and output a single geo data frame.
895

896
        NOTE: You can run out of memory if the query rectangle is too large.
897
        """
898

899
        chunk_stream = self.vector_stream(
1✔
900
            query_rectangle,
901
            time_start_column=time_start_column,
902
            time_end_column=time_end_column,
903
            open_timeout=open_timeout,
904
        )
905

906
        data_frame: gpd.GeoDataFrame | None = None
1✔
907
        chunk: gpd.GeoDataFrame | None = None
1✔
908

909
        async def read_dataframe() -> gpd.GeoDataFrame | None:
1✔
910
            try:
1✔
911
                return await chunk_stream.__anext__()
1✔
912
            except StopAsyncIteration:
1✔
913
                return None
1✔
914

915
        def merge_dataframes(df_a: gpd.GeoDataFrame | None, df_b: gpd.GeoDataFrame | None) -> gpd.GeoDataFrame | None:
1✔
916
            if df_a is None:
1✔
917
                return df_b
1✔
918

919
            if df_b is None:
1✔
920
                return df_a
×
921

922
            return pd.concat([df_a, df_b], ignore_index=True)
1✔
923

924
        while True:
1✔
925
            (chunk, data_frame) = await asyncio.gather(
1✔
926
                read_dataframe(),
927
                backports.to_thread(merge_dataframes, data_frame, chunk),
928
                # TODO: use this when min Python version is 3.9
929
                # asyncio.to_thread(merge_dataframes, data_frame, chunk),
930
            )
931

932
            # we can stop when the chunk stream is exhausted
933
            if chunk is None:
1✔
934
                break
1✔
935

936
        return data_frame
1✔
937

938
    def __replace_http_with_ws(self, url: str) -> str:
1✔
939
        """
940
        Replace the protocol in the url from `http` to `ws`.
941

942
        For the websockets library, it is necessary that the url starts with `ws://`.
943
        For HTTPS, we need to use `wss://` instead.
944
        """
945

946
        [protocol, url_part] = url.split("://", maxsplit=1)
1✔
947

948
        ws_prefix = "wss://" if "s" in protocol.lower() else "ws://"
1✔
949

950
        return f"{ws_prefix}{url_part}"
1✔
951

952

953
def register_workflow(workflow: dict[str, Any] | WorkflowBuilderOperator, timeout: int = 60) -> Workflow:
1✔
954
    """
955
    Register a workflow in Geo Engine and receive a `WorkflowId`
956
    """
957

958
    if isinstance(workflow, WorkflowBuilderOperator):
1✔
959
        workflow = workflow.to_workflow_dict()
1✔
960

961
    workflow_model = geoengine_openapi_client.Workflow.from_dict(workflow)
1✔
962

963
    if workflow_model is None:
1✔
964
        raise InputException("Invalid workflow definition")
×
965

966
    session = get_session()
1✔
967

968
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
969
        workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
970
        response = workflows_api.register_workflow_handler(workflow_model, _request_timeout=timeout)
1✔
971

972
    return Workflow(WorkflowId.from_response(response))
1✔
973

974

975
def workflow_by_id(workflow_id: UUID) -> Workflow:
1✔
976
    """
977
    Create a workflow object from a workflow id
978
    """
979

980
    # TODO: check that workflow exists
981

982
    return Workflow(WorkflowId(workflow_id))
1✔
983

984

985
def get_quota(user_id: UUID | None = None, timeout: int = 60) -> geoengine_openapi_client.Quota:
1✔
986
    """
987
    Gets a user's quota. Only admins can get other users' quota.
988
    """
989

990
    session = get_session()
1✔
991

992
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
993
        user_api = geoengine_openapi_client.UserApi(api_client)
1✔
994

995
        if user_id is None:
1✔
996
            return user_api.quota_handler(_request_timeout=timeout)
1✔
997

998
        return user_api.get_user_quota_handler(user_id, _request_timeout=timeout)
1✔
999

1000

1001
def update_quota(user_id: UUID, new_available_quota: int, timeout: int = 60) -> None:
1✔
1002
    """
1003
    Update a user's quota. Only admins can perform this operation.
1004
    """
1005

1006
    session = get_session()
1✔
1007

1008
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
1009
        user_api = geoengine_openapi_client.UserApi(api_client)
1✔
1010
        user_api.update_user_quota_handler(
1✔
1011
            user_id, geoengine_openapi_client.UpdateQuota(available=new_available_quota), _request_timeout=timeout
1012
        )
1013

1014

1015
def data_usage(offset: int = 0, limit: int = 10) -> list[geoengine_openapi_client.DataUsage]:
1✔
1016
    """
1017
    Get data usage
1018
    """
1019

1020
    session = get_session()
1✔
1021

1022
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
1023
        user_api = geoengine_openapi_client.UserApi(api_client)
1✔
1024
        response = user_api.data_usage_handler(offset=offset, limit=limit)
1✔
1025

1026
        # create dataframe from response
1027
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1028
        df = pd.DataFrame(usage_dicts)
1✔
1029
        if "timestamp" in df.columns:
1✔
1030
            df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)
1✔
1031

1032
    return df
1✔
1033

1034

1035
def data_usage_summary(
1✔
1036
    granularity: geoengine_openapi_client.UsageSummaryGranularity,
1037
    dataset: str | None = None,
1038
    offset: int = 0,
1039
    limit: int = 10,
1040
) -> pd.DataFrame:
1041
    """
1042
    Get data usage summary
1043
    """
1044

1045
    session = get_session()
1✔
1046

1047
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
1048
        user_api = geoengine_openapi_client.UserApi(api_client)
1✔
1049
        response = user_api.data_usage_summary_handler(
1✔
1050
            dataset=dataset, granularity=granularity, offset=offset, limit=limit
1051
        )
1052

1053
        # create dataframe from response
1054
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1055
        df = pd.DataFrame(usage_dicts)
1✔
1056
        if "timestamp" in df.columns:
1✔
1057
            df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)
1✔
1058

1059
    return df
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc