• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine-python / 19573839061

21 Nov 2025 02:37PM UTC coverage: 78.668% (-1.1%) from 79.74%
19573839061

Pull #221

github

web-flow
Merge 4012cbccd into e06d48b64
Pull Request #221: Pixel_based_queries_rewrite

3083 of 3919 relevant lines covered (78.67%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.4
geoengine/workflow.py
1
"""
2
A workflow representation and methods on workflows
3
"""
4
# pylint: disable=too-many-lines
5
# TODO: split into multiple files
6

7
from __future__ import annotations
1✔
8

9
import asyncio
1✔
10
import json
1✔
11
from collections import defaultdict
1✔
12
from collections.abc import AsyncIterator
1✔
13
from io import BytesIO
1✔
14
from logging import debug
1✔
15
from os import PathLike
1✔
16
from typing import Any, TypedDict, cast
1✔
17
from uuid import UUID
1✔
18

19
import geoengine_openapi_client as geoc
1✔
20
import geopandas as gpd
1✔
21
import numpy as np
1✔
22
import pandas as pd
1✔
23
import pyarrow as pa
1✔
24
import rasterio.io
1✔
25
import requests as req
1✔
26
import rioxarray
1✔
27
import websockets
1✔
28
import websockets.asyncio.client
1✔
29
import xarray as xr
1✔
30
from owslib.util import Authentication, ResponseWrapper
1✔
31
from owslib.wcs import WebCoverageService
1✔
32
from PIL import Image
1✔
33
from vega import VegaLite
1✔
34

35
from geoengine import api, backports
1✔
36
from geoengine.auth import get_session
1✔
37
from geoengine.error import (
1✔
38
    GeoEngineException,
39
    InputException,
40
    MethodNotCalledOnPlotException,
41
    MethodNotCalledOnRasterException,
42
    MethodNotCalledOnVectorException,
43
    OGCXMLError,
44
)
45
from geoengine.raster import RasterTile2D
1✔
46
from geoengine.tasks import Task, TaskId
1✔
47
from geoengine.types import (
1✔
48
    ClassificationMeasurement,
49
    ProvenanceEntry,
50
    QueryRectangle,
51
    RasterColorizer,
52
    RasterQueryRectangle,
53
    RasterResultDescriptor,
54
    ResultDescriptor,
55
    SpatialPartition2D,
56
    SpatialResolution,
57
    VectorResultDescriptor,
58
)
59
from geoengine.workflow_builder.operators import Operator as WorkflowBuilderOperator
1✔
60

61
# TODO: Define as recursive type when supported in mypy: https://github.com/python/mypy/issues/731
62
JsonType = dict[str, Any] | list[Any] | int | str | float | bool | type[None]
1✔
63

64

65
class Axis(TypedDict):
1✔
66
    title: str
1✔
67

68

69
class Bin(TypedDict):
1✔
70
    binned: bool
1✔
71
    step: float
1✔
72

73

74
class Field(TypedDict):
1✔
75
    field: str
1✔
76

77

78
class DatasetIds(TypedDict):
1✔
79
    upload: UUID
1✔
80
    dataset: UUID
1✔
81

82

83
class Values(TypedDict):
1✔
84
    binStart: float
1✔
85
    binEnd: float
1✔
86
    Frequency: int
1✔
87

88

89
class X(TypedDict):
1✔
90
    field: Field
1✔
91
    bin: Bin
1✔
92
    axis: Axis
1✔
93

94

95
class X2(TypedDict):
1✔
96
    field: Field
1✔
97

98

99
class Y(TypedDict):
1✔
100
    field: Field
1✔
101
    type: str
1✔
102

103

104
class Encoding(TypedDict):
1✔
105
    x: X
1✔
106
    x2: X2
1✔
107
    y: Y
1✔
108

109

110
VegaSpec = TypedDict("VegaSpec", {"$schema": str, "data": list[Values], "mark": str, "encoding": Encoding})
1✔
111

112

113
class WorkflowId:
1✔
114
    """
115
    A wrapper around a workflow UUID
116
    """
117

118
    __workflow_id: UUID
1✔
119

120
    def __init__(self, workflow_id: UUID | str) -> None:
1✔
121
        """Create a new WorkflowId from an UUID or uuid as str"""
122

123
        if not isinstance(workflow_id, UUID):
1✔
124
            workflow_id = UUID(workflow_id)
1✔
125

126
        self.__workflow_id = workflow_id
1✔
127

128
    @classmethod
1✔
129
    def from_response(cls, response: geoc.IdResponse) -> WorkflowId:
1✔
130
        """
131
        Create a `WorkflowId` from an http response
132
        """
133
        return WorkflowId(response.id)
1✔
134

135
    def __str__(self) -> str:
1✔
136
        return str(self.__workflow_id)
1✔
137

138
    def __repr__(self) -> str:
1✔
139
        return str(self)
1✔
140

141
    def to_dict(self) -> UUID:
1✔
142
        return self.__workflow_id
1✔
143

144

145
class RasterStreamProcessing:
1✔
146
    """
147
    Helper class to process raster stream data
148
    """
149

150
    @classmethod
1✔
151
    def read_arrow_ipc(cls, arrow_ipc: bytes) -> pa.RecordBatch:
1✔
152
        """Read an Arrow IPC file from a byte array"""
153

154
        reader = pa.ipc.open_file(arrow_ipc)
1✔
155
        # We know from the backend that there is only one record batch
156
        record_batch = reader.get_record_batch(0)
1✔
157
        return record_batch
1✔
158

159
    @classmethod
1✔
160
    def process_bytes(cls, tile_bytes: bytes | None) -> RasterTile2D | None:
1✔
161
        """Process a tile from a byte array"""
162

163
        if tile_bytes is None:
1✔
164
            return None
1✔
165

166
        # process the received data
167
        record_batch = RasterStreamProcessing.read_arrow_ipc(tile_bytes)
1✔
168
        tile = RasterTile2D.from_ge_record_batch(record_batch)
1✔
169

170
        return tile
1✔
171

172
    @classmethod
1✔
173
    def merge_tiles(cls, tiles: list[xr.DataArray]) -> xr.DataArray | None:
1✔
174
        """Merge a list of tiles into a single xarray"""
175

176
        if len(tiles) == 0:
1✔
177
            return None
×
178

179
        # group the tiles by band
180
        tiles_by_band: dict[int, list[xr.DataArray]] = defaultdict(list)
1✔
181
        for tile in tiles:
1✔
182
            band = tile.band.item()  # assuming 'band' is a coordinate with a single value
1✔
183
            tiles_by_band[band].append(tile)
1✔
184

185
        # build one spatial tile per band
186
        combined_by_band = []
1✔
187
        for band_tiles in tiles_by_band.values():
1✔
188
            combined = xr.combine_by_coords(band_tiles)
1✔
189
            # `combine_by_coords` always returns a `DataArray` for single variable input arrays.
190
            # This assertion verifies this for mypy
191
            assert isinstance(combined, xr.DataArray)
1✔
192
            combined_by_band.append(combined)
1✔
193

194
        # build one array with all bands and geo coordinates
195
        combined_tile = xr.concat(combined_by_band, dim="band")
1✔
196

197
        return combined_tile
1✔
198

199

200
class Workflow:
1✔
201
    """
202
    Holds a workflow id and allows querying data
203
    """
204

205
    __workflow_id: WorkflowId
1✔
206
    __result_descriptor: ResultDescriptor
1✔
207

208
    def __init__(self, workflow_id: WorkflowId) -> None:
1✔
209
        self.__workflow_id = workflow_id
1✔
210
        self.__result_descriptor = self.__query_result_descriptor()
1✔
211

212
    def __str__(self) -> str:
1✔
213
        return str(self.__workflow_id)
1✔
214

215
    def __repr__(self) -> str:
1✔
216
        return repr(self.__workflow_id)
1✔
217

218
    def __query_result_descriptor(self, timeout: int = 60) -> ResultDescriptor:
1✔
219
        """
220
        Query the metadata of the workflow result
221
        """
222

223
        session = get_session()
1✔
224

225
        with geoc.ApiClient(session.configuration) as api_client:
1✔
226
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
227
            response = workflows_api.get_workflow_metadata_handler(
1✔
228
                self.__workflow_id.to_dict(), _request_timeout=timeout
229
            )
230

231
        debug(response)
1✔
232

233
        return ResultDescriptor.from_response(response)
1✔
234

235
    def get_result_descriptor(self) -> ResultDescriptor:
1✔
236
        """
237
        Return the metadata of the workflow result
238
        """
239

240
        return self.__result_descriptor
1✔
241

242
    def workflow_definition(self, timeout: int = 60) -> geoc.Workflow:
1✔
243
        """Return the workflow definition for this workflow"""
244

245
        session = get_session()
1✔
246

247
        with geoc.ApiClient(session.configuration) as api_client:
1✔
248
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
249
            response = workflows_api.load_workflow_handler(self.__workflow_id.to_dict(), _request_timeout=timeout)
1✔
250

251
        return response
1✔
252

253
    def get_dataframe(
1✔
254
        self, bbox: QueryRectangle, timeout: int = 3600, resolve_classifications: bool = False
255
    ) -> gpd.GeoDataFrame:
256
        """
257
        Query a workflow and return the WFS result as a GeoPandas `GeoDataFrame`
258
        """
259

260
        if not self.__result_descriptor.is_vector_result():
1✔
261
            raise MethodNotCalledOnVectorException()
1✔
262

263
        session = get_session()
1✔
264

265
        with geoc.ApiClient(session.configuration) as api_client:
1✔
266
            wfs_api = geoc.OGCWFSApi(api_client)
1✔
267
            response = wfs_api.wfs_feature_handler(
1✔
268
                workflow=self.__workflow_id.to_dict(),
269
                service=geoc.WfsService(geoc.WfsService.WFS),
270
                request=geoc.GetFeatureRequest(geoc.GetFeatureRequest.GETFEATURE),
271
                type_names=str(self.__workflow_id),
272
                bbox=bbox.bbox_str,
273
                version=geoc.WfsVersion(geoc.WfsVersion.ENUM_2_DOT_0_DOT_0),
274
                time=bbox.time_str,
275
                srs_name=bbox.srs,
276
                _request_timeout=timeout,
277
            )
278

279
        def geo_json_with_time_to_geopandas(geo_json):
1✔
280
            """
281
            GeoJson has no standard for time, so we parse the when field
282
            separately and attach it to the data frame as columns `start`
283
            and `end`.
284
            """
285

286
            data = gpd.GeoDataFrame.from_features(geo_json)
1✔
287
            data = data.set_crs(bbox.srs, allow_override=True)
1✔
288

289
            start = [f["when"]["start"] for f in geo_json["features"]]
1✔
290
            end = [f["when"]["end"] for f in geo_json["features"]]
1✔
291

292
            # TODO: find a good way to infer BoT/EoT
293

294
            data["start"] = gpd.pd.to_datetime(start, errors="coerce")
1✔
295
            data["end"] = gpd.pd.to_datetime(end, errors="coerce")
1✔
296

297
            return data
1✔
298

299
        def transform_classifications(data: gpd.GeoDataFrame):
1✔
300
            result_descriptor: VectorResultDescriptor = self.__result_descriptor  # type: ignore
×
301
            for column, info in result_descriptor.columns.items():
×
302
                if isinstance(info.measurement, ClassificationMeasurement):
×
303
                    measurement: ClassificationMeasurement = info.measurement
×
304
                    classes = measurement.classes
×
305
                    data[column] = data[column].apply(lambda x, classes=classes: classes[x])  # pylint: disable=cell-var-from-loop
×
306

307
            return data
×
308

309
        result = geo_json_with_time_to_geopandas(response.to_dict())
1✔
310

311
        if resolve_classifications:
1✔
312
            result = transform_classifications(result)
×
313

314
        return result
1✔
315

316
    def wms_get_map_as_image(
1✔
317
        self,
318
        bbox: QueryRectangle,
319
        raster_colorizer: RasterColorizer,
320
        # TODO: allow to use width height
321
        spatial_resolution: SpatialResolution,
322
    ) -> Image.Image:
323
        """Return the result of a WMS request as a PIL Image"""
324

325
        if not self.__result_descriptor.is_raster_result():
1✔
326
            raise MethodNotCalledOnRasterException()
×
327

328
        session = get_session()
1✔
329

330
        with geoc.ApiClient(session.configuration) as api_client:
1✔
331
            wms_api = geoc.OGCWMSApi(api_client)
1✔
332
            response = wms_api.wms_map_handler(
1✔
333
                workflow=self.__workflow_id.to_dict(),
334
                version=geoc.WmsVersion(geoc.WmsVersion.ENUM_1_DOT_3_DOT_0),
335
                service=geoc.WmsService(geoc.WmsService.WMS),
336
                request=geoc.GetMapRequest(geoc.GetMapRequest.GETMAP),
337
                width=int((bbox.spatial_bounds.xmax - bbox.spatial_bounds.xmin) / spatial_resolution.x_resolution),
338
                height=int((bbox.spatial_bounds.ymax - bbox.spatial_bounds.ymin) / spatial_resolution.y_resolution),  # pylint: disable=line-too-long
339
                bbox=bbox.bbox_ogc_str,
340
                format=geoc.GetMapFormat(geoc.GetMapFormat.IMAGE_SLASH_PNG),
341
                layers=str(self),
342
                styles="custom:" + raster_colorizer.to_api_dict().to_json(),
343
                crs=bbox.srs,
344
                time=bbox.time_str,
345
            )
346

347
        if OGCXMLError.is_ogc_error(response):
1✔
348
            raise OGCXMLError(response)
1✔
349

350
        return Image.open(BytesIO(response))
1✔
351

352
    def plot_json(
1✔
353
        self, bbox: QueryRectangle, spatial_resolution: SpatialResolution | None = None, timeout: int = 3600
354
    ) -> geoc.WrappedPlotOutput:
355
        """
356
        Query a workflow and return the plot chart result as WrappedPlotOutput
357
        """
358

359
        if not self.__result_descriptor.is_plot_result():
1✔
360
            raise MethodNotCalledOnPlotException()
×
361

362
        session = get_session()
1✔
363

364
        with geoc.ApiClient(session.configuration) as api_client:
1✔
365
            plots_api = geoc.PlotsApi(api_client)
1✔
366
            return plots_api.get_plot_handler(
1✔
367
                bbox.bbox_str,
368
                bbox.time_str,
369
                str(spatial_resolution),
370
                self.__workflow_id.to_dict(),
371
                bbox.srs,
372
                _request_timeout=timeout,
373
            )
374

375
    def plot_chart(
1✔
376
        self, bbox: QueryRectangle, spatial_resolution: SpatialResolution | None = None, timeout: int = 3600
377
    ) -> VegaLite:
378
        """
379
        Query a workflow and return the plot chart result as a vega plot
380
        """
381

382
        response = self.plot_json(bbox, spatial_resolution, timeout)
1✔
383
        vega_spec: VegaSpec = json.loads(response.data["vegaString"])
1✔
384

385
        return VegaLite(vega_spec)
1✔
386

387
    def __request_wcs(
1✔
388
        self,
389
        bbox: QueryRectangle,
390
        timeout=3600,
391
        file_format: str = "image/tiff",
392
        force_no_data_value: float | None = None,
393
        spatial_resolution: SpatialResolution | None = None,
394
    ) -> ResponseWrapper:
395
        """
396
        Query a workflow and return the coverage
397

398
        Parameters
399
        ----------
400
        bbox : A bounding box for the query
401
        timeout : HTTP request timeout in seconds
402
        file_format : The format of the returned raster
403
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
404
            Otherwise, use the Geo Engine will produce masked rasters.
405
        """
406

407
        if not self.__result_descriptor.is_raster_result():
1✔
408
            raise MethodNotCalledOnRasterException()
×
409

410
        session = get_session()
1✔
411

412
        # TODO: properly build CRS string for bbox
413
        crs = f"urn:ogc:def:crs:{bbox.srs.replace(':', '::')}"
1✔
414

415
        wcs_url = f"{session.server_url}/wcs/{self.__workflow_id}"
1✔
416
        wcs = WebCoverageService(
1✔
417
            wcs_url,
418
            version="1.1.1",
419
            auth=Authentication(auth_delegate=session.requests_bearer_auth()),
420
        )
421

422
        resx = None
1✔
423
        resy = None
1✔
424
        if spatial_resolution is not None:
1✔
425
            [resx, resy] = spatial_resolution.resolution_ogc(bbox.srs)
1✔
426

427
        kwargs = {}
1✔
428

429
        # TODO: allow subset of bands from RasterQueryRectangle
430
        if force_no_data_value is not None:
1✔
431
            kwargs["nodatavalue"] = str(float(force_no_data_value))
1✔
432
        if resx is not None:
1✔
433
            kwargs["resx"] = str(resx)
1✔
434
        if resy is not None:
1✔
435
            kwargs["resy"] = str(resy)
1✔
436

437
        return wcs.getCoverage(
1✔
438
            identifier=f"{self.__workflow_id}",
439
            bbox=bbox.bbox_ogc,
440
            time=[bbox.time_str],
441
            format=file_format,
442
            crs=crs,
443
            timeout=timeout,
444
            **kwargs,
445
        )
446

447
    def __get_wcs_tiff_as_memory_file(
1✔
448
        self,
449
        bbox: QueryRectangle,
450
        timeout=3600,
451
        force_no_data_value: float | None = None,
452
        spatial_resolution: SpatialResolution | None = None,
453
    ) -> rasterio.io.MemoryFile:
454
        """
455
        Query a workflow and return the raster result as a memory mapped GeoTiff
456

457
        Parameters
458
        ----------
459
        bbox : A bounding box for the query
460
        timeout : HTTP request timeout in seconds
461
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
462
            Otherwise, use the Geo Engine will produce masked rasters.
463
        """
464

465
        response = self.__request_wcs(bbox, timeout, "image/tiff", force_no_data_value, spatial_resolution).read()
1✔
466

467
        # response is checked via `raise_on_error` in `getCoverage` / `openUrl`
468

469
        memory_file = rasterio.io.MemoryFile(response)
1✔
470

471
        return memory_file
1✔
472

473
    def get_array(
1✔
474
        self,
475
        bbox: QueryRectangle,
476
        spatial_resolution: SpatialResolution | None = None,
477
        timeout=3600,
478
        force_no_data_value: float | None = None,
479
    ) -> np.ndarray:
480
        """
481
        Query a workflow and return the raster result as a numpy array
482

483
        Parameters
484
        ----------
485
        bbox : A bounding box for the query
486
        timeout : HTTP request timeout in seconds
487
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
488
            Otherwise, use the Geo Engine will produce masked rasters.
489
        """
490

491
        with (
1✔
492
            self.__get_wcs_tiff_as_memory_file(bbox, timeout, force_no_data_value, spatial_resolution) as memfile,
493
            memfile.open() as dataset,
494
        ):
495
            array = dataset.read(1)
1✔
496

497
            return array
1✔
498

499
    def get_xarray(
1✔
500
        self,
501
        bbox: QueryRectangle,
502
        spatial_resolution: SpatialResolution | None = None,
503
        timeout=3600,
504
        force_no_data_value: float | None = None,
505
    ) -> xr.DataArray:
506
        """
507
        Query a workflow and return the raster result as a georeferenced xarray
508

509
        Parameters
510
        ----------
511
        bbox : A bounding box for the query
512
        timeout : HTTP request timeout in seconds
513
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
514
            Otherwise, use the Geo Engine will produce masked rasters.
515
        """
516

517
        with (
1✔
518
            self.__get_wcs_tiff_as_memory_file(bbox, timeout, force_no_data_value, spatial_resolution) as memfile,
519
            memfile.open() as dataset,
520
        ):
521
            data_array = rioxarray.open_rasterio(dataset)
1✔
522

523
            # helping mypy with inference
524
            assert isinstance(data_array, xr.DataArray)
1✔
525

526
            rio: xr.DataArray = data_array.rio
1✔
527
            rio.update_attrs(
1✔
528
                {
529
                    "crs": rio.crs,
530
                    "res": rio.resolution(),
531
                    "transform": rio.transform(),
532
                },
533
                inplace=True,
534
            )
535

536
            # TODO: add time information to dataset
537
            return data_array.load()
1✔
538

539
    # pylint: disable=too-many-arguments,too-many-positional-arguments
540
    def download_raster(
1✔
541
        self,
542
        bbox: QueryRectangle,
543
        file_path: str,
544
        timeout=3600,
545
        file_format: str = "image/tiff",
546
        force_no_data_value: float | None = None,
547
        spatial_resolution: SpatialResolution | None = None,
548
    ) -> None:
549
        """
550
        Query a workflow and save the raster result as a file on disk
551

552
        Parameters
553
        ----------
554
        bbox : A bounding box for the query
555
        file_path : The path to the file to save the raster to
556
        timeout : HTTP request timeout in seconds
557
        file_format : The format of the returned raster
558
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
559
            Otherwise, use the Geo Engine will produce masked rasters.
560
        """
561

562
        response = self.__request_wcs(bbox, timeout, file_format, force_no_data_value, spatial_resolution)
1✔
563

564
        with open(file_path, "wb") as file:
1✔
565
            file.write(response.read())
1✔
566

567
    def get_provenance(self, timeout: int = 60) -> list[ProvenanceEntry]:
1✔
568
        """
569
        Query the provenance of the workflow
570
        """
571

572
        session = get_session()
1✔
573

574
        with geoc.ApiClient(session.configuration) as api_client:
1✔
575
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
576
            response = workflows_api.get_workflow_provenance_handler(
1✔
577
                self.__workflow_id.to_dict(), _request_timeout=timeout
578
            )
579

580
        return [ProvenanceEntry.from_response(item) for item in response]
1✔
581

582
    def metadata_zip(self, path: PathLike | BytesIO, timeout: int = 60) -> None:
1✔
583
        """
584
        Query workflow metadata and citations and stores it as zip file to `path`
585
        """
586

587
        session = get_session()
1✔
588

589
        with geoc.ApiClient(session.configuration) as api_client:
1✔
590
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
591
            response = workflows_api.get_workflow_all_metadata_zip_handler(
1✔
592
                self.__workflow_id.to_dict(), _request_timeout=timeout
593
            )
594

595
        if isinstance(path, BytesIO):
1✔
596
            path.write(response)
1✔
597
        else:
598
            with open(path, "wb") as file:
×
599
                file.write(response)
×
600

601
    # pylint: disable=too-many-positional-arguments,too-many-positional-arguments
602
    def save_as_dataset(
1✔
603
        self,
604
        query_rectangle: QueryRectangle,
605
        name: None | str,
606
        display_name: str,
607
        description: str = "",
608
        timeout: int = 3600,
609
    ) -> Task:
610
        """Init task to store the workflow result as a layer"""
611

612
        # Currently, it only works for raster results
613
        if not self.__result_descriptor.is_raster_result():
1✔
614
            raise MethodNotCalledOnRasterException()
×
615

616
        session = get_session()
1✔
617

618
        if not isinstance(query_rectangle, QueryRectangle):
1✔
619
            print("save_as_dataset ignores params other then spatial and tmporal bounds.")
×
620

621
        qrect = geoc.models.raster_to_dataset_query_rectangle.RasterToDatasetQueryRectangle(
1✔
622
            spatial_bounds=SpatialPartition2D.from_bounding_box(query_rectangle.spatial_bounds).to_api_dict(),
623
            time_interval=query_rectangle.time.to_api_dict(),
624
        )
625

626
        with geoc.ApiClient(session.configuration) as api_client:
1✔
627
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
628
            response = workflows_api.dataset_from_workflow_handler(
1✔
629
                self.__workflow_id.to_dict(),
630
                geoc.RasterDatasetFromWorkflow(
631
                    name=name, display_name=display_name, description=description, query=qrect
632
                ),
633
                _request_timeout=timeout,
634
            )
635

636
        return Task(TaskId.from_response(response))
1✔
637

638
    async def raster_stream(
1✔
639
        self,
640
        query_rectangle: QueryRectangle | RasterQueryRectangle,
641
        open_timeout: int = 60,
642
    ) -> AsyncIterator[RasterTile2D]:
643
        """Stream the workflow result as series of RasterTile2D (transformable to numpy and xarray)"""
644

645
        # Currently, it only works for raster results
646
        if not self.__result_descriptor.is_raster_result():
1✔
647
            raise MethodNotCalledOnRasterException()
×
648

649
        result_descriptor = cast(RasterResultDescriptor, self.__result_descriptor)
1✔
650

651
        if not isinstance(query_rectangle, RasterQueryRectangle):
1✔
652
            query_rectangle = query_rectangle.with_raster_bands(
1✔
653
                # TODO: all bands or first band?
654
                list(range(0, len(result_descriptor.bands)))
655
            )
656

657
        session = get_session()
1✔
658

659
        url = (
1✔
660
            req.Request(
661
                "GET",
662
                url=f"{session.server_url}/workflow/{self.__workflow_id}/rasterStream",
663
                params={
664
                    "resultType": "arrow",
665
                    "spatialBounds": query_rectangle.bbox_str,
666
                    "timeInterval": query_rectangle.time_str,
667
                    "attributes": ",".join(map(str, query_rectangle.raster_bands)),
668
                },
669
            )
670
            .prepare()
671
            .url
672
        )
673

674
        if url is None:
1✔
675
            raise InputException("Invalid websocket url")
×
676

677
        async with websockets.asyncio.client.connect(
1✔
678
            uri=self.__replace_http_with_ws(url),
679
            additional_headers=session.auth_header,
680
            open_timeout=open_timeout,
681
            max_size=None,
682
        ) as websocket:
683
            tile_bytes: bytes | None = None
1✔
684

685
            while websocket.state == websockets.protocol.State.OPEN:
1✔
686

687
                async def read_new_bytes() -> bytes | None:
1✔
688
                    # already send the next request to speed up the process
689
                    try:
1✔
690
                        await websocket.send("NEXT")
1✔
691
                    except websockets.exceptions.ConnectionClosed:
×
692
                        # the websocket connection is already closed, we cannot read anymore
693
                        return None
×
694

695
                    try:
1✔
696
                        data: str | bytes = await websocket.recv()
1✔
697

698
                        if isinstance(data, str):
1✔
699
                            # the server sent an error message
700
                            raise GeoEngineException({"error": data})
×
701

702
                        return data
1✔
703
                    except websockets.exceptions.ConnectionClosedOK:
×
704
                        # the websocket connection closed gracefully, so we stop reading
705
                        return None
×
706

707
                (tile_bytes, tile) = await asyncio.gather(
1✔
708
                    read_new_bytes(),
709
                    # asyncio.to_thread(process_bytes, tile_bytes), # TODO: use this when min Python version is 3.9
710
                    backports.to_thread(RasterStreamProcessing.process_bytes, tile_bytes),
711
                )
712

713
                if tile is not None:
1✔
714
                    yield tile
1✔
715

716
            # process the last tile
717
            tile = RasterStreamProcessing.process_bytes(tile_bytes)
1✔
718

719
            if tile is not None:
1✔
720
                yield tile
1✔
721

722
    async def raster_stream_into_xarray(
1✔
723
        self,
724
        query_rectangle: RasterQueryRectangle,
725
        clip_to_query_rectangle: bool = False,
726
        open_timeout: int = 60,
727
    ) -> xr.DataArray:
728
        """
729
        Stream the workflow result into memory and output a single xarray.
730

731
        NOTE: You can run out of memory if the query rectangle is too large.
732
        """
733

734
        tile_stream = self.raster_stream(query_rectangle, open_timeout=open_timeout)
1✔
735

736
        timestep_xarrays: list[xr.DataArray] = []
1✔
737

738
        spatial_clip_bounds = query_rectangle.spatial_bounds if clip_to_query_rectangle else None
1✔
739

740
        async def read_tiles(
1✔
741
            remainder_tile: RasterTile2D | None,
742
        ) -> tuple[list[xr.DataArray], RasterTile2D | None]:
743
            last_timestep: np.datetime64 | None = None
1✔
744
            tiles = []
1✔
745

746
            if remainder_tile is not None:
1✔
747
                last_timestep = remainder_tile.time_start_ms
1✔
748
                xr_tile = remainder_tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
749
                tiles.append(xr_tile)
1✔
750

751
            async for tile in tile_stream:
1✔
752
                timestep: np.datetime64 = tile.time_start_ms
1✔
753
                if last_timestep is None:
1✔
754
                    last_timestep = timestep
1✔
755
                elif last_timestep != timestep:
1✔
756
                    return tiles, tile
1✔
757

758
                xr_tile = tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
759
                tiles.append(xr_tile)
1✔
760

761
            # this seems to be the last time step, so just return tiles
762
            return tiles, None
1✔
763

764
        (tiles, remainder_tile) = await read_tiles(None)
1✔
765

766
        while len(tiles):
1✔
767
            ((new_tiles, new_remainder_tile), new_timestep_xarray) = await asyncio.gather(
1✔
768
                read_tiles(remainder_tile),
769
                backports.to_thread(RasterStreamProcessing.merge_tiles, tiles),
770
                # asyncio.to_thread(merge_tiles, tiles), # TODO: use this when min Python version is 3.9
771
            )
772

773
            tiles = new_tiles
1✔
774
            remainder_tile = new_remainder_tile
1✔
775

776
            if new_timestep_xarray is not None:
1✔
777
                timestep_xarrays.append(new_timestep_xarray)
1✔
778

779
        output: xr.DataArray = cast(
1✔
780
            xr.DataArray,
781
            # await asyncio.to_thread( # TODO: use this when min Python version is 3.9
782
            await backports.to_thread(
783
                xr.concat,
784
                # TODO: This is a typings error, since the method accepts also a `xr.DataArray` and returns one
785
                cast(list[xr.Dataset], timestep_xarrays),
786
                dim="time",
787
            ),
788
        )
789

790
        return output
1✔
791

792
    async def vector_stream(
1✔
793
        self,
794
        query_rectangle: QueryRectangle,
795
        time_start_column: str = "time_start",
796
        time_end_column: str = "time_end",
797
        open_timeout: int = 60,
798
    ) -> AsyncIterator[gpd.GeoDataFrame]:
799
        """Stream the workflow result as series of `GeoDataFrame`s"""
800

801
        def read_arrow_ipc(arrow_ipc: bytes) -> pa.RecordBatch:
1✔
802
            reader = pa.ipc.open_file(arrow_ipc)
1✔
803
            # We know from the backend that there is only one record batch
804
            record_batch = reader.get_record_batch(0)
1✔
805
            return record_batch
1✔
806

807
        def create_geo_data_frame(
1✔
808
            record_batch: pa.RecordBatch, time_start_column: str, time_end_column: str
809
        ) -> gpd.GeoDataFrame:
810
            metadata = record_batch.schema.metadata
1✔
811
            spatial_reference = metadata[b"spatialReference"].decode("utf-8")
1✔
812

813
            data_frame = record_batch.to_pandas()
1✔
814

815
            geometry = gpd.GeoSeries.from_wkt(data_frame[api.GEOMETRY_COLUMN_NAME])
1✔
816
            # delete the duplicated column
817
            del data_frame[api.GEOMETRY_COLUMN_NAME]
1✔
818

819
            geo_data_frame = gpd.GeoDataFrame(
1✔
820
                data_frame,
821
                geometry=geometry,
822
                crs=spatial_reference,
823
            )
824

825
            # split time column
826
            geo_data_frame[[time_start_column, time_end_column]] = geo_data_frame[api.TIME_COLUMN_NAME].tolist()
1✔
827
            # delete the duplicated column
828
            del geo_data_frame[api.TIME_COLUMN_NAME]
1✔
829

830
            # parse time columns
831
            for time_column in [time_start_column, time_end_column]:
1✔
832
                geo_data_frame[time_column] = pd.to_datetime(
1✔
833
                    geo_data_frame[time_column],
834
                    utc=True,
835
                    unit="ms",
836
                    # TODO: solve time conversion problem from Geo Engine to Python for large (+/-) time instances
837
                    errors="coerce",
838
                )
839

840
            return geo_data_frame
1✔
841

842
        def process_bytes(batch_bytes: bytes | None) -> gpd.GeoDataFrame | None:
1✔
843
            if batch_bytes is None:
1✔
844
                return None
1✔
845

846
            # process the received data
847
            record_batch = read_arrow_ipc(batch_bytes)
1✔
848
            tile = create_geo_data_frame(
1✔
849
                record_batch,
850
                time_start_column=time_start_column,
851
                time_end_column=time_end_column,
852
            )
853

854
            return tile
1✔
855

856
        # Currently, it only works for raster results
857
        if not self.__result_descriptor.is_vector_result():
1✔
858
            raise MethodNotCalledOnVectorException()
×
859

860
        session = get_session()
1✔
861

862
        params = {
1✔
863
            "resultType": "arrow",
864
            "spatialBounds": query_rectangle.bbox_str,
865
            "timeInterval": query_rectangle.time_str,
866
        }
867

868
        url = (
1✔
869
            req.Request("GET", url=f"{session.server_url}/workflow/{self.__workflow_id}/vectorStream", params=params)
870
            .prepare()
871
            .url
872
        )
873

874
        if url is None:
1✔
875
            raise InputException("Invalid websocket url")
×
876

877
        async with websockets.asyncio.client.connect(
1✔
878
            uri=self.__replace_http_with_ws(url),
879
            additional_headers=session.auth_header,
880
            open_timeout=open_timeout,
881
            max_size=None,  # allow arbitrary large messages, since it is capped by the server's chunk size
882
        ) as websocket:
883
            batch_bytes: bytes | None = None
1✔
884

885
            while websocket.state == websockets.protocol.State.OPEN:
1✔
886

887
                async def read_new_bytes() -> bytes | None:
1✔
888
                    # already send the next request to speed up the process
889
                    try:
1✔
890
                        await websocket.send("NEXT")
1✔
891
                    except websockets.exceptions.ConnectionClosed:
×
892
                        # the websocket connection is already closed, we cannot read anymore
893
                        return None
×
894

895
                    try:
1✔
896
                        data: str | bytes = await websocket.recv()
1✔
897

898
                        if isinstance(data, str):
1✔
899
                            # the server sent an error message
900
                            raise GeoEngineException({"error": data})
×
901

902
                        return data
1✔
903
                    except websockets.exceptions.ConnectionClosedOK:
×
904
                        # the websocket connection closed gracefully, so we stop reading
905
                        return None
×
906

907
                (batch_bytes, batch) = await asyncio.gather(
1✔
908
                    read_new_bytes(),
909
                    # asyncio.to_thread(process_bytes, batch_bytes), # TODO: use this when min Python version is 3.9
910
                    backports.to_thread(process_bytes, batch_bytes),
911
                )
912

913
                if batch is not None:
1✔
914
                    yield batch
1✔
915

916
            # process the last tile
917
            batch = process_bytes(batch_bytes)
1✔
918

919
            if batch is not None:
1✔
920
                yield batch
1✔
921

922
    async def vector_stream_into_geopandas(
1✔
923
        self,
924
        query_rectangle: QueryRectangle,
925
        time_start_column: str = "time_start",
926
        time_end_column: str = "time_end",
927
        open_timeout: int = 60,
928
    ) -> gpd.GeoDataFrame:
929
        """
930
        Stream the workflow result into memory and output a single geo data frame.
931

932
        NOTE: You can run out of memory if the query rectangle is too large.
933
        """
934

935
        chunk_stream = self.vector_stream(
1✔
936
            query_rectangle,
937
            time_start_column=time_start_column,
938
            time_end_column=time_end_column,
939
            open_timeout=open_timeout,
940
        )
941

942
        data_frame: gpd.GeoDataFrame | None = None
1✔
943
        chunk: gpd.GeoDataFrame | None = None
1✔
944

945
        async def read_dataframe() -> gpd.GeoDataFrame | None:
1✔
946
            try:
1✔
947
                return await chunk_stream.__anext__()
1✔
948
            except StopAsyncIteration:
1✔
949
                return None
1✔
950

951
        def merge_dataframes(df_a: gpd.GeoDataFrame | None, df_b: gpd.GeoDataFrame | None) -> gpd.GeoDataFrame | None:
1✔
952
            if df_a is None:
1✔
953
                return df_b
1✔
954

955
            if df_b is None:
1✔
956
                return df_a
×
957

958
            return pd.concat([df_a, df_b], ignore_index=True)
1✔
959

960
        while True:
1✔
961
            (chunk, data_frame) = await asyncio.gather(
1✔
962
                read_dataframe(),
963
                backports.to_thread(merge_dataframes, data_frame, chunk),
964
                # TODO: use this when min Python version is 3.9
965
                # asyncio.to_thread(merge_dataframes, data_frame, chunk),
966
            )
967

968
            # we can stop when the chunk stream is exhausted
969
            if chunk is None:
1✔
970
                break
1✔
971

972
        return data_frame
1✔
973

974
    def __replace_http_with_ws(self, url: str) -> str:
1✔
975
        """
976
        Replace the protocol in the url from `http` to `ws`.
977

978
        For the websockets library, it is necessary that the url starts with `ws://`.
979
        For HTTPS, we need to use `wss://` instead.
980
        """
981

982
        [protocol, url_part] = url.split("://", maxsplit=1)
1✔
983

984
        ws_prefix = "wss://" if "s" in protocol.lower() else "ws://"
1✔
985

986
        return f"{ws_prefix}{url_part}"
1✔
987

988

989
def register_workflow(workflow: dict[str, Any] | WorkflowBuilderOperator, timeout: int = 60) -> Workflow:
1✔
990
    """
991
    Register a workflow in Geo Engine and receive a `WorkflowId`
992
    """
993

994
    if isinstance(workflow, WorkflowBuilderOperator):
1✔
995
        workflow = workflow.to_workflow_dict()
1✔
996

997
    workflow_model = geoc.Workflow.from_dict(workflow)
1✔
998

999
    if workflow_model is None:
1✔
1000
        raise InputException("Invalid workflow definition")
×
1001

1002
    session = get_session()
1✔
1003

1004
    with geoc.ApiClient(session.configuration) as api_client:
1✔
1005
        workflows_api = geoc.WorkflowsApi(api_client)
1✔
1006
        response = workflows_api.register_workflow_handler(workflow_model, _request_timeout=timeout)
1✔
1007

1008
    return Workflow(WorkflowId.from_response(response))
1✔
1009

1010

1011
def workflow_by_id(workflow_id: UUID | str) -> Workflow:
1✔
1012
    """
1013
    Create a workflow object from a workflow id
1014
    """
1015

1016
    # TODO: check that workflow exists
1017

1018
    return Workflow(WorkflowId(workflow_id))
1✔
1019

1020

1021
def get_quota(user_id: UUID | None = None, timeout: int = 60) -> geoc.Quota:
1✔
1022
    """
1023
    Gets a user's quota. Only admins can get other users' quota.
1024
    """
1025

1026
    session = get_session()
1✔
1027

1028
    with geoc.ApiClient(session.configuration) as api_client:
1✔
1029
        user_api = geoc.UserApi(api_client)
1✔
1030

1031
        if user_id is None:
1✔
1032
            return user_api.quota_handler(_request_timeout=timeout)
1✔
1033

1034
        return user_api.get_user_quota_handler(user_id, _request_timeout=timeout)
1✔
1035

1036

1037
def update_quota(user_id: UUID, new_available_quota: int, timeout: int = 60) -> None:
1✔
1038
    """
1039
    Update a user's quota. Only admins can perform this operation.
1040
    """
1041

1042
    session = get_session()
1✔
1043

1044
    with geoc.ApiClient(session.configuration) as api_client:
1✔
1045
        user_api = geoc.UserApi(api_client)
1✔
1046
        user_api.update_user_quota_handler(
1✔
1047
            user_id, geoc.UpdateQuota(available=new_available_quota), _request_timeout=timeout
1048
        )
1049

1050

1051
def data_usage(offset: int = 0, limit: int = 10) -> list[geoc.DataUsage]:
1✔
1052
    """
1053
    Get data usage
1054
    """
1055

1056
    session = get_session()
1✔
1057

1058
    with geoc.ApiClient(session.configuration) as api_client:
1✔
1059
        user_api = geoc.UserApi(api_client)
1✔
1060
        response = user_api.data_usage_handler(offset=offset, limit=limit)
1✔
1061

1062
        # create dataframe from response
1063
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1064
        df = pd.DataFrame(usage_dicts)
1✔
1065
        if "timestamp" in df.columns:
1✔
1066
            df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)
1✔
1067

1068
    return df
1✔
1069

1070

1071
def data_usage_summary(
1✔
1072
    granularity: geoc.UsageSummaryGranularity, dataset: str | None = None, offset: int = 0, limit: int = 10
1073
) -> pd.DataFrame:
1074
    """
1075
    Get data usage summary
1076
    """
1077

1078
    session = get_session()
1✔
1079

1080
    with geoc.ApiClient(session.configuration) as api_client:
1✔
1081
        user_api = geoc.UserApi(api_client)
1✔
1082
        response = user_api.data_usage_summary_handler(
1✔
1083
            dataset=dataset, granularity=granularity, offset=offset, limit=limit
1084
        )
1085

1086
        # create dataframe from response
1087
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1088
        df = pd.DataFrame(usage_dicts)
1✔
1089
        if "timestamp" in df.columns:
1✔
1090
            df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)
1✔
1091

1092
    return df
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc