• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine-python / 15072011800

16 May 2025 03:29PM UTC coverage: 76.662% (-0.008%) from 76.67%
15072011800

Pull #222

github

web-flow
Merge b416386ac into 89c260aaf
Pull Request #222: build: update dependencies as of 2025-05-07

2756 of 3595 relevant lines covered (76.66%)

0.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.48
geoengine/workflow.py
1
'''
2
A workflow representation and methods on workflows
3
'''
4
# pylint: disable=too-many-lines
5
# TODO: split into multiple files
6

7
from __future__ import annotations
1✔
8

9
import asyncio
1✔
10
from collections import defaultdict
1✔
11
import json
1✔
12
from io import BytesIO
1✔
13
from logging import debug
1✔
14
from os import PathLike
1✔
15
from typing import Any, AsyncIterator, Dict, List, Optional, Union, Type, cast, TypedDict
1✔
16
from uuid import UUID
1✔
17

18
import geopandas as gpd
1✔
19
import pandas as pd
1✔
20
import numpy as np
1✔
21
import rasterio.io
1✔
22
import requests as req
1✔
23
import rioxarray
1✔
24
from PIL import Image
1✔
25
from owslib.util import Authentication, ResponseWrapper
1✔
26
from owslib.wcs import WebCoverageService
1✔
27
from vega import VegaLite
1✔
28
import websockets
1✔
29
import xarray as xr
1✔
30
import pyarrow as pa
1✔
31

32
import geoengine_openapi_client
1✔
33
from geoengine import api
1✔
34
from geoengine.auth import get_session
1✔
35
from geoengine.error import GeoEngineException, InputException, MethodNotCalledOnPlotException, \
1✔
36
    MethodNotCalledOnRasterException, MethodNotCalledOnVectorException, OGCXMLError
37
from geoengine import backports
1✔
38
from geoengine.types import ProvenanceEntry, QueryRectangle, RasterColorizer, ResultDescriptor, \
1✔
39
    VectorResultDescriptor, ClassificationMeasurement
40
from geoengine.tasks import Task, TaskId
1✔
41
from geoengine.workflow_builder.operators import Operator as WorkflowBuilderOperator
1✔
42
from geoengine.raster import RasterTile2D
1✔
43

44

45
# TODO: Define as recursive type when supported in mypy: https://github.com/python/mypy/issues/731
46
JsonType = Union[Dict[str, Any], List[Any], int, str, float, bool, Type[None]]
1✔
47

48
Axis = TypedDict('Axis', {'title': str})
1✔
49
Bin = TypedDict('Bin', {'binned': bool, 'step': float})
1✔
50
Field = TypedDict('Field', {'field': str})
1✔
51
DatasetIds = TypedDict('DatasetIds', {'upload': UUID, 'dataset': UUID})
1✔
52
Values = TypedDict('Values', {'binStart': float, 'binEnd': float, 'Frequency': int})
1✔
53
X = TypedDict('X', {'field': Field, 'bin': Bin, 'axis': Axis})
1✔
54
X2 = TypedDict('X2', {'field': Field})
1✔
55
Y = TypedDict('Y', {'field': Field, 'type': str})
1✔
56
Encoding = TypedDict('Encoding', {'x': X, 'x2': X2, 'y': Y})
1✔
57
VegaSpec = TypedDict('VegaSpec', {'$schema': str, 'data': List[Values], 'mark': str, 'encoding': Encoding})
1✔
58

59

60
class WorkflowId:
1✔
61
    '''
62
    A wrapper around a workflow UUID
63
    '''
64

65
    __workflow_id: UUID
1✔
66

67
    def __init__(self, workflow_id: UUID) -> None:
1✔
68
        self.__workflow_id = workflow_id
1✔
69

70
    @classmethod
1✔
71
    def from_response(cls, response: geoengine_openapi_client.IdResponse) -> WorkflowId:
1✔
72
        '''
73
        Create a `WorkflowId` from an http response
74
        '''
75
        return WorkflowId(UUID(response.id))
1✔
76

77
    def __str__(self) -> str:
1✔
78
        return str(self.__workflow_id)
1✔
79

80
    def __repr__(self) -> str:
1✔
81
        return str(self)
1✔
82

83

84
class RasterStreamProcessing:
1✔
85
    '''
86
    Helper class to process raster stream data
87
    '''
88

89
    @classmethod
1✔
90
    def read_arrow_ipc(cls, arrow_ipc: bytes) -> pa.RecordBatch:
1✔
91
        '''Read an Arrow IPC file from a byte array'''
92

93
        reader = pa.ipc.open_file(arrow_ipc)
1✔
94
        # We know from the backend that there is only one record batch
95
        record_batch = reader.get_record_batch(0)
1✔
96
        return record_batch
1✔
97

98
    @classmethod
1✔
99
    def process_bytes(cls, tile_bytes: Optional[bytes]) -> Optional[RasterTile2D]:
1✔
100
        '''Process a tile from a byte array'''
101

102
        if tile_bytes is None:
1✔
103
            return None
1✔
104

105
        # process the received data
106
        record_batch = RasterStreamProcessing.read_arrow_ipc(tile_bytes)
1✔
107
        tile = RasterTile2D.from_ge_record_batch(record_batch)
1✔
108

109
        return tile
1✔
110

111
    @classmethod
1✔
112
    def merge_tiles(cls, tiles: List[xr.DataArray]) -> Optional[xr.DataArray]:
1✔
113
        '''Merge a list of tiles into a single xarray'''
114

115
        if len(tiles) == 0:
1✔
116
            return None
×
117

118
        # group the tiles by band
119
        tiles_by_band: Dict[int, List[xr.DataArray]] = defaultdict(list)
1✔
120
        for tile in tiles:
1✔
121
            band = tile.band.item()  # assuming 'band' is a coordinate with a single value
1✔
122
            tiles_by_band[band].append(tile)
1✔
123

124
        # build one spatial tile per band
125
        combined_by_band = []
1✔
126
        for band_tiles in tiles_by_band.values():
1✔
127
            combined = xr.combine_by_coords(band_tiles)
1✔
128
            # `combine_by_coords` always returns a `DataArray` for single variable input arrays.
129
            # This assertion verifies this for mypy
130
            assert isinstance(combined, xr.DataArray)
1✔
131
            combined_by_band.append(combined)
1✔
132

133
        # build one array with all bands and geo coordinates
134
        combined_tile = xr.concat(combined_by_band, dim='band')
1✔
135

136
        return combined_tile
1✔
137

138

139
class Workflow:
1✔
140
    '''
141
    Holds a workflow id and allows querying data
142
    '''
143

144
    __workflow_id: WorkflowId
1✔
145
    __result_descriptor: ResultDescriptor
1✔
146

147
    def __init__(self, workflow_id: WorkflowId) -> None:
1✔
148
        self.__workflow_id = workflow_id
1✔
149
        self.__result_descriptor = self.__query_result_descriptor()
1✔
150

151
    def __str__(self) -> str:
1✔
152
        return str(self.__workflow_id)
1✔
153

154
    def __repr__(self) -> str:
1✔
155
        return repr(self.__workflow_id)
1✔
156

157
    def __query_result_descriptor(self, timeout: int = 60) -> ResultDescriptor:
1✔
158
        '''
159
        Query the metadata of the workflow result
160
        '''
161

162
        session = get_session()
1✔
163

164
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
165
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
166
            response = workflows_api.get_workflow_metadata_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
167

168
        debug(response)
1✔
169

170
        return ResultDescriptor.from_response(response)
1✔
171

172
    def get_result_descriptor(self) -> ResultDescriptor:
1✔
173
        '''
174
        Return the metadata of the workflow result
175
        '''
176

177
        return self.__result_descriptor
1✔
178

179
    def workflow_definition(self, timeout: int = 60) -> geoengine_openapi_client.Workflow:
1✔
180
        '''Return the workflow definition for this workflow'''
181

182
        session = get_session()
1✔
183

184
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
185
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
186
            response = workflows_api.load_workflow_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
187

188
        return response
1✔
189

190
    def get_dataframe(
1✔
191
            self,
192
            bbox: QueryRectangle,
193
            timeout: int = 3600,
194
            resolve_classifications: bool = False
195
    ) -> gpd.GeoDataFrame:
196
        '''
197
        Query a workflow and return the WFS result as a GeoPandas `GeoDataFrame`
198
        '''
199

200
        if not self.__result_descriptor.is_vector_result():
1✔
201
            raise MethodNotCalledOnVectorException()
1✔
202

203
        session = get_session()
1✔
204

205
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
206
            wfs_api = geoengine_openapi_client.OGCWFSApi(api_client)
1✔
207
            response = wfs_api.wfs_feature_handler(
1✔
208
                workflow=str(self.__workflow_id),
209
                service=geoengine_openapi_client.WfsService(geoengine_openapi_client.WfsService.WFS),
210
                request=geoengine_openapi_client.GetFeatureRequest(
211
                    geoengine_openapi_client.GetFeatureRequest.GETFEATURE
212
                ),
213
                type_names=str(self.__workflow_id),
214
                bbox=bbox.bbox_str,
215
                version=geoengine_openapi_client.WfsVersion(geoengine_openapi_client.WfsVersion.ENUM_2_DOT_0_DOT_0),
216
                time=bbox.time_str,
217
                srs_name=bbox.srs,
218
                query_resolution=str(bbox.spatial_resolution),
219
                _request_timeout=timeout
220
            )
221

222
        def geo_json_with_time_to_geopandas(geo_json):
1✔
223
            '''
224
            GeoJson has no standard for time, so we parse the when field
225
            separately and attach it to the data frame as columns `start`
226
            and `end`.
227
            '''
228

229
            data = gpd.GeoDataFrame.from_features(geo_json)
1✔
230
            data = data.set_crs(bbox.srs, allow_override=True)
1✔
231

232
            start = [f['when']['start'] for f in geo_json['features']]
1✔
233
            end = [f['when']['end'] for f in geo_json['features']]
1✔
234

235
            # TODO: find a good way to infer BoT/EoT
236

237
            data['start'] = gpd.pd.to_datetime(start, errors='coerce')
1✔
238
            data['end'] = gpd.pd.to_datetime(end, errors='coerce')
1✔
239

240
            return data
1✔
241

242
        def transform_classifications(data: gpd.GeoDataFrame):
1✔
243
            result_descriptor: VectorResultDescriptor = self.__result_descriptor  # type: ignore
×
244
            for (column, info) in result_descriptor.columns.items():
×
245
                if isinstance(info.measurement, ClassificationMeasurement):
×
246
                    measurement: ClassificationMeasurement = info.measurement
×
247
                    classes = measurement.classes
×
248
                    data[column] = data[column].apply(lambda x: classes[x])  # pylint: disable=cell-var-from-loop
×
249

250
            return data
×
251

252
        result = geo_json_with_time_to_geopandas(response.to_dict())
1✔
253

254
        if resolve_classifications:
1✔
255
            result = transform_classifications(result)
×
256

257
        return result
1✔
258

259
    def wms_get_map_as_image(self, bbox: QueryRectangle, raster_colorizer: RasterColorizer) -> Image.Image:
1✔
260
        '''Return the result of a WMS request as a PIL Image'''
261

262
        if not self.__result_descriptor.is_raster_result():
1✔
263
            raise MethodNotCalledOnRasterException()
×
264

265
        session = get_session()
1✔
266

267
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
268
            wms_api = geoengine_openapi_client.OGCWMSApi(api_client)
1✔
269
            response = wms_api.wms_map_handler(
1✔
270
                workflow=str(self),
271
                version=geoengine_openapi_client.WmsVersion(geoengine_openapi_client.WmsVersion.ENUM_1_DOT_3_DOT_0),
272
                service=geoengine_openapi_client.WmsService(geoengine_openapi_client.WmsService.WMS),
273
                request=geoengine_openapi_client.GetMapRequest(geoengine_openapi_client.GetMapRequest.GETMAP),
274
                width=int((bbox.spatial_bounds.xmax - bbox.spatial_bounds.xmin) / bbox.spatial_resolution.x_resolution),
275
                height=int((bbox.spatial_bounds.ymax - bbox.spatial_bounds.ymin) / bbox.spatial_resolution.y_resolution),  # pylint: disable=line-too-long
276
                bbox=bbox.bbox_ogc_str,
277
                format=geoengine_openapi_client.GetMapFormat(geoengine_openapi_client.GetMapFormat.IMAGE_SLASH_PNG),
278
                layers=str(self),
279
                styles='custom:' + raster_colorizer.to_api_dict().to_json(),
280
                crs=bbox.srs,
281
                time=bbox.time_str
282
            )
283

284
        if OGCXMLError.is_ogc_error(response):
1✔
285
            raise OGCXMLError(response)
1✔
286

287
        return Image.open(BytesIO(response))
1✔
288

289
    def plot_json(self, bbox: QueryRectangle, timeout: int = 3600) -> geoengine_openapi_client.WrappedPlotOutput:
1✔
290
        '''
291
        Query a workflow and return the plot chart result as WrappedPlotOutput
292
        '''
293

294
        if not self.__result_descriptor.is_plot_result():
1✔
295
            raise MethodNotCalledOnPlotException()
×
296

297
        session = get_session()
1✔
298

299
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
300
            plots_api = geoengine_openapi_client.PlotsApi(api_client)
1✔
301
            return plots_api.get_plot_handler(
1✔
302
                bbox.bbox_str,
303
                bbox.time_str,
304
                str(bbox.spatial_resolution),
305
                str(self.__workflow_id),
306
                bbox.srs,
307
                _request_timeout=timeout
308
            )
309

310
    def plot_chart(self, bbox: QueryRectangle, timeout: int = 3600) -> VegaLite:
1✔
311
        '''
312
        Query a workflow and return the plot chart result as a vega plot
313
        '''
314

315
        response = self.plot_json(bbox, timeout)
1✔
316
        vega_spec: VegaSpec = json.loads(response.data['vegaString'])
1✔
317

318
        return VegaLite(vega_spec)
1✔
319

320
    def __request_wcs(
1✔
321
        self,
322
        bbox: QueryRectangle,
323
        timeout=3600,
324
        file_format: str = 'image/tiff',
325
        force_no_data_value: Optional[float] = None
326
    ) -> ResponseWrapper:
327
        '''
328
        Query a workflow and return the coverage
329

330
        Parameters
331
        ----------
332
        bbox : A bounding box for the query
333
        timeout : HTTP request timeout in seconds
334
        file_format : The format of the returned raster
335
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
336
            Otherwise, use the Geo Engine will produce masked rasters.
337
        '''
338

339
        if not self.__result_descriptor.is_raster_result():
1✔
340
            raise MethodNotCalledOnRasterException()
×
341

342
        session = get_session()
1✔
343

344
        # TODO: properly build CRS string for bbox
345
        crs = f'urn:ogc:def:crs:{bbox.srs.replace(":", "::")}'
1✔
346

347
        wcs_url = f'{session.server_url}/wcs/{self.__workflow_id}'
1✔
348
        wcs = WebCoverageService(
1✔
349
            wcs_url,
350
            version='1.1.1',
351
            auth=Authentication(auth_delegate=session.requests_bearer_auth()),
352
        )
353

354
        [resx, resy] = bbox.resolution_ogc
1✔
355

356
        kwargs = {}
1✔
357

358
        if force_no_data_value is not None:
1✔
359
            kwargs["nodatavalue"] = str(float(force_no_data_value))
×
360

361
        return wcs.getCoverage(
1✔
362
            identifier=f'{self.__workflow_id}',
363
            bbox=bbox.bbox_ogc,
364
            time=[bbox.time_str],
365
            format=file_format,
366
            crs=crs,
367
            resx=resx,
368
            resy=resy,
369
            timeout=timeout,
370
            **kwargs
371
        )
372

373
    def __get_wcs_tiff_as_memory_file(
1✔
374
        self,
375
        bbox: QueryRectangle,
376
        timeout=3600,
377
        force_no_data_value: Optional[float] = None
378
    ) -> rasterio.io.MemoryFile:
379
        '''
380
        Query a workflow and return the raster result as a memory mapped GeoTiff
381

382
        Parameters
383
        ----------
384
        bbox : A bounding box for the query
385
        timeout : HTTP request timeout in seconds
386
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
387
            Otherwise, use the Geo Engine will produce masked rasters.
388
        '''
389

390
        response = self.__request_wcs(bbox, timeout, 'image/tiff', force_no_data_value).read()
1✔
391

392
        # response is checked via `raise_on_error` in `getCoverage` / `openUrl`
393

394
        memory_file = rasterio.io.MemoryFile(response)
1✔
395

396
        return memory_file
1✔
397

398
    def get_array(
1✔
399
        self,
400
        bbox: QueryRectangle,
401
        timeout=3600,
402
        force_no_data_value: Optional[float] = None
403
    ) -> np.ndarray:
404
        '''
405
        Query a workflow and return the raster result as a numpy array
406

407
        Parameters
408
        ----------
409
        bbox : A bounding box for the query
410
        timeout : HTTP request timeout in seconds
411
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
412
            Otherwise, use the Geo Engine will produce masked rasters.
413
        '''
414

415
        with self.__get_wcs_tiff_as_memory_file(
1✔
416
            bbox,
417
            timeout,
418
            force_no_data_value
419
        ) as memfile, memfile.open() as dataset:
420
            array = dataset.read(1)
1✔
421

422
            return array
1✔
423

424
    def get_xarray(
1✔
425
        self,
426
        bbox: QueryRectangle,
427
        timeout=3600,
428
        force_no_data_value: Optional[float] = None
429
    ) -> xr.DataArray:
430
        '''
431
        Query a workflow and return the raster result as a georeferenced xarray
432

433
        Parameters
434
        ----------
435
        bbox : A bounding box for the query
436
        timeout : HTTP request timeout in seconds
437
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
438
            Otherwise, use the Geo Engine will produce masked rasters.
439
        '''
440

441
        with self.__get_wcs_tiff_as_memory_file(
1✔
442
            bbox,
443
            timeout,
444
            force_no_data_value
445
        ) as memfile, memfile.open() as dataset:
446
            data_array = rioxarray.open_rasterio(dataset)
1✔
447

448
            # helping mypy with inference
449
            assert isinstance(data_array, xr.DataArray)
1✔
450

451
            rio: xr.DataArray = data_array.rio
1✔
452
            rio.update_attrs({
1✔
453
                'crs': rio.crs,
454
                'res': rio.resolution(),
455
                'transform': rio.transform(),
456
            }, inplace=True)
457

458
            # TODO: add time information to dataset
459
            return data_array.load()
1✔
460

461
    # pylint: disable=too-many-arguments,too-many-positional-arguments
462
    def download_raster(
1✔
463
        self,
464
        bbox: QueryRectangle,
465
        file_path: str,
466
        timeout=3600,
467
        file_format: str = 'image/tiff',
468
        force_no_data_value: Optional[float] = None
469
    ) -> None:
470
        '''
471
        Query a workflow and save the raster result as a file on disk
472

473
        Parameters
474
        ----------
475
        bbox : A bounding box for the query
476
        file_path : The path to the file to save the raster to
477
        timeout : HTTP request timeout in seconds
478
        file_format : The format of the returned raster
479
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
480
            Otherwise, use the Geo Engine will produce masked rasters.
481
        '''
482

483
        response = self.__request_wcs(bbox, timeout, file_format, force_no_data_value)
×
484

485
        with open(file_path, 'wb') as file:
×
486
            file.write(response.read())
×
487

488
    def get_provenance(self, timeout: int = 60) -> List[ProvenanceEntry]:
1✔
489
        '''
490
        Query the provenance of the workflow
491
        '''
492

493
        session = get_session()
1✔
494

495
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
496
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
497
            response = workflows_api.get_workflow_provenance_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
498

499
        return [ProvenanceEntry.from_response(item) for item in response]
1✔
500

501
    def metadata_zip(self, path: Union[PathLike, BytesIO], timeout: int = 60) -> None:
1✔
502
        '''
503
        Query workflow metadata and citations and stores it as zip file to `path`
504
        '''
505

506
        session = get_session()
×
507

508
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
×
509
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
×
510
            response = workflows_api.get_workflow_all_metadata_zip_handler(
×
511
                str(self.__workflow_id),
512
                _request_timeout=timeout
513
            )
514

515
        if isinstance(path, BytesIO):
×
516
            path.write(response)
×
517
        else:
518
            with open(path, 'wb') as file:
×
519
                file.write(response)
×
520

521
    # pylint: disable=too-many-positional-arguments,too-many-positional-arguments
522
    def save_as_dataset(
1✔
523
            self,
524
            query_rectangle: geoengine_openapi_client.RasterQueryRectangle,
525
            name: Optional[str],
526
            display_name: str,
527
            description: str = '',
528
            timeout: int = 3600) -> Task:
529
        '''Init task to store the workflow result as a layer'''
530

531
        # Currently, it only works for raster results
532
        if not self.__result_descriptor.is_raster_result():
1✔
533
            raise MethodNotCalledOnRasterException()
×
534

535
        session = get_session()
1✔
536

537
        with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
538
            workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
539
            response = workflows_api.dataset_from_workflow_handler(
1✔
540
                str(self.__workflow_id),
541
                geoengine_openapi_client.RasterDatasetFromWorkflow(
542
                    name=name,
543
                    display_name=display_name,
544
                    description=description,
545
                    query=query_rectangle
546
                ),
547
                _request_timeout=timeout
548
            )
549

550
        return Task(TaskId.from_response(response))
1✔
551

552
    async def raster_stream(
1✔
553
        self,
554
        query_rectangle: QueryRectangle,
555
        open_timeout: int = 60,
556
        bands: Optional[List[int]] = None  # TODO: move into query rectangle?
557
    ) -> AsyncIterator[RasterTile2D]:
558
        '''Stream the workflow result as series of RasterTile2D (transformable to numpy and xarray)'''
559

560
        if bands is None:
1✔
561
            bands = [0]
1✔
562

563
        if len(bands) == 0:
1✔
564
            raise InputException('At least one band must be specified')
×
565

566
        # Currently, it only works for raster results
567
        if not self.__result_descriptor.is_raster_result():
1✔
568
            raise MethodNotCalledOnRasterException()
×
569

570
        session = get_session()
1✔
571

572
        url = req.Request(
1✔
573
            'GET',
574
            url=f'{session.server_url}/workflow/{self.__workflow_id}/rasterStream',
575
            params={
576
                'resultType': 'arrow',
577
                'spatialBounds': query_rectangle.bbox_str,
578
                'timeInterval': query_rectangle.time_str,
579
                'spatialResolution': str(query_rectangle.spatial_resolution),
580
                'attributes': ','.join(map(str, bands))
581
            },
582
        ).prepare().url
583

584
        if url is None:
1✔
585
            raise InputException('Invalid websocket url')
×
586

587
        async with websockets.asyncio.client.connect(
1✔
588
            uri=self.__replace_http_with_ws(url),
589
            extra_headers=session.auth_header,
590
            open_timeout=open_timeout,
591
            max_size=None,
592
        ) as websocket:
593

594
            tile_bytes: Optional[bytes] = None
1✔
595

596
            while websocket.state == websockets.protocol.State.OPEN:
1✔
597
                async def read_new_bytes() -> Optional[bytes]:
1✔
598
                    # already send the next request to speed up the process
599
                    try:
1✔
600
                        await websocket.send("NEXT")
1✔
601
                    except websockets.exceptions.ConnectionClosed:
×
602
                        # the websocket connection is already closed, we cannot read anymore
603
                        return None
×
604

605
                    try:
1✔
606
                        data: Union[str, bytes] = await websocket.recv()
1✔
607

608
                        if isinstance(data, str):
1✔
609
                            # the server sent an error message
610
                            raise GeoEngineException({'error': data})
×
611

612
                        return data
1✔
613
                    except websockets.exceptions.ConnectionClosedOK:
×
614
                        # the websocket connection closed gracefully, so we stop reading
615
                        return None
×
616

617
                (tile_bytes, tile) = await asyncio.gather(
1✔
618
                    read_new_bytes(),
619
                    # asyncio.to_thread(process_bytes, tile_bytes), # TODO: use this when min Python version is 3.9
620
                    backports.to_thread(RasterStreamProcessing.process_bytes, tile_bytes),
621
                )
622

623
                if tile is not None:
1✔
624
                    yield tile
1✔
625

626
            # process the last tile
627
            tile = RasterStreamProcessing.process_bytes(tile_bytes)
1✔
628

629
            if tile is not None:
1✔
630
                yield tile
1✔
631

632
    async def raster_stream_into_xarray(
1✔
633
        self,
634
        query_rectangle: QueryRectangle,
635
        clip_to_query_rectangle: bool = False,
636
        open_timeout: int = 60,
637
        bands: Optional[List[int]] = None  # TODO: move into query rectangle?
638
    ) -> xr.DataArray:
639
        '''
640
        Stream the workflow result into memory and output a single xarray.
641

642
        NOTE: You can run out of memory if the query rectangle is too large.
643
        '''
644

645
        if bands is None:
1✔
646
            bands = [0]
1✔
647

648
        if len(bands) == 0:
1✔
649
            raise InputException('At least one band must be specified')
×
650

651
        tile_stream = self.raster_stream(
1✔
652
            query_rectangle,
653
            open_timeout=open_timeout,
654
            bands=bands
655
        )
656

657
        timestep_xarrays: List[xr.DataArray] = []
1✔
658

659
        spatial_clip_bounds = query_rectangle.spatial_bounds if clip_to_query_rectangle else None
1✔
660

661
        async def read_tiles(
1✔
662
            remainder_tile: Optional[RasterTile2D]
663
        ) -> tuple[List[xr.DataArray], Optional[RasterTile2D]]:
664
            last_timestep: Optional[np.datetime64] = None
1✔
665
            tiles = []
1✔
666

667
            if remainder_tile is not None:
1✔
668
                last_timestep = remainder_tile.time_start_ms
1✔
669
                xr_tile = remainder_tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
670
                tiles.append(xr_tile)
1✔
671

672
            async for tile in tile_stream:
1✔
673
                timestep: np.datetime64 = tile.time_start_ms
1✔
674
                if last_timestep is None:
1✔
675
                    last_timestep = timestep
1✔
676
                elif last_timestep != timestep:
1✔
677
                    return tiles, tile
1✔
678

679
                xr_tile = tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
680
                tiles.append(xr_tile)
1✔
681

682
            # this seems to be the last time step, so just return tiles
683
            return tiles, None
1✔
684

685
        (tiles, remainder_tile) = await read_tiles(None)
1✔
686

687
        while len(tiles):
1✔
688
            ((new_tiles, new_remainder_tile), new_timestep_xarray) = await asyncio.gather(
1✔
689
                read_tiles(remainder_tile),
690
                backports.to_thread(RasterStreamProcessing.merge_tiles, tiles)
691
                # asyncio.to_thread(merge_tiles, tiles), # TODO: use this when min Python version is 3.9
692
            )
693

694
            tiles = new_tiles
1✔
695
            remainder_tile = new_remainder_tile
1✔
696

697
            if new_timestep_xarray is not None:
1✔
698
                timestep_xarrays.append(new_timestep_xarray)
1✔
699

700
        output: xr.DataArray = cast(
1✔
701
            xr.DataArray,
702
            # await asyncio.to_thread( # TODO: use this when min Python version is 3.9
703
            await backports.to_thread(
704
                xr.concat,
705
                # TODO: This is a typings error, since the method accepts also a `xr.DataArray` and returns one
706
                cast(List[xr.Dataset], timestep_xarrays),
707
                dim='time'
708
            )
709
        )
710

711
        return output
1✔
712

713
    async def vector_stream(
1✔
714
            self,
715
            query_rectangle: QueryRectangle,
716
            time_start_column: str = 'time_start',
717
            time_end_column: str = 'time_end',
718
            open_timeout: int = 60) -> AsyncIterator[gpd.GeoDataFrame]:
719
        '''Stream the workflow result as series of `GeoDataFrame`s'''
720

721
        def read_arrow_ipc(arrow_ipc: bytes) -> pa.RecordBatch:
1✔
722
            reader = pa.ipc.open_file(arrow_ipc)
1✔
723
            # We know from the backend that there is only one record batch
724
            record_batch = reader.get_record_batch(0)
1✔
725
            return record_batch
1✔
726

727
        def create_geo_data_frame(record_batch: pa.RecordBatch,
1✔
728
                                  time_start_column: str,
729
                                  time_end_column: str) -> gpd.GeoDataFrame:
730
            metadata = record_batch.schema.metadata
1✔
731
            spatial_reference = metadata[b'spatialReference'].decode('utf-8')
1✔
732

733
            data_frame = record_batch.to_pandas()
1✔
734

735
            geometry = gpd.GeoSeries.from_wkt(data_frame[api.GEOMETRY_COLUMN_NAME])
1✔
736
            del data_frame[api.GEOMETRY_COLUMN_NAME]  # delete the duplicated column
1✔
737

738
            geo_data_frame = gpd.GeoDataFrame(
1✔
739
                data_frame,
740
                geometry=geometry,
741
                crs=spatial_reference,
742
            )
743

744
            # split time column
745
            geo_data_frame[[time_start_column, time_end_column]] = geo_data_frame[api.TIME_COLUMN_NAME].tolist()
1✔
746
            del geo_data_frame[api.TIME_COLUMN_NAME]  # delete the duplicated column
1✔
747

748
            # parse time columns
749
            for time_column in [time_start_column, time_end_column]:
1✔
750
                geo_data_frame[time_column] = pd.to_datetime(
1✔
751
                    geo_data_frame[time_column],
752
                    utc=True,
753
                    unit='ms',
754
                    # TODO: solve time conversion problem from Geo Engine to Python for large (+/-) time instances
755
                    errors='coerce',
756
                )
757

758
            return geo_data_frame
1✔
759

760
        def process_bytes(batch_bytes: Optional[bytes]) -> Optional[gpd.GeoDataFrame]:
1✔
761
            if batch_bytes is None:
1✔
762
                return None
1✔
763

764
            # process the received data
765
            record_batch = read_arrow_ipc(batch_bytes)
1✔
766
            tile = create_geo_data_frame(
1✔
767
                record_batch,
768
                time_start_column=time_start_column,
769
                time_end_column=time_end_column,
770
            )
771

772
            return tile
1✔
773

774
        # Currently, it only works for raster results
775
        if not self.__result_descriptor.is_vector_result():
1✔
776
            raise MethodNotCalledOnVectorException()
×
777

778
        session = get_session()
1✔
779

780
        url = req.Request(
1✔
781
            'GET',
782
            url=f'{session.server_url}/workflow/{self.__workflow_id}/vectorStream',
783
            params={
784
                'resultType': 'arrow',
785
                'spatialBounds': query_rectangle.bbox_str,
786
                'timeInterval': query_rectangle.time_str,
787
                'spatialResolution': str(query_rectangle.spatial_resolution),
788
            },
789
        ).prepare().url
790

791
        if url is None:
1✔
792
            raise InputException('Invalid websocket url')
×
793

794
        async with websockets.asyncio.client.connect(
1✔
795
            uri=self.__replace_http_with_ws(url),
796
            extra_headers=session.auth_header,
797
            open_timeout=open_timeout,
798
            max_size=None,  # allow arbitrary large messages, since it is capped by the server's chunk size
799
        ) as websocket:
800

801
            batch_bytes: Optional[bytes] = None
1✔
802

803
            while websocket.state == websockets.protocol.State.OPEN:
1✔
804
                async def read_new_bytes() -> Optional[bytes]:
1✔
805
                    # already send the next request to speed up the process
806
                    try:
1✔
807
                        await websocket.send("NEXT")
1✔
808
                    except websockets.exceptions.ConnectionClosed:
×
809
                        # the websocket connection is already closed, we cannot read anymore
810
                        return None
×
811

812
                    try:
1✔
813
                        data: Union[str, bytes] = await websocket.recv()
1✔
814

815
                        if isinstance(data, str):
1✔
816
                            # the server sent an error message
817
                            raise GeoEngineException({'error': data})
×
818

819
                        return data
1✔
820
                    except websockets.exceptions.ConnectionClosedOK:
×
821
                        # the websocket connection closed gracefully, so we stop reading
822
                        return None
×
823

824
                (batch_bytes, batch) = await asyncio.gather(
1✔
825
                    read_new_bytes(),
826
                    # asyncio.to_thread(process_bytes, batch_bytes), # TODO: use this when min Python version is 3.9
827
                    backports.to_thread(process_bytes, batch_bytes),
828
                )
829

830
                if batch is not None:
1✔
831
                    yield batch
1✔
832

833
            # process the last tile
834
            batch = process_bytes(batch_bytes)
1✔
835

836
            if batch is not None:
1✔
837
                yield batch
1✔
838

839
    async def vector_stream_into_geopandas(
1✔
840
            self,
841
            query_rectangle: QueryRectangle,
842
            time_start_column: str = 'time_start',
843
            time_end_column: str = 'time_end',
844
            open_timeout: int = 60) -> gpd.GeoDataFrame:
845
        '''
846
        Stream the workflow result into memory and output a single geo data frame.
847

848
        NOTE: You can run out of memory if the query rectangle is too large.
849
        '''
850

851
        chunk_stream = self.vector_stream(
1✔
852
            query_rectangle,
853
            time_start_column=time_start_column,
854
            time_end_column=time_end_column,
855
            open_timeout=open_timeout,
856
        )
857

858
        data_frame: Optional[gpd.GeoDataFrame] = None
1✔
859
        chunk: Optional[gpd.GeoDataFrame] = None
1✔
860

861
        async def read_dataframe() -> Optional[gpd.GeoDataFrame]:
1✔
862
            try:
1✔
863
                return await chunk_stream.__anext__()
1✔
864
            except StopAsyncIteration:
1✔
865
                return None
1✔
866

867
        def merge_dataframes(
1✔
868
            df_a: Optional[gpd.GeoDataFrame],
869
            df_b: Optional[gpd.GeoDataFrame]
870
        ) -> Optional[gpd.GeoDataFrame]:
871
            if df_a is None:
1✔
872
                return df_b
1✔
873

874
            if df_b is None:
1✔
875
                return df_a
×
876

877
            return pd.concat([df_a, df_b], ignore_index=True)
1✔
878

879
        while True:
1✔
880
            (chunk, data_frame) = await asyncio.gather(
1✔
881
                read_dataframe(),
882
                backports.to_thread(merge_dataframes, data_frame, chunk),
883
                # TODO: use this when min Python version is 3.9
884
                # asyncio.to_thread(merge_dataframes, data_frame, chunk),
885
            )
886

887
            # we can stop when the chunk stream is exhausted
888
            if chunk is None:
1✔
889
                break
1✔
890

891
        return data_frame
1✔
892

893
    def __replace_http_with_ws(self, url: str) -> str:
1✔
894
        '''
895
        Replace the protocol in the url from `http` to `ws`.
896

897
        For the websockets library, it is necessary that the url starts with `ws://`.
898
        For HTTPS, we need to use `wss://` instead.
899
        '''
900

901
        [protocol, url_part] = url.split('://', maxsplit=1)
1✔
902

903
        ws_prefix = 'wss://' if 's' in protocol.lower() else 'ws://'
1✔
904

905
        return f'{ws_prefix}{url_part}'
1✔
906

907

908
def register_workflow(workflow: Union[Dict[str, Any], WorkflowBuilderOperator], timeout: int = 60) -> Workflow:
1✔
909
    '''
910
    Register a workflow in Geo Engine and receive a `WorkflowId`
911
    '''
912

913
    if isinstance(workflow, WorkflowBuilderOperator):
1✔
914
        workflow = workflow.to_workflow_dict()
1✔
915

916
    workflow_model = geoengine_openapi_client.Workflow.from_dict(workflow)
1✔
917

918
    if workflow_model is None:
1✔
919
        raise InputException("Invalid workflow definition")
×
920

921
    session = get_session()
1✔
922

923
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
924
        workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
1✔
925
        response = workflows_api.register_workflow_handler(workflow_model, _request_timeout=timeout)
1✔
926

927
    return Workflow(WorkflowId.from_response(response))
1✔
928

929

930
def workflow_by_id(workflow_id: UUID) -> Workflow:
1✔
931
    '''
932
    Create a workflow object from a workflow id
933
    '''
934

935
    # TODO: check that workflow exists
936

937
    return Workflow(WorkflowId(workflow_id))
1✔
938

939

940
def get_quota(user_id: Optional[UUID] = None, timeout: int = 60) -> geoengine_openapi_client.Quota:
1✔
941
    '''
942
    Gets a user's quota. Only admins can get other users' quota.
943
    '''
944

945
    session = get_session()
×
946

947
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
×
948
        user_api = geoengine_openapi_client.UserApi(api_client)
×
949

950
        if user_id is None:
×
951
            return user_api.quota_handler(_request_timeout=timeout)
×
952

953
        return user_api.get_user_quota_handler(str(user_id), _request_timeout=timeout)
×
954

955

956
def update_quota(user_id: UUID, new_available_quota: int, timeout: int = 60) -> None:
1✔
957
    '''
958
    Update a user's quota. Only admins can perform this operation.
959
    '''
960

961
    session = get_session()
×
962

963
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
×
964
        user_api = geoengine_openapi_client.UserApi(api_client)
×
965
        user_api.update_user_quota_handler(
×
966
            str(user_id),
967
            geoengine_openapi_client.UpdateQuota(
968
                available=new_available_quota
969
            ),
970
            _request_timeout=timeout
971
        )
972

973

974
def data_usage(offset: int = 0, limit: int = 10) -> List[geoengine_openapi_client.DataUsage]:
1✔
975
    '''
976
    Get data usage
977
    '''
978

979
    session = get_session()
1✔
980

981
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
982
        user_api = geoengine_openapi_client.UserApi(api_client)
1✔
983
        response = user_api.data_usage_handler(offset=offset, limit=limit)
1✔
984

985
        # create dataframe from response
986
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
987
        df = pd.DataFrame(usage_dicts)
1✔
988
        if 'timestamp' in df.columns:
1✔
989
            df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
1✔
990

991
    return df
1✔
992

993

994
def data_usage_summary(granularity: geoengine_openapi_client.UsageSummaryGranularity,
1✔
995
                       dataset: Optional[str] = None,
996
                       offset: int = 0, limit: int = 10) -> pd.DataFrame:
997
    '''
998
    Get data usage summary
999
    '''
1000

1001
    session = get_session()
1✔
1002

1003
    with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
1✔
1004
        user_api = geoengine_openapi_client.UserApi(api_client)
1✔
1005
        response = user_api.data_usage_summary_handler(dataset=dataset, granularity=granularity,
1✔
1006
                                                       offset=offset, limit=limit)
1007

1008
        # create dataframe from response
1009
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1010
        df = pd.DataFrame(usage_dicts)
1✔
1011
        if 'timestamp' in df.columns:
1✔
1012
            df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
1✔
1013

1014
    return df
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc