• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine-python / 14441260704

14 Apr 2025 08:40AM UTC coverage: 75.452% (-1.2%) from 76.67%
14441260704

Pull #221

github

web-flow
Merge a9db07509 into 89c260aaf
Pull Request #221: Pixel_based_queries_rewrite

2837 of 3760 relevant lines covered (75.45%)

0.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.34
geoengine/workflow.py
1
'''
2
A workflow representation and methods on workflows
3
'''
4
# pylint: disable=too-many-lines
5
# TODO: split into multiple files
6

7
from __future__ import annotations
1✔
8

9
import asyncio
1✔
10
from collections import defaultdict
1✔
11
import json
1✔
12
from io import BytesIO
1✔
13
from logging import debug
1✔
14
from os import PathLike
1✔
15
from typing import Any, AsyncIterator, Dict, List, Optional, Union, Type, cast, TypedDict
1✔
16
from uuid import UUID
1✔
17

18
import geoengine_openapi_client as geoc
1✔
19
import geopandas as gpd
1✔
20
import pandas as pd
1✔
21
import numpy as np
1✔
22
import rasterio.io
1✔
23
import requests as req
1✔
24
import rioxarray
1✔
25
from PIL import Image
1✔
26
from owslib.util import Authentication, ResponseWrapper
1✔
27
from owslib.wcs import WebCoverageService
1✔
28
from vega import VegaLite
1✔
29
import websockets
1✔
30
import websockets.client
1✔
31
import xarray as xr
1✔
32
import pyarrow as pa
1✔
33

34

35
from geoengine import api
1✔
36
from geoengine.auth import get_session
1✔
37
from geoengine.error import GeoEngineException, InputException, MethodNotCalledOnPlotException, \
1✔
38
    MethodNotCalledOnRasterException, MethodNotCalledOnVectorException, OGCXMLError
39
from geoengine import backports
1✔
40
from geoengine.types import ProvenanceEntry, QueryRectangle, QueryRectangleWithResolution, RasterColorizer, \
1✔
41
    ResultDescriptor, SpatialPartition2D, VectorResultDescriptor, ClassificationMeasurement
42
from geoengine.tasks import Task, TaskId
1✔
43
from geoengine.workflow_builder.operators import Operator as WorkflowBuilderOperator
1✔
44
from geoengine.raster import RasterTile2D
1✔
45

46

47
# TODO: Define as recursive type when supported in mypy: https://github.com/python/mypy/issues/731
48
JsonType = Union[Dict[str, Any], List[Any], int, str, float, bool, Type[None]]
1✔
49

50
Axis = TypedDict('Axis', {'title': str})
1✔
51
Bin = TypedDict('Bin', {'binned': bool, 'step': float})
1✔
52
Field = TypedDict('Field', {'field': str})
1✔
53
DatasetIds = TypedDict('DatasetIds', {'upload': UUID, 'dataset': UUID})
1✔
54
Values = TypedDict('Values', {'binStart': float, 'binEnd': float, 'Frequency': int})
1✔
55
X = TypedDict('X', {'field': Field, 'bin': Bin, 'axis': Axis})
1✔
56
X2 = TypedDict('X2', {'field': Field})
1✔
57
Y = TypedDict('Y', {'field': Field, 'type': str})
1✔
58
Encoding = TypedDict('Encoding', {'x': X, 'x2': X2, 'y': Y})
1✔
59
VegaSpec = TypedDict('VegaSpec', {'$schema': str, 'data': List[Values], 'mark': str, 'encoding': Encoding})
1✔
60

61

62
class WorkflowId:
1✔
63
    '''
64
    A wrapper around a workflow UUID
65
    '''
66

67
    __workflow_id: UUID
1✔
68

69
    def __init__(self, workflow_id: UUID) -> None:
1✔
70
        self.__workflow_id = workflow_id
1✔
71

72
    @classmethod
1✔
73
    def from_response(cls, response: geoc.IdResponse) -> WorkflowId:
1✔
74
        '''
75
        Create a `WorkflowId` from an http response
76
        '''
77
        return WorkflowId(UUID(response.id))
1✔
78

79
    def __str__(self) -> str:
1✔
80
        return str(self.__workflow_id)
1✔
81

82
    def __repr__(self) -> str:
1✔
83
        return str(self)
1✔
84

85

86
class RasterStreamProcessing:
1✔
87
    '''
88
    Helper class to process raster stream data
89
    '''
90

91
    @classmethod
1✔
92
    def read_arrow_ipc(cls, arrow_ipc: bytes) -> pa.RecordBatch:
1✔
93
        '''Read an Arrow IPC file from a byte array'''
94

95
        reader = pa.ipc.open_file(arrow_ipc)
1✔
96
        # We know from the backend that there is only one record batch
97
        record_batch = reader.get_record_batch(0)
1✔
98
        return record_batch
1✔
99

100
    @classmethod
1✔
101
    def process_bytes(cls, tile_bytes: Optional[bytes]) -> Optional[RasterTile2D]:
1✔
102
        '''Process a tile from a byte array'''
103

104
        if tile_bytes is None:
1✔
105
            return None
1✔
106

107
        # process the received data
108
        record_batch = RasterStreamProcessing.read_arrow_ipc(tile_bytes)
1✔
109
        tile = RasterTile2D.from_ge_record_batch(record_batch)
1✔
110

111
        return tile
1✔
112

113
    @classmethod
1✔
114
    def merge_tiles(cls, tiles: List[xr.DataArray]) -> Optional[xr.DataArray]:
1✔
115
        '''Merge a list of tiles into a single xarray'''
116

117
        if len(tiles) == 0:
1✔
118
            return None
×
119

120
        # group the tiles by band
121
        tiles_by_band: Dict[int, List[xr.DataArray]] = defaultdict(list)
1✔
122
        for tile in tiles:
1✔
123
            band = tile.band.item()  # assuming 'band' is a coordinate with a single value
1✔
124
            tiles_by_band[band].append(tile)
1✔
125

126
        # build one spatial tile per band
127
        combined_by_band = []
1✔
128
        for band_tiles in tiles_by_band.values():
1✔
129
            combined = xr.combine_by_coords(band_tiles)
1✔
130
            # `combine_by_coords` always returns a `DataArray` for single variable input arrays.
131
            # This assertion verifies this for mypy
132
            assert isinstance(combined, xr.DataArray)
1✔
133
            combined_by_band.append(combined)
1✔
134

135
        # build one array with all bands and geo coordinates
136
        combined_tile = xr.concat(combined_by_band, dim='band')
1✔
137

138
        return combined_tile
1✔
139

140

141
class Workflow:
1✔
142
    '''
143
    Holds a workflow id and allows querying data
144
    '''
145

146
    __workflow_id: WorkflowId
1✔
147
    __result_descriptor: ResultDescriptor
1✔
148

149
    def __init__(self, workflow_id: WorkflowId) -> None:
1✔
150
        self.__workflow_id = workflow_id
1✔
151
        self.__result_descriptor = self.__query_result_descriptor()
1✔
152

153
    def __str__(self) -> str:
1✔
154
        return str(self.__workflow_id)
1✔
155

156
    def __repr__(self) -> str:
1✔
157
        return repr(self.__workflow_id)
1✔
158

159
    def __query_result_descriptor(self, timeout: int = 60) -> ResultDescriptor:
1✔
160
        '''
161
        Query the metadata of the workflow result
162
        '''
163

164
        session = get_session()
1✔
165

166
        with geoc.ApiClient(session.configuration) as api_client:
1✔
167
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
168
            response = workflows_api.get_workflow_metadata_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
169

170
        debug(response)
1✔
171

172
        return ResultDescriptor.from_response(response)
1✔
173

174
    def get_result_descriptor(self) -> ResultDescriptor:
1✔
175
        '''
176
        Return the metadata of the workflow result
177
        '''
178

179
        return self.__result_descriptor
1✔
180

181
    def workflow_definition(self, timeout: int = 60) -> geoc.Workflow:
1✔
182
        '''Return the workflow definition for this workflow'''
183

184
        session = get_session()
1✔
185

186
        with geoc.ApiClient(session.configuration) as api_client:
1✔
187
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
188
            response = workflows_api.load_workflow_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
189

190
        return response
1✔
191

192
    def get_dataframe(
1✔
193
            self,
194
            bbox: QueryRectangle,
195
            timeout: int = 3600,
196
            resolve_classifications: bool = False
197
    ) -> gpd.GeoDataFrame:
198
        '''
199
        Query a workflow and return the WFS result as a GeoPandas `GeoDataFrame`
200
        '''
201

202
        if not self.__result_descriptor.is_vector_result():
1✔
203
            raise MethodNotCalledOnVectorException()
1✔
204

205
        session = get_session()
1✔
206

207
        qres = None
1✔
208
        if isinstance(bbox, QueryRectangleWithResolution):
1✔
209
            qres = str(bbox.spatial_resolution)
1✔
210

211
        with geoc.ApiClient(session.configuration) as api_client:
1✔
212
            wfs_api = geoc.OGCWFSApi(api_client)
1✔
213
            response = wfs_api.wfs_feature_handler(
1✔
214
                workflow=str(self.__workflow_id),
215
                service=geoc.WfsService(geoc.WfsService.WFS),
216
                request=geoc.GetFeatureRequest(
217
                    geoc.GetFeatureRequest.GETFEATURE
218
                ),
219
                type_names=str(self.__workflow_id),
220
                bbox=bbox.bbox_str,
221
                version=geoc.WfsVersion(geoc.WfsVersion.ENUM_2_DOT_0_DOT_0),
222
                time=bbox.time_str,
223
                srs_name=bbox.srs,
224
                query_resolution=qres,
225
                _request_timeout=timeout
226
            )
227

228
        def geo_json_with_time_to_geopandas(geo_json):
1✔
229
            '''
230
            GeoJson has no standard for time, so we parse the when field
231
            separately and attach it to the data frame as columns `start`
232
            and `end`.
233
            '''
234

235
            data = gpd.GeoDataFrame.from_features(geo_json)
1✔
236
            data = data.set_crs(bbox.srs, allow_override=True)
1✔
237

238
            start = [f['when']['start'] for f in geo_json['features']]
1✔
239
            end = [f['when']['end'] for f in geo_json['features']]
1✔
240

241
            # TODO: find a good way to infer BoT/EoT
242

243
            data['start'] = gpd.pd.to_datetime(start, errors='coerce')
1✔
244
            data['end'] = gpd.pd.to_datetime(end, errors='coerce')
1✔
245

246
            return data
1✔
247

248
        def transform_classifications(data: gpd.GeoDataFrame):
1✔
249
            result_descriptor: VectorResultDescriptor = self.__result_descriptor  # type: ignore
×
250
            for (column, info) in result_descriptor.columns.items():
×
251
                if isinstance(info.measurement, ClassificationMeasurement):
×
252
                    measurement: ClassificationMeasurement = info.measurement
×
253
                    classes = measurement.classes
×
254
                    data[column] = data[column].apply(lambda x: classes[x])  # pylint: disable=cell-var-from-loop
×
255

256
            return data
×
257

258
        result = geo_json_with_time_to_geopandas(response.to_dict())
1✔
259

260
        if resolve_classifications:
1✔
261
            result = transform_classifications(result)
×
262

263
        return result
1✔
264

265
    def wms_get_map_as_image(
1✔
266
            self, bbox: QueryRectangleWithResolution, raster_colorizer: RasterColorizer
267
    ) -> Image.Image:
268
        '''Return the result of a WMS request as a PIL Image'''
269

270
        if not self.__result_descriptor.is_raster_result():
1✔
271
            raise MethodNotCalledOnRasterException()
×
272

273
        session = get_session()
1✔
274

275
        with geoc.ApiClient(session.configuration) as api_client:
1✔
276
            wms_api = geoc.OGCWMSApi(api_client)
1✔
277
            response = wms_api.wms_map_handler(
1✔
278
                workflow=str(self),
279
                version=geoc.WmsVersion(geoc.WmsVersion.ENUM_1_DOT_3_DOT_0),
280
                service=geoc.WmsService(geoc.WmsService.WMS),
281
                request=geoc.GetMapRequest(geoc.GetMapRequest.GETMAP),
282
                width=int((bbox.spatial_bounds.xmax - bbox.spatial_bounds.xmin) / bbox.spatial_resolution.x_resolution),
283
                height=int((bbox.spatial_bounds.ymax - bbox.spatial_bounds.ymin) / bbox.spatial_resolution.y_resolution),  # pylint: disable=line-too-long
284
                bbox=bbox.bbox_ogc_str,
285
                format=geoc.GetMapFormat(geoc.GetMapFormat.IMAGE_SLASH_PNG),
286
                layers=str(self),
287
                styles='custom:' + raster_colorizer.to_api_dict().to_json(),
288
                crs=bbox.srs,
289
                time=bbox.time_str
290
            )
291

292
        if OGCXMLError.is_ogc_error(response):
1✔
293
            raise OGCXMLError(response)
1✔
294

295
        return Image.open(BytesIO(response))
1✔
296

297
    def plot_json(self, bbox: QueryRectangleWithResolution, timeout: int = 3600) -> geoc.WrappedPlotOutput:
1✔
298
        '''
299
        Query a workflow and return the plot chart result as WrappedPlotOutput
300
        '''
301

302
        if not self.__result_descriptor.is_plot_result():
1✔
303
            raise MethodNotCalledOnPlotException()
×
304

305
        session = get_session()
1✔
306

307
        with geoc.ApiClient(session.configuration) as api_client:
1✔
308
            plots_api = geoc.PlotsApi(api_client)
1✔
309
            return plots_api.get_plot_handler(
1✔
310
                bbox.bbox_str,
311
                bbox.time_str,
312
                str(bbox.spatial_resolution),
313
                str(self.__workflow_id),
314
                bbox.srs,
315
                _request_timeout=timeout
316
            )
317

318
    def plot_chart(self, bbox: QueryRectangleWithResolution, timeout: int = 3600) -> VegaLite:
1✔
319
        '''
320
        Query a workflow and return the plot chart result as a vega plot
321
        '''
322

323
        response = self.plot_json(bbox, timeout)
1✔
324
        vega_spec: VegaSpec = json.loads(response.data['vegaString'])
1✔
325

326
        return VegaLite(vega_spec)
1✔
327

328
    def __request_wcs(
1✔
329
        self,
330
        bbox: QueryRectangle,
331
        timeout=3600,
332
        file_format: str = 'image/tiff',
333
        force_no_data_value: Optional[float] = None
334
    ) -> ResponseWrapper:
335
        '''
336
        Query a workflow and return the coverage
337

338
        Parameters
339
        ----------
340
        bbox : A bounding box for the query
341
        timeout : HTTP request timeout in seconds
342
        file_format : The format of the returned raster
343
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
344
            Otherwise, use the Geo Engine will produce masked rasters.
345
        '''
346

347
        if not self.__result_descriptor.is_raster_result():
1✔
348
            raise MethodNotCalledOnRasterException()
×
349

350
        session = get_session()
1✔
351

352
        # TODO: properly build CRS string for bbox
353
        crs = f'urn:ogc:def:crs:{bbox.srs.replace(":", "::")}'
1✔
354

355
        wcs_url = f'{session.server_url}/wcs/{self.__workflow_id}'
1✔
356
        wcs = WebCoverageService(
1✔
357
            wcs_url,
358
            version='1.1.1',
359
            auth=Authentication(auth_delegate=session.requests_bearer_auth()),
360
        )
361

362
        resx = None
1✔
363
        resy = None
1✔
364
        if isinstance(bbox, QueryRectangleWithResolution):
1✔
365
            [resx, resy] = bbox.resolution_ogc
1✔
366

367
        kwargs = {}
1✔
368

369
        if force_no_data_value is not None:
1✔
370
            kwargs["nodatavalue"] = str(float(force_no_data_value))
×
371
        if resx is not None:
1✔
372
            kwargs["resx"] = str(resx)
1✔
373
        if resy is not None:
1✔
374
            kwargs["resy"] = str(resy)
1✔
375

376
        return wcs.getCoverage(
1✔
377
            identifier=f'{self.__workflow_id}',
378
            bbox=bbox.bbox_ogc,
379
            time=[bbox.time_str],
380
            format=file_format,
381
            crs=crs,
382
            timeout=timeout,
383
            **kwargs
384
        )
385

386
    def __get_wcs_tiff_as_memory_file(
1✔
387
        self,
388
        bbox: QueryRectangle,
389
        timeout=3600,
390
        force_no_data_value: Optional[float] = None
391
    ) -> rasterio.io.MemoryFile:
392
        '''
393
        Query a workflow and return the raster result as a memory mapped GeoTiff
394

395
        Parameters
396
        ----------
397
        bbox : A bounding box for the query
398
        timeout : HTTP request timeout in seconds
399
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
400
            Otherwise, use the Geo Engine will produce masked rasters.
401
        '''
402

403
        response = self.__request_wcs(bbox, timeout, 'image/tiff', force_no_data_value).read()
1✔
404

405
        # response is checked via `raise_on_error` in `getCoverage` / `openUrl`
406

407
        memory_file = rasterio.io.MemoryFile(response)
1✔
408

409
        return memory_file
1✔
410

411
    def get_array(
1✔
412
        self,
413
        bbox: QueryRectangle,
414
        timeout=3600,
415
        force_no_data_value: Optional[float] = None
416
    ) -> np.ndarray:
417
        '''
418
        Query a workflow and return the raster result as a numpy array
419

420
        Parameters
421
        ----------
422
        bbox : A bounding box for the query
423
        timeout : HTTP request timeout in seconds
424
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
425
            Otherwise, use the Geo Engine will produce masked rasters.
426
        '''
427

428
        with self.__get_wcs_tiff_as_memory_file(
1✔
429
            bbox,
430
            timeout,
431
            force_no_data_value
432
        ) as memfile, memfile.open() as dataset:
433
            array = dataset.read(1)
1✔
434

435
            return array
1✔
436

437
    def get_xarray(
1✔
438
        self,
439
        bbox: QueryRectangle,
440
        timeout=3600,
441
        force_no_data_value: Optional[float] = None
442
    ) -> xr.DataArray:
443
        '''
444
        Query a workflow and return the raster result as a georeferenced xarray
445

446
        Parameters
447
        ----------
448
        bbox : A bounding box for the query
449
        timeout : HTTP request timeout in seconds
450
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
451
            Otherwise, use the Geo Engine will produce masked rasters.
452
        '''
453

454
        with self.__get_wcs_tiff_as_memory_file(
1✔
455
            bbox,
456
            timeout,
457
            force_no_data_value
458
        ) as memfile, memfile.open() as dataset:
459
            data_array = rioxarray.open_rasterio(dataset)
1✔
460

461
            # helping mypy with inference
462
            assert isinstance(data_array, xr.DataArray)
1✔
463

464
            rio: xr.DataArray = data_array.rio
1✔
465
            rio.update_attrs({
1✔
466
                'crs': rio.crs,
467
                'res': rio.resolution(),
468
                'transform': rio.transform(),
469
            }, inplace=True)
470

471
            # TODO: add time information to dataset
472
            return data_array.load()
1✔
473

474
    # pylint: disable=too-many-arguments,too-many-positional-arguments
475
    def download_raster(
1✔
476
        self,
477
        bbox: QueryRectangle,
478
        file_path: str,
479
        timeout=3600,
480
        file_format: str = 'image/tiff',
481
        force_no_data_value: Optional[float] = None
482
    ) -> None:
483
        '''
484
        Query a workflow and save the raster result as a file on disk
485

486
        Parameters
487
        ----------
488
        bbox : A bounding box for the query
489
        file_path : The path to the file to save the raster to
490
        timeout : HTTP request timeout in seconds
491
        file_format : The format of the returned raster
492
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
493
            Otherwise, use the Geo Engine will produce masked rasters.
494
        '''
495

496
        response = self.__request_wcs(bbox, timeout, file_format, force_no_data_value)
×
497

498
        with open(file_path, 'wb') as file:
×
499
            file.write(response.read())
×
500

501
    def get_provenance(self, timeout: int = 60) -> List[ProvenanceEntry]:
1✔
502
        '''
503
        Query the provenance of the workflow
504
        '''
505

506
        session = get_session()
1✔
507

508
        with geoc.ApiClient(session.configuration) as api_client:
1✔
509
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
510
            response = workflows_api.get_workflow_provenance_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
511

512
        return [ProvenanceEntry.from_response(item) for item in response]
1✔
513

514
    def metadata_zip(self, path: Union[PathLike, BytesIO], timeout: int = 60) -> None:
1✔
515
        '''
516
        Query workflow metadata and citations and stores it as zip file to `path`
517
        '''
518

519
        session = get_session()
×
520

521
        with geoc.ApiClient(session.configuration) as api_client:
×
522
            workflows_api = geoc.WorkflowsApi(api_client)
×
523
            response = workflows_api.get_workflow_all_metadata_zip_handler(
×
524
                str(self.__workflow_id),
525
                _request_timeout=timeout
526
            )
527

528
        if isinstance(path, BytesIO):
×
529
            path.write(response)
×
530
        else:
531
            with open(path, 'wb') as file:
×
532
                file.write(response)
×
533

534
    # pylint: disable=too-many-positional-arguments,too-many-positional-arguments
535
    def save_as_dataset(
1✔
536
            self,
537
            query_rectangle: QueryRectangle,
538
            name: Optional[str],
539
            display_name: str,
540
            description: str = '',
541
            timeout: int = 3600) -> Task:
542
        '''Init task to store the workflow result as a layer'''
543

544
        # Currently, it only works for raster results
545
        if not self.__result_descriptor.is_raster_result():
1✔
546
            raise MethodNotCalledOnRasterException()
×
547

548
        session = get_session()
1✔
549

550
        if not isinstance(query_rectangle, QueryRectangle):
1✔
551
            print("save_as_dataset ignores params other then spatial and tmporal bounds.")
×
552

553
        qrect = geoc.models.raster_to_dataset_query_rectangle.RasterToDatasetQueryRectangle(
1✔
554
            spatial_bounds=SpatialPartition2D.from_bounding_box(query_rectangle.spatial_bounds).to_api_dict(),
555
            time_interval=query_rectangle.time.to_api_dict()
556
        )
557

558
        with geoc.ApiClient(session.configuration) as api_client:
1✔
559
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
560
            response = workflows_api.dataset_from_workflow_handler(
1✔
561
                str(self.__workflow_id),
562
                geoc.RasterDatasetFromWorkflow(
563
                    name=name,
564
                    display_name=display_name,
565
                    description=description,
566
                    query=qrect
567
                ),
568
                _request_timeout=timeout
569
            )
570

571
        return Task(TaskId.from_response(response))
1✔
572

573
    async def raster_stream(
1✔
574
        self,
575
        query_rectangle: QueryRectangle,
576
        open_timeout: int = 60,
577
        bands: Optional[List[int]] = None  # TODO: move into query rectangle?
578
    ) -> AsyncIterator[RasterTile2D]:
579
        '''Stream the workflow result as series of RasterTile2D (transformable to numpy and xarray)'''
580

581
        if not isinstance(query_rectangle, QueryRectangle):
1✔
582
            print("raster_stream ignores params other then spatial and tmporal bounds.")
×
583

584
        if bands is None:
1✔
585
            bands = [0]
1✔
586

587
        if len(bands) == 0:
1✔
588
            raise InputException('At least one band must be specified')
×
589

590
        # Currently, it only works for raster results
591
        if not self.__result_descriptor.is_raster_result():
1✔
592
            raise MethodNotCalledOnRasterException()
×
593

594
        session = get_session()
1✔
595

596
        url = req.Request(
1✔
597
            'GET',
598
            url=f'{session.server_url}/workflow/{self.__workflow_id}/rasterStream',
599
            params={
600
                'resultType': 'arrow',
601
                'spatialBounds': query_rectangle.bbox_str,
602
                'timeInterval': query_rectangle.time_str,
603
                'attributes': ','.join(map(str, bands))
604
            },
605
        ).prepare().url
606

607
        if url is None:
1✔
608
            raise InputException('Invalid websocket url')
×
609

610
        async with websockets.client.connect(
1✔
611
            uri=self.__replace_http_with_ws(url),
612
            extra_headers=session.auth_header,
613
            open_timeout=open_timeout,
614
            max_size=None,
615
        ) as websocket:
616

617
            tile_bytes: Optional[bytes] = None
1✔
618

619
            while websocket.open:
1✔
620
                async def read_new_bytes() -> Optional[bytes]:
1✔
621
                    # already send the next request to speed up the process
622
                    try:
1✔
623
                        await websocket.send("NEXT")
1✔
624
                    except websockets.exceptions.ConnectionClosed:
×
625
                        # the websocket connection is already closed, we cannot read anymore
626
                        return None
×
627

628
                    try:
1✔
629
                        data: Union[str, bytes] = await websocket.recv()
1✔
630

631
                        if isinstance(data, str):
1✔
632
                            # the server sent an error message
633
                            raise GeoEngineException({'error': data})
×
634

635
                        return data
1✔
636
                    except websockets.exceptions.ConnectionClosedOK:
×
637
                        # the websocket connection closed gracefully, so we stop reading
638
                        return None
×
639

640
                (tile_bytes, tile) = await asyncio.gather(
1✔
641
                    read_new_bytes(),
642
                    # asyncio.to_thread(process_bytes, tile_bytes), # TODO: use this when min Python version is 3.9
643
                    backports.to_thread(RasterStreamProcessing.process_bytes, tile_bytes),
644
                )
645

646
                if tile is not None:
1✔
647
                    yield tile
1✔
648

649
            # process the last tile
650
            tile = RasterStreamProcessing.process_bytes(tile_bytes)
1✔
651

652
            if tile is not None:
1✔
653
                yield tile
1✔
654

655
    async def raster_stream_into_xarray(
1✔
656
        self,
657
        query_rectangle: QueryRectangle,
658
        clip_to_query_rectangle: bool = False,
659
        open_timeout: int = 60,
660
        bands: Optional[List[int]] = None  # TODO: move into query rectangle?
661
    ) -> xr.DataArray:
662
        '''
663
        Stream the workflow result into memory and output a single xarray.
664

665
        NOTE: You can run out of memory if the query rectangle is too large.
666
        '''
667

668
        if bands is None:
1✔
669
            bands = [0]
1✔
670

671
        if len(bands) == 0:
1✔
672
            raise InputException('At least one band must be specified')
×
673

674
        tile_stream = self.raster_stream(
1✔
675
            query_rectangle,
676
            open_timeout=open_timeout,
677
            bands=bands
678
        )
679

680
        timestep_xarrays: List[xr.DataArray] = []
1✔
681

682
        spatial_clip_bounds = query_rectangle.spatial_bounds if clip_to_query_rectangle else None
1✔
683

684
        async def read_tiles(
1✔
685
            remainder_tile: Optional[RasterTile2D]
686
        ) -> tuple[List[xr.DataArray], Optional[RasterTile2D]]:
687
            last_timestep: Optional[np.datetime64] = None
1✔
688
            tiles = []
1✔
689

690
            if remainder_tile is not None:
1✔
691
                last_timestep = remainder_tile.time_start_ms
1✔
692
                xr_tile = remainder_tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
693
                tiles.append(xr_tile)
1✔
694

695
            async for tile in tile_stream:
1✔
696
                timestep: np.datetime64 = tile.time_start_ms
1✔
697
                if last_timestep is None:
1✔
698
                    last_timestep = timestep
1✔
699
                elif last_timestep != timestep:
1✔
700
                    return tiles, tile
1✔
701

702
                xr_tile = tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
703
                tiles.append(xr_tile)
1✔
704

705
            # this seems to be the last time step, so just return tiles
706
            return tiles, None
1✔
707

708
        (tiles, remainder_tile) = await read_tiles(None)
1✔
709

710
        while len(tiles):
1✔
711
            ((new_tiles, new_remainder_tile), new_timestep_xarray) = await asyncio.gather(
1✔
712
                read_tiles(remainder_tile),
713
                backports.to_thread(RasterStreamProcessing.merge_tiles, tiles)
714
                # asyncio.to_thread(merge_tiles, tiles), # TODO: use this when min Python version is 3.9
715
            )
716

717
            tiles = new_tiles
1✔
718
            remainder_tile = new_remainder_tile
1✔
719

720
            if new_timestep_xarray is not None:
1✔
721
                timestep_xarrays.append(new_timestep_xarray)
1✔
722

723
        output: xr.DataArray = cast(
1✔
724
            xr.DataArray,
725
            # await asyncio.to_thread( # TODO: use this when min Python version is 3.9
726
            await backports.to_thread(
727
                xr.concat,
728
                # TODO: This is a typings error, since the method accepts also a `xr.DataArray` and returns one
729
                cast(List[xr.Dataset], timestep_xarrays),
730
                dim='time'
731
            )
732
        )
733

734
        return output
1✔
735

736
    async def vector_stream(
1✔
737
            self,
738
            query_rectangle: QueryRectangle,
739
            time_start_column: str = 'time_start',
740
            time_end_column: str = 'time_end',
741
            open_timeout: int = 60) -> AsyncIterator[gpd.GeoDataFrame]:
742
        '''Stream the workflow result as series of `GeoDataFrame`s'''
743

744
        def read_arrow_ipc(arrow_ipc: bytes) -> pa.RecordBatch:
1✔
745
            reader = pa.ipc.open_file(arrow_ipc)
1✔
746
            # We know from the backend that there is only one record batch
747
            record_batch = reader.get_record_batch(0)
1✔
748
            return record_batch
1✔
749

750
        def create_geo_data_frame(record_batch: pa.RecordBatch,
1✔
751
                                  time_start_column: str,
752
                                  time_end_column: str) -> gpd.GeoDataFrame:
753
            metadata = record_batch.schema.metadata
1✔
754
            spatial_reference = metadata[b'spatialReference'].decode('utf-8')
1✔
755

756
            data_frame = record_batch.to_pandas()
1✔
757

758
            geometry = gpd.GeoSeries.from_wkt(data_frame[api.GEOMETRY_COLUMN_NAME])
1✔
759
            del data_frame[api.GEOMETRY_COLUMN_NAME]  # delete the duplicated column
1✔
760

761
            geo_data_frame = gpd.GeoDataFrame(
1✔
762
                data_frame,
763
                geometry=geometry,
764
                crs=spatial_reference,
765
            )
766

767
            # split time column
768
            geo_data_frame[[time_start_column, time_end_column]] = geo_data_frame[api.TIME_COLUMN_NAME].tolist()
1✔
769
            del geo_data_frame[api.TIME_COLUMN_NAME]  # delete the duplicated column
1✔
770

771
            # parse time columns
772
            for time_column in [time_start_column, time_end_column]:
1✔
773
                geo_data_frame[time_column] = pd.to_datetime(
1✔
774
                    geo_data_frame[time_column],
775
                    utc=True,
776
                    unit='ms',
777
                    # TODO: solve time conversion problem from Geo Engine to Python for large (+/-) time instances
778
                    errors='coerce',
779
                )
780

781
            return geo_data_frame
1✔
782

783
        def process_bytes(batch_bytes: Optional[bytes]) -> Optional[gpd.GeoDataFrame]:
1✔
784
            if batch_bytes is None:
1✔
785
                return None
1✔
786

787
            # process the received data
788
            record_batch = read_arrow_ipc(batch_bytes)
1✔
789
            tile = create_geo_data_frame(
1✔
790
                record_batch,
791
                time_start_column=time_start_column,
792
                time_end_column=time_end_column,
793
            )
794

795
            return tile
1✔
796

797
        # Currently, it only works for raster results
798
        if not self.__result_descriptor.is_vector_result():
1✔
799
            raise MethodNotCalledOnVectorException()
×
800

801
        session = get_session()
1✔
802

803
        params = {
1✔
804
            'resultType': 'arrow',
805
            'spatialBounds': query_rectangle.bbox_str,
806
            'timeInterval': query_rectangle.time_str
807
        }
808

809
        if isinstance(query_rectangle, QueryRectangleWithResolution):
1✔
810
            params['spatialResolution'] = str(query_rectangle.spatial_resolution)
×
811

812
        url = req.Request(
1✔
813
            'GET',
814
            url=f'{session.server_url}/workflow/{self.__workflow_id}/vectorStream',
815
            params=params
816
        ).prepare().url
817

818
        if url is None:
1✔
819
            raise InputException('Invalid websocket url')
×
820

821
        async with websockets.client.connect(
1✔
822
            uri=self.__replace_http_with_ws(url),
823
            extra_headers=session.auth_header,
824
            open_timeout=open_timeout,
825
            max_size=None,  # allow arbitrary large messages, since it is capped by the server's chunk size
826
        ) as websocket:
827

828
            batch_bytes: Optional[bytes] = None
1✔
829

830
            while websocket.open:
1✔
831
                async def read_new_bytes() -> Optional[bytes]:
1✔
832
                    # already send the next request to speed up the process
833
                    try:
1✔
834
                        await websocket.send("NEXT")
1✔
835
                    except websockets.exceptions.ConnectionClosed:
×
836
                        # the websocket connection is already closed, we cannot read anymore
837
                        return None
×
838

839
                    try:
1✔
840
                        data: Union[str, bytes] = await websocket.recv()
1✔
841

842
                        if isinstance(data, str):
1✔
843
                            # the server sent an error message
844
                            raise GeoEngineException({'error': data})
×
845

846
                        return data
1✔
847
                    except websockets.exceptions.ConnectionClosedOK:
×
848
                        # the websocket connection closed gracefully, so we stop reading
849
                        return None
×
850

851
                (batch_bytes, batch) = await asyncio.gather(
1✔
852
                    read_new_bytes(),
853
                    # asyncio.to_thread(process_bytes, batch_bytes), # TODO: use this when min Python version is 3.9
854
                    backports.to_thread(process_bytes, batch_bytes),
855
                )
856

857
                if batch is not None:
1✔
858
                    yield batch
1✔
859

860
            # process the last tile
861
            batch = process_bytes(batch_bytes)
1✔
862

863
            if batch is not None:
1✔
864
                yield batch
1✔
865

866
    async def vector_stream_into_geopandas(
1✔
867
            self,
868
            query_rectangle: QueryRectangle,
869
            time_start_column: str = 'time_start',
870
            time_end_column: str = 'time_end',
871
            open_timeout: int = 60) -> gpd.GeoDataFrame:
872
        '''
873
        Stream the workflow result into memory and output a single geo data frame.
874

875
        NOTE: You can run out of memory if the query rectangle is too large.
876
        '''
877

878
        chunk_stream = self.vector_stream(
1✔
879
            query_rectangle,
880
            time_start_column=time_start_column,
881
            time_end_column=time_end_column,
882
            open_timeout=open_timeout,
883
        )
884

885
        data_frame: Optional[gpd.GeoDataFrame] = None
1✔
886
        chunk: Optional[gpd.GeoDataFrame] = None
1✔
887

888
        async def read_dataframe() -> Optional[gpd.GeoDataFrame]:
1✔
889
            try:
1✔
890
                return await chunk_stream.__anext__()
1✔
891
            except StopAsyncIteration:
1✔
892
                return None
1✔
893

894
        def merge_dataframes(
1✔
895
            df_a: Optional[gpd.GeoDataFrame],
896
            df_b: Optional[gpd.GeoDataFrame]
897
        ) -> Optional[gpd.GeoDataFrame]:
898
            if df_a is None:
1✔
899
                return df_b
1✔
900

901
            if df_b is None:
1✔
902
                return df_a
×
903

904
            return pd.concat([df_a, df_b], ignore_index=True)
1✔
905

906
        while True:
907
            (chunk, data_frame) = await asyncio.gather(
1✔
908
                read_dataframe(),
909
                backports.to_thread(merge_dataframes, data_frame, chunk),
910
                # TODO: use this when min Python version is 3.9
911
                # asyncio.to_thread(merge_dataframes, data_frame, chunk),
912
            )
913

914
            # we can stop when the chunk stream is exhausted
915
            if chunk is None:
1✔
916
                break
1✔
917

918
        return data_frame
1✔
919

920
    def __replace_http_with_ws(self, url: str) -> str:
1✔
921
        '''
922
        Replace the protocol in the url from `http` to `ws`.
923

924
        For the websockets library, it is necessary that the url starts with `ws://`.
925
        For HTTPS, we need to use `wss://` instead.
926
        '''
927

928
        [protocol, url_part] = url.split('://', maxsplit=1)
1✔
929

930
        ws_prefix = 'wss://' if 's' in protocol.lower() else 'ws://'
1✔
931

932
        return f'{ws_prefix}{url_part}'
1✔
933

934

935
def register_workflow(workflow: Union[Dict[str, Any], WorkflowBuilderOperator], timeout: int = 60) -> Workflow:
1✔
936
    '''
937
    Register a workflow in Geo Engine and receive a `WorkflowId`
938
    '''
939

940
    if isinstance(workflow, WorkflowBuilderOperator):
1✔
941
        workflow = workflow.to_workflow_dict()
1✔
942

943
    workflow_model = geoc.Workflow.from_dict(workflow)
1✔
944

945
    if workflow_model is None:
1✔
946
        raise InputException("Invalid workflow definition")
×
947

948
    session = get_session()
1✔
949

950
    with geoc.ApiClient(session.configuration) as api_client:
1✔
951
        workflows_api = geoc.WorkflowsApi(api_client)
1✔
952
        response = workflows_api.register_workflow_handler(workflow_model, _request_timeout=timeout)
1✔
953

954
    return Workflow(WorkflowId.from_response(response))
1✔
955

956

957
def workflow_by_id(workflow_id: UUID) -> Workflow:
1✔
958
    '''
959
    Create a workflow object from a workflow id
960
    '''
961

962
    # TODO: check that workflow exists
963

964
    return Workflow(WorkflowId(workflow_id))
1✔
965

966

967
def get_quota(user_id: Optional[UUID] = None, timeout: int = 60) -> geoc.Quota:
1✔
968
    '''
969
    Gets a user's quota. Only admins can get other users' quota.
970
    '''
971

972
    session = get_session()
×
973

974
    with geoc.ApiClient(session.configuration) as api_client:
×
975
        user_api = geoc.UserApi(api_client)
×
976

977
        if user_id is None:
×
978
            return user_api.quota_handler(_request_timeout=timeout)
×
979

980
        return user_api.get_user_quota_handler(str(user_id), _request_timeout=timeout)
×
981

982

983
def update_quota(user_id: UUID, new_available_quota: int, timeout: int = 60) -> None:
1✔
984
    '''
985
    Update a user's quota. Only admins can perform this operation.
986
    '''
987

988
    session = get_session()
×
989

990
    with geoc.ApiClient(session.configuration) as api_client:
×
991
        user_api = geoc.UserApi(api_client)
×
992
        user_api.update_user_quota_handler(
×
993
            str(user_id),
994
            geoc.UpdateQuota(
995
                available=new_available_quota
996
            ),
997
            _request_timeout=timeout
998
        )
999

1000

1001
def data_usage(offset: int = 0, limit: int = 10) -> List[geoc.DataUsage]:
1✔
1002
    '''
1003
    Get data usage
1004
    '''
1005

1006
    session = get_session()
1✔
1007

1008
    with geoc.ApiClient(session.configuration) as api_client:
1✔
1009
        user_api = geoc.UserApi(api_client)
1✔
1010
        response = user_api.data_usage_handler(offset=offset, limit=limit)
1✔
1011

1012
        # create dataframe from response
1013
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1014
        df = pd.DataFrame(usage_dicts)
1✔
1015
        if 'timestamp' in df.columns:
1✔
1016
            df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
1✔
1017

1018
    return df
1✔
1019

1020

1021
def data_usage_summary(granularity: geoc.UsageSummaryGranularity,
1✔
1022
                       dataset: Optional[str] = None,
1023
                       offset: int = 0, limit: int = 10) -> pd.DataFrame:
1024
    '''
1025
    Get data usage summary
1026
    '''
1027

1028
    session = get_session()
1✔
1029

1030
    with geoc.ApiClient(session.configuration) as api_client:
1✔
1031
        user_api = geoc.UserApi(api_client)
1✔
1032
        response = user_api.data_usage_summary_handler(dataset=dataset, granularity=granularity,
1✔
1033
                                                       offset=offset, limit=limit)
1034

1035
        # create dataframe from response
1036
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1037
        df = pd.DataFrame(usage_dicts)
1✔
1038
        if 'timestamp' in df.columns:
1✔
1039
            df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
1✔
1040

1041
    return df
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc