• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine-python / 17064361033

19 Aug 2025 08:43AM UTC coverage: 76.088% (-0.9%) from 76.961%
17064361033

Pull #221

github

web-flow
Merge 78613fd6d into 798243b77
Pull Request #221: Pixel_based_queries_rewrite

2921 of 3839 relevant lines covered (76.09%)

0.76 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.59
geoengine/workflow.py
1
"""
2
A workflow representation and methods on workflows
3
"""
4
# pylint: disable=too-many-lines
5
# TODO: split into multiple files
6

7
from __future__ import annotations
1✔
8

9
import asyncio
1✔
10
import json
1✔
11
from collections import defaultdict
1✔
12
from collections.abc import AsyncIterator
1✔
13
from io import BytesIO
1✔
14
from logging import debug
1✔
15
from os import PathLike
1✔
16
from typing import Any, TypedDict, cast
1✔
17
from uuid import UUID
1✔
18

19
import geoengine_openapi_client as geoc
1✔
20
import geopandas as gpd
1✔
21
import numpy as np
1✔
22
import pandas as pd
1✔
23
import pyarrow as pa
1✔
24
import rasterio.io
1✔
25
import requests as req
1✔
26
import rioxarray
1✔
27
import websockets
1✔
28
import websockets.asyncio.client
1✔
29
import xarray as xr
1✔
30
from owslib.util import Authentication, ResponseWrapper
1✔
31
from owslib.wcs import WebCoverageService
1✔
32
from PIL import Image
1✔
33
from vega import VegaLite
1✔
34

35
from geoengine import api, backports
1✔
36
from geoengine.auth import get_session
1✔
37
from geoengine.error import (
1✔
38
    GeoEngineException,
39
    InputException,
40
    MethodNotCalledOnPlotException,
41
    MethodNotCalledOnRasterException,
42
    MethodNotCalledOnVectorException,
43
    OGCXMLError,
44
)
45
from geoengine.raster import RasterTile2D
1✔
46
from geoengine.tasks import Task, TaskId
1✔
47
from geoengine.types import (
1✔
48
    ClassificationMeasurement,
49
    ProvenanceEntry,
50
    QueryRectangle,
51
    RasterColorizer,
52
    RasterQueryRectangle,
53
    RasterResultDescriptor,
54
    ResultDescriptor,
55
    SpatialPartition2D,
56
    SpatialResolution,
57
    VectorResultDescriptor,
58
)
59
from geoengine.workflow_builder.operators import Operator as WorkflowBuilderOperator
1✔
60

61
# TODO: Define as recursive type when supported in mypy: https://github.com/python/mypy/issues/731
62
JsonType = dict[str, Any] | list[Any] | int | str | float | bool | type[None]
1✔
63

64

65
class Axis(TypedDict):
1✔
66
    title: str
1✔
67

68

69
class Bin(TypedDict):
1✔
70
    binned: bool
1✔
71
    step: float
1✔
72

73

74
class Field(TypedDict):
1✔
75
    field: str
1✔
76

77

78
class DatasetIds(TypedDict):
1✔
79
    upload: UUID
1✔
80
    dataset: UUID
1✔
81

82

83
class Values(TypedDict):
1✔
84
    binStart: float
1✔
85
    binEnd: float
1✔
86
    Frequency: int
1✔
87

88

89
class X(TypedDict):
1✔
90
    field: Field
1✔
91
    bin: Bin
1✔
92
    axis: Axis
1✔
93

94

95
class X2(TypedDict):
1✔
96
    field: Field
1✔
97

98

99
class Y(TypedDict):
1✔
100
    field: Field
1✔
101
    type: str
1✔
102

103

104
class Encoding(TypedDict):
1✔
105
    x: X
1✔
106
    x2: X2
1✔
107
    y: Y
1✔
108

109

110
VegaSpec = TypedDict("VegaSpec", {"$schema": str, "data": list[Values], "mark": str, "encoding": Encoding})
1✔
111

112

113
class WorkflowId:
1✔
114
    """
115
    A wrapper around a workflow UUID
116
    """
117

118
    __workflow_id: UUID
1✔
119

120
    def __init__(self, workflow_id: UUID | str) -> None:
1✔
121
        """Create a new WorkflowId from an UUID or uuid as str"""
122

123
        if not isinstance(workflow_id, UUID):
1✔
124
            workflow_id = UUID(workflow_id)
1✔
125

126
        self.__workflow_id = workflow_id
1✔
127

128
    @classmethod
1✔
129
    def from_response(cls, response: geoc.IdResponse) -> WorkflowId:
1✔
130
        """
131
        Create a `WorkflowId` from an http response
132
        """
133
        return WorkflowId(UUID(response.id))
1✔
134

135
    def __str__(self) -> str:
1✔
136
        return str(self.__workflow_id)
1✔
137

138
    def __repr__(self) -> str:
1✔
139
        return str(self)
1✔
140

141

142
class RasterStreamProcessing:
1✔
143
    """
144
    Helper class to process raster stream data
145
    """
146

147
    @classmethod
1✔
148
    def read_arrow_ipc(cls, arrow_ipc: bytes) -> pa.RecordBatch:
1✔
149
        """Read an Arrow IPC file from a byte array"""
150

151
        reader = pa.ipc.open_file(arrow_ipc)
1✔
152
        # We know from the backend that there is only one record batch
153
        record_batch = reader.get_record_batch(0)
1✔
154
        return record_batch
1✔
155

156
    @classmethod
1✔
157
    def process_bytes(cls, tile_bytes: bytes | None) -> RasterTile2D | None:
1✔
158
        """Process a tile from a byte array"""
159

160
        if tile_bytes is None:
1✔
161
            return None
1✔
162

163
        # process the received data
164
        record_batch = RasterStreamProcessing.read_arrow_ipc(tile_bytes)
1✔
165
        tile = RasterTile2D.from_ge_record_batch(record_batch)
1✔
166

167
        return tile
1✔
168

169
    @classmethod
1✔
170
    def merge_tiles(cls, tiles: list[xr.DataArray]) -> xr.DataArray | None:
1✔
171
        """Merge a list of tiles into a single xarray"""
172

173
        if len(tiles) == 0:
1✔
174
            return None
×
175

176
        # group the tiles by band
177
        tiles_by_band: dict[int, list[xr.DataArray]] = defaultdict(list)
1✔
178
        for tile in tiles:
1✔
179
            band = tile.band.item()  # assuming 'band' is a coordinate with a single value
1✔
180
            tiles_by_band[band].append(tile)
1✔
181

182
        # build one spatial tile per band
183
        combined_by_band = []
1✔
184
        for band_tiles in tiles_by_band.values():
1✔
185
            combined = xr.combine_by_coords(band_tiles)
1✔
186
            # `combine_by_coords` always returns a `DataArray` for single variable input arrays.
187
            # This assertion verifies this for mypy
188
            assert isinstance(combined, xr.DataArray)
1✔
189
            combined_by_band.append(combined)
1✔
190

191
        # build one array with all bands and geo coordinates
192
        combined_tile = xr.concat(combined_by_band, dim="band")
1✔
193

194
        return combined_tile
1✔
195

196

197
class Workflow:
1✔
198
    """
199
    Holds a workflow id and allows querying data
200
    """
201

202
    __workflow_id: WorkflowId
1✔
203
    __result_descriptor: ResultDescriptor
1✔
204

205
    def __init__(self, workflow_id: WorkflowId) -> None:
1✔
206
        self.__workflow_id = workflow_id
1✔
207
        self.__result_descriptor = self.__query_result_descriptor()
1✔
208

209
    def __str__(self) -> str:
1✔
210
        return str(self.__workflow_id)
1✔
211

212
    def __repr__(self) -> str:
1✔
213
        return repr(self.__workflow_id)
1✔
214

215
    def __query_result_descriptor(self, timeout: int = 60) -> ResultDescriptor:
1✔
216
        """
217
        Query the metadata of the workflow result
218
        """
219

220
        session = get_session()
1✔
221

222
        with geoc.ApiClient(session.configuration) as api_client:
1✔
223
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
224
            response = workflows_api.get_workflow_metadata_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
225

226
        debug(response)
1✔
227

228
        return ResultDescriptor.from_response(response)
1✔
229

230
    def get_result_descriptor(self) -> ResultDescriptor:
1✔
231
        """
232
        Return the metadata of the workflow result
233
        """
234

235
        return self.__result_descriptor
1✔
236

237
    def workflow_definition(self, timeout: int = 60) -> geoc.Workflow:
1✔
238
        """Return the workflow definition for this workflow"""
239

240
        session = get_session()
1✔
241

242
        with geoc.ApiClient(session.configuration) as api_client:
1✔
243
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
244
            response = workflows_api.load_workflow_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
245

246
        return response
1✔
247

248
    def get_dataframe(
1✔
249
        self, bbox: QueryRectangle, timeout: int = 3600, resolve_classifications: bool = False
250
    ) -> gpd.GeoDataFrame:
251
        """
252
        Query a workflow and return the WFS result as a GeoPandas `GeoDataFrame`
253
        """
254

255
        if not self.__result_descriptor.is_vector_result():
1✔
256
            raise MethodNotCalledOnVectorException()
1✔
257

258
        session = get_session()
1✔
259

260
        with geoc.ApiClient(session.configuration) as api_client:
1✔
261
            wfs_api = geoc.OGCWFSApi(api_client)
1✔
262
            response = wfs_api.wfs_feature_handler(
1✔
263
                workflow=str(self.__workflow_id),
264
                service=geoc.WfsService(geoc.WfsService.WFS),
265
                request=geoc.GetFeatureRequest(geoc.GetFeatureRequest.GETFEATURE),
266
                type_names=str(self.__workflow_id),
267
                bbox=bbox.bbox_str,
268
                version=geoc.WfsVersion(geoc.WfsVersion.ENUM_2_DOT_0_DOT_0),
269
                time=bbox.time_str,
270
                srs_name=bbox.srs,
271
                _request_timeout=timeout,
272
            )
273

274
        def geo_json_with_time_to_geopandas(geo_json):
1✔
275
            """
276
            GeoJson has no standard for time, so we parse the when field
277
            separately and attach it to the data frame as columns `start`
278
            and `end`.
279
            """
280

281
            data = gpd.GeoDataFrame.from_features(geo_json)
1✔
282
            data = data.set_crs(bbox.srs, allow_override=True)
1✔
283

284
            start = [f["when"]["start"] for f in geo_json["features"]]
1✔
285
            end = [f["when"]["end"] for f in geo_json["features"]]
1✔
286

287
            # TODO: find a good way to infer BoT/EoT
288

289
            data["start"] = gpd.pd.to_datetime(start, errors="coerce")
1✔
290
            data["end"] = gpd.pd.to_datetime(end, errors="coerce")
1✔
291

292
            return data
1✔
293

294
        def transform_classifications(data: gpd.GeoDataFrame):
1✔
295
            result_descriptor: VectorResultDescriptor = self.__result_descriptor  # type: ignore
×
296
            for column, info in result_descriptor.columns.items():
×
297
                if isinstance(info.measurement, ClassificationMeasurement):
×
298
                    measurement: ClassificationMeasurement = info.measurement
×
299
                    classes = measurement.classes
×
300
                    data[column] = data[column].apply(lambda x, classes=classes: classes[x])  # pylint: disable=cell-var-from-loop
×
301

302
            return data
×
303

304
        result = geo_json_with_time_to_geopandas(response.to_dict())
1✔
305

306
        if resolve_classifications:
1✔
307
            result = transform_classifications(result)
×
308

309
        return result
1✔
310

311
    def wms_get_map_as_image(
1✔
312
        self,
313
        bbox: QueryRectangle,
314
        raster_colorizer: RasterColorizer,
315
        # TODO: allow to use width height
316
        spatial_resolution: SpatialResolution,
317
    ) -> Image.Image:
318
        """Return the result of a WMS request as a PIL Image"""
319

320
        if not self.__result_descriptor.is_raster_result():
1✔
321
            raise MethodNotCalledOnRasterException()
×
322

323
        session = get_session()
1✔
324

325
        with geoc.ApiClient(session.configuration) as api_client:
1✔
326
            wms_api = geoc.OGCWMSApi(api_client)
1✔
327
            response = wms_api.wms_map_handler(
1✔
328
                workflow=str(self),
329
                version=geoc.WmsVersion(geoc.WmsVersion.ENUM_1_DOT_3_DOT_0),
330
                service=geoc.WmsService(geoc.WmsService.WMS),
331
                request=geoc.GetMapRequest(geoc.GetMapRequest.GETMAP),
332
                width=int((bbox.spatial_bounds.xmax - bbox.spatial_bounds.xmin) / spatial_resolution.x_resolution),
333
                height=int((bbox.spatial_bounds.ymax - bbox.spatial_bounds.ymin) / spatial_resolution.y_resolution),  # pylint: disable=line-too-long
334
                bbox=bbox.bbox_ogc_str,
335
                format=geoc.GetMapFormat(geoc.GetMapFormat.IMAGE_SLASH_PNG),
336
                layers=str(self),
337
                styles="custom:" + raster_colorizer.to_api_dict().to_json(),
338
                crs=bbox.srs,
339
                time=bbox.time_str,
340
            )
341

342
        if OGCXMLError.is_ogc_error(response):
1✔
343
            raise OGCXMLError(response)
1✔
344

345
        return Image.open(BytesIO(response))
1✔
346

347
    def plot_json(
1✔
348
        self, bbox: QueryRectangle, spatial_resolution: SpatialResolution | None = None, timeout: int = 3600
349
    ) -> geoc.WrappedPlotOutput:
350
        """
351
        Query a workflow and return the plot chart result as WrappedPlotOutput
352
        """
353

354
        if not self.__result_descriptor.is_plot_result():
1✔
355
            raise MethodNotCalledOnPlotException()
×
356

357
        session = get_session()
1✔
358

359
        with geoc.ApiClient(session.configuration) as api_client:
1✔
360
            plots_api = geoc.PlotsApi(api_client)
1✔
361
            return plots_api.get_plot_handler(
1✔
362
                bbox.bbox_str,
363
                bbox.time_str,
364
                # TODO: why does it need a resolution?
365
                str(spatial_resolution),
366
                str(self.__workflow_id),
367
                bbox.srs,
368
                _request_timeout=timeout,
369
            )
370

371
    def plot_chart(
1✔
372
        self, bbox: QueryRectangle, spatial_resolution: SpatialResolution | None = None, timeout: int = 3600
373
    ) -> VegaLite:
374
        """
375
        Query a workflow and return the plot chart result as a vega plot
376
        """
377

378
        response = self.plot_json(bbox, spatial_resolution, timeout)
1✔
379
        vega_spec: VegaSpec = json.loads(response.data["vegaString"])
1✔
380

381
        return VegaLite(vega_spec)
1✔
382

383
    def __request_wcs(
1✔
384
        self,
385
        bbox: QueryRectangle,
386
        timeout=3600,
387
        file_format: str = "image/tiff",
388
        force_no_data_value: float | None = None,
389
        spatial_resolution: SpatialResolution | None = None,
390
    ) -> ResponseWrapper:
391
        """
392
        Query a workflow and return the coverage
393

394
        Parameters
395
        ----------
396
        bbox : A bounding box for the query
397
        timeout : HTTP request timeout in seconds
398
        file_format : The format of the returned raster
399
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
400
            Otherwise, use the Geo Engine will produce masked rasters.
401
        """
402

403
        if not self.__result_descriptor.is_raster_result():
1✔
404
            raise MethodNotCalledOnRasterException()
×
405

406
        session = get_session()
1✔
407

408
        # TODO: properly build CRS string for bbox
409
        crs = f"urn:ogc:def:crs:{bbox.srs.replace(':', '::')}"
1✔
410

411
        wcs_url = f"{session.server_url}/wcs/{self.__workflow_id}"
1✔
412
        wcs = WebCoverageService(
1✔
413
            wcs_url,
414
            version="1.1.1",
415
            auth=Authentication(auth_delegate=session.requests_bearer_auth()),
416
        )
417

418
        resx = None
1✔
419
        resy = None
1✔
420
        if spatial_resolution is not None:
1✔
421
            [resx, resy] = spatial_resolution.resolution_ogc(bbox.srs)
1✔
422

423
        kwargs = {}
1✔
424

425
        # TODO: allow subset of bands from RasterQueryRectangle
426
        if force_no_data_value is not None:
1✔
427
            kwargs["nodatavalue"] = str(float(force_no_data_value))
×
428
        if resx is not None:
1✔
429
            kwargs["resx"] = str(resx)
1✔
430
        if resy is not None:
1✔
431
            kwargs["resy"] = str(resy)
1✔
432

433
        return wcs.getCoverage(
1✔
434
            identifier=f"{self.__workflow_id}",
435
            bbox=bbox.bbox_ogc,
436
            time=[bbox.time_str],
437
            format=file_format,
438
            crs=crs,
439
            timeout=timeout,
440
            **kwargs,
441
        )
442

443
    def __get_wcs_tiff_as_memory_file(
1✔
444
        self,
445
        bbox: QueryRectangle,
446
        timeout=3600,
447
        force_no_data_value: float | None = None,
448
        spatial_resolution: SpatialResolution | None = None,
449
    ) -> rasterio.io.MemoryFile:
450
        """
451
        Query a workflow and return the raster result as a memory mapped GeoTiff
452

453
        Parameters
454
        ----------
455
        bbox : A bounding box for the query
456
        timeout : HTTP request timeout in seconds
457
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
458
            Otherwise, use the Geo Engine will produce masked rasters.
459
        """
460

461
        response = self.__request_wcs(bbox, timeout, "image/tiff", force_no_data_value, spatial_resolution).read()
1✔
462

463
        # response is checked via `raise_on_error` in `getCoverage` / `openUrl`
464

465
        memory_file = rasterio.io.MemoryFile(response)
1✔
466

467
        return memory_file
1✔
468

469
    def get_array(
1✔
470
        self,
471
        bbox: QueryRectangle,
472
        spatial_resolution: SpatialResolution | None = None,
473
        timeout=3600,
474
        force_no_data_value: float | None = None,
475
    ) -> np.ndarray:
476
        """
477
        Query a workflow and return the raster result as a numpy array
478

479
        Parameters
480
        ----------
481
        bbox : A bounding box for the query
482
        timeout : HTTP request timeout in seconds
483
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
484
            Otherwise, use the Geo Engine will produce masked rasters.
485
        """
486

487
        with (
1✔
488
            self.__get_wcs_tiff_as_memory_file(bbox, timeout, force_no_data_value, spatial_resolution) as memfile,
489
            memfile.open() as dataset,
490
        ):
491
            array = dataset.read(1)
1✔
492

493
            return array
1✔
494

495
    def get_xarray(
1✔
496
        self,
497
        bbox: QueryRectangle,
498
        spatial_resolution: SpatialResolution | None = None,
499
        timeout=3600,
500
        force_no_data_value: float | None = None,
501
    ) -> xr.DataArray:
502
        """
503
        Query a workflow and return the raster result as a georeferenced xarray
504

505
        Parameters
506
        ----------
507
        bbox : A bounding box for the query
508
        timeout : HTTP request timeout in seconds
509
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
510
            Otherwise, use the Geo Engine will produce masked rasters.
511
        """
512

513
        with (
1✔
514
            self.__get_wcs_tiff_as_memory_file(bbox, timeout, force_no_data_value, spatial_resolution) as memfile,
515
            memfile.open() as dataset,
516
        ):
517
            data_array = rioxarray.open_rasterio(dataset)
1✔
518

519
            # helping mypy with inference
520
            assert isinstance(data_array, xr.DataArray)
1✔
521

522
            rio: xr.DataArray = data_array.rio
1✔
523
            rio.update_attrs(
1✔
524
                {
525
                    "crs": rio.crs,
526
                    "res": rio.resolution(),
527
                    "transform": rio.transform(),
528
                },
529
                inplace=True,
530
            )
531

532
            # TODO: add time information to dataset
533
            return data_array.load()
1✔
534

535
    # pylint: disable=too-many-arguments,too-many-positional-arguments
536
    def download_raster(
1✔
537
        self,
538
        bbox: QueryRectangle,
539
        file_path: str,
540
        timeout=3600,
541
        file_format: str = "image/tiff",
542
        force_no_data_value: float | None = None,
543
        spatial_resolution: SpatialResolution | None = None,
544
    ) -> None:
545
        """
546
        Query a workflow and save the raster result as a file on disk
547

548
        Parameters
549
        ----------
550
        bbox : A bounding box for the query
551
        file_path : The path to the file to save the raster to
552
        timeout : HTTP request timeout in seconds
553
        file_format : The format of the returned raster
554
        force_no_data_value: If not None, use this value as no data value for the requested raster data. \
555
            Otherwise, use the Geo Engine will produce masked rasters.
556
        """
557

558
        response = self.__request_wcs(bbox, timeout, file_format, force_no_data_value, spatial_resolution)
×
559

560
        with open(file_path, "wb") as file:
×
561
            file.write(response.read())
×
562

563
    def get_provenance(self, timeout: int = 60) -> list[ProvenanceEntry]:
1✔
564
        """
565
        Query the provenance of the workflow
566
        """
567

568
        session = get_session()
1✔
569

570
        with geoc.ApiClient(session.configuration) as api_client:
1✔
571
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
572
            response = workflows_api.get_workflow_provenance_handler(str(self.__workflow_id), _request_timeout=timeout)
1✔
573

574
        return [ProvenanceEntry.from_response(item) for item in response]
1✔
575

576
    def metadata_zip(self, path: PathLike | BytesIO, timeout: int = 60) -> None:
1✔
577
        """
578
        Query workflow metadata and citations and stores it as zip file to `path`
579
        """
580

581
        session = get_session()
×
582

583
        with geoc.ApiClient(session.configuration) as api_client:
×
584
            workflows_api = geoc.WorkflowsApi(api_client)
×
585
            response = workflows_api.get_workflow_all_metadata_zip_handler(
×
586
                str(self.__workflow_id), _request_timeout=timeout
587
            )
588

589
        if isinstance(path, BytesIO):
×
590
            path.write(response)
×
591
        else:
592
            with open(path, "wb") as file:
×
593
                file.write(response)
×
594

595
    # pylint: disable=too-many-positional-arguments,too-many-positional-arguments
596
    def save_as_dataset(
1✔
597
        self,
598
        query_rectangle: QueryRectangle,
599
        name: None | str,
600
        display_name: str,
601
        description: str = "",
602
        timeout: int = 3600,
603
    ) -> Task:
604
        """Init task to store the workflow result as a layer"""
605

606
        # Currently, it only works for raster results
607
        if not self.__result_descriptor.is_raster_result():
1✔
608
            raise MethodNotCalledOnRasterException()
×
609

610
        session = get_session()
1✔
611

612
        if not isinstance(query_rectangle, QueryRectangle):
1✔
613
            print("save_as_dataset ignores params other then spatial and tmporal bounds.")
×
614

615
        qrect = geoc.models.raster_to_dataset_query_rectangle.RasterToDatasetQueryRectangle(
1✔
616
            spatial_bounds=SpatialPartition2D.from_bounding_box(query_rectangle.spatial_bounds).to_api_dict(),
617
            time_interval=query_rectangle.time.to_api_dict(),
618
        )
619

620
        with geoc.ApiClient(session.configuration) as api_client:
1✔
621
            workflows_api = geoc.WorkflowsApi(api_client)
1✔
622
            response = workflows_api.dataset_from_workflow_handler(
1✔
623
                str(self.__workflow_id),
624
                geoc.RasterDatasetFromWorkflow(
625
                    name=name, display_name=display_name, description=description, query=qrect
626
                ),
627
                _request_timeout=timeout,
628
            )
629

630
        return Task(TaskId.from_response(response))
1✔
631

632
    async def raster_stream(
1✔
633
        self,
634
        query_rectangle: QueryRectangle | RasterQueryRectangle,
635
        open_timeout: int = 60,
636
    ) -> AsyncIterator[RasterTile2D]:
637
        """Stream the workflow result as series of RasterTile2D (transformable to numpy and xarray)"""
638

639
        # Currently, it only works for raster results
640
        if not self.__result_descriptor.is_raster_result():
1✔
641
            raise MethodNotCalledOnRasterException()
×
642

643
        result_descriptor = cast(RasterResultDescriptor, self.__result_descriptor)
1✔
644

645
        if not isinstance(query_rectangle, RasterQueryRectangle):
1✔
646
            query_rectangle = query_rectangle.with_raster_bands(
1✔
647
                # TODO: all bands or first band?
648
                list(range(0, len(result_descriptor.bands)))
649
            )
650

651
        session = get_session()
1✔
652

653
        url = (
1✔
654
            req.Request(
655
                "GET",
656
                url=f"{session.server_url}/workflow/{self.__workflow_id}/rasterStream",
657
                params={
658
                    "resultType": "arrow",
659
                    "spatialBounds": query_rectangle.bbox_str,
660
                    "timeInterval": query_rectangle.time_str,
661
                    "attributes": ",".join(map(str, query_rectangle.raster_bands)),
662
                },
663
            )
664
            .prepare()
665
            .url
666
        )
667

668
        if url is None:
1✔
669
            raise InputException("Invalid websocket url")
×
670

671
        async with websockets.asyncio.client.connect(
1✔
672
            uri=self.__replace_http_with_ws(url),
673
            additional_headers=session.auth_header,
674
            open_timeout=open_timeout,
675
            max_size=None,
676
        ) as websocket:
677
            tile_bytes: bytes | None = None
1✔
678

679
            while websocket.state == websockets.protocol.State.OPEN:
1✔
680

681
                async def read_new_bytes() -> bytes | None:
1✔
682
                    # already send the next request to speed up the process
683
                    try:
1✔
684
                        await websocket.send("NEXT")
1✔
685
                    except websockets.exceptions.ConnectionClosed:
×
686
                        # the websocket connection is already closed, we cannot read anymore
687
                        return None
×
688

689
                    try:
1✔
690
                        data: str | bytes = await websocket.recv()
1✔
691

692
                        if isinstance(data, str):
1✔
693
                            # the server sent an error message
694
                            raise GeoEngineException({"error": data})
×
695

696
                        return data
1✔
697
                    except websockets.exceptions.ConnectionClosedOK:
×
698
                        # the websocket connection closed gracefully, so we stop reading
699
                        return None
×
700

701
                (tile_bytes, tile) = await asyncio.gather(
1✔
702
                    read_new_bytes(),
703
                    # asyncio.to_thread(process_bytes, tile_bytes), # TODO: use this when min Python version is 3.9
704
                    backports.to_thread(RasterStreamProcessing.process_bytes, tile_bytes),
705
                )
706

707
                if tile is not None:
1✔
708
                    yield tile
1✔
709

710
            # process the last tile
711
            tile = RasterStreamProcessing.process_bytes(tile_bytes)
1✔
712

713
            if tile is not None:
1✔
714
                yield tile
1✔
715

716
    async def raster_stream_into_xarray(
1✔
717
        self,
718
        query_rectangle: RasterQueryRectangle,
719
        clip_to_query_rectangle: bool = False,
720
        open_timeout: int = 60,
721
    ) -> xr.DataArray:
722
        """
723
        Stream the workflow result into memory and output a single xarray.
724

725
        NOTE: You can run out of memory if the query rectangle is too large.
726
        """
727

728
        tile_stream = self.raster_stream(query_rectangle, open_timeout=open_timeout)
1✔
729

730
        timestep_xarrays: list[xr.DataArray] = []
1✔
731

732
        spatial_clip_bounds = query_rectangle.spatial_bounds if clip_to_query_rectangle else None
1✔
733

734
        async def read_tiles(
1✔
735
            remainder_tile: RasterTile2D | None,
736
        ) -> tuple[list[xr.DataArray], RasterTile2D | None]:
737
            last_timestep: np.datetime64 | None = None
1✔
738
            tiles = []
1✔
739

740
            if remainder_tile is not None:
1✔
741
                last_timestep = remainder_tile.time_start_ms
1✔
742
                xr_tile = remainder_tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
743
                tiles.append(xr_tile)
1✔
744

745
            async for tile in tile_stream:
1✔
746
                timestep: np.datetime64 = tile.time_start_ms
1✔
747
                if last_timestep is None:
1✔
748
                    last_timestep = timestep
1✔
749
                elif last_timestep != timestep:
1✔
750
                    return tiles, tile
1✔
751

752
                xr_tile = tile.to_xarray(clip_with_bounds=spatial_clip_bounds)
1✔
753
                tiles.append(xr_tile)
1✔
754

755
            # this seems to be the last time step, so just return tiles
756
            return tiles, None
1✔
757

758
        (tiles, remainder_tile) = await read_tiles(None)
1✔
759

760
        while len(tiles):
1✔
761
            ((new_tiles, new_remainder_tile), new_timestep_xarray) = await asyncio.gather(
1✔
762
                read_tiles(remainder_tile),
763
                backports.to_thread(RasterStreamProcessing.merge_tiles, tiles),
764
                # asyncio.to_thread(merge_tiles, tiles), # TODO: use this when min Python version is 3.9
765
            )
766

767
            tiles = new_tiles
1✔
768
            remainder_tile = new_remainder_tile
1✔
769

770
            if new_timestep_xarray is not None:
1✔
771
                timestep_xarrays.append(new_timestep_xarray)
1✔
772

773
        output: xr.DataArray = cast(
1✔
774
            xr.DataArray,
775
            # await asyncio.to_thread( # TODO: use this when min Python version is 3.9
776
            await backports.to_thread(
777
                xr.concat,
778
                # TODO: This is a typings error, since the method accepts also a `xr.DataArray` and returns one
779
                cast(list[xr.Dataset], timestep_xarrays),
780
                dim="time",
781
            ),
782
        )
783

784
        return output
1✔
785

786
    async def vector_stream(
1✔
787
        self,
788
        query_rectangle: QueryRectangle,
789
        time_start_column: str = "time_start",
790
        time_end_column: str = "time_end",
791
        open_timeout: int = 60,
792
    ) -> AsyncIterator[gpd.GeoDataFrame]:
793
        """Stream the workflow result as series of `GeoDataFrame`s"""
794

795
        def read_arrow_ipc(arrow_ipc: bytes) -> pa.RecordBatch:
1✔
796
            reader = pa.ipc.open_file(arrow_ipc)
1✔
797
            # We know from the backend that there is only one record batch
798
            record_batch = reader.get_record_batch(0)
1✔
799
            return record_batch
1✔
800

801
        def create_geo_data_frame(
1✔
802
            record_batch: pa.RecordBatch, time_start_column: str, time_end_column: str
803
        ) -> gpd.GeoDataFrame:
804
            metadata = record_batch.schema.metadata
1✔
805
            spatial_reference = metadata[b"spatialReference"].decode("utf-8")
1✔
806

807
            data_frame = record_batch.to_pandas()
1✔
808

809
            geometry = gpd.GeoSeries.from_wkt(data_frame[api.GEOMETRY_COLUMN_NAME])
1✔
810
            # delete the duplicated column
811
            del data_frame[api.GEOMETRY_COLUMN_NAME]
1✔
812

813
            geo_data_frame = gpd.GeoDataFrame(
1✔
814
                data_frame,
815
                geometry=geometry,
816
                crs=spatial_reference,
817
            )
818

819
            # split time column
820
            geo_data_frame[[time_start_column, time_end_column]] = geo_data_frame[api.TIME_COLUMN_NAME].tolist()
1✔
821
            # delete the duplicated column
822
            del geo_data_frame[api.TIME_COLUMN_NAME]
1✔
823

824
            # parse time columns
825
            for time_column in [time_start_column, time_end_column]:
1✔
826
                geo_data_frame[time_column] = pd.to_datetime(
1✔
827
                    geo_data_frame[time_column],
828
                    utc=True,
829
                    unit="ms",
830
                    # TODO: solve time conversion problem from Geo Engine to Python for large (+/-) time instances
831
                    errors="coerce",
832
                )
833

834
            return geo_data_frame
1✔
835

836
        def process_bytes(batch_bytes: bytes | None) -> gpd.GeoDataFrame | None:
1✔
837
            if batch_bytes is None:
1✔
838
                return None
1✔
839

840
            # process the received data
841
            record_batch = read_arrow_ipc(batch_bytes)
1✔
842
            tile = create_geo_data_frame(
1✔
843
                record_batch,
844
                time_start_column=time_start_column,
845
                time_end_column=time_end_column,
846
            )
847

848
            return tile
1✔
849

850
        # Currently, it only works for raster results
851
        if not self.__result_descriptor.is_vector_result():
1✔
852
            raise MethodNotCalledOnVectorException()
×
853

854
        session = get_session()
1✔
855

856
        params = {
1✔
857
            "resultType": "arrow",
858
            "spatialBounds": query_rectangle.bbox_str,
859
            "timeInterval": query_rectangle.time_str,
860
        }
861

862
        url = (
1✔
863
            req.Request("GET", url=f"{session.server_url}/workflow/{self.__workflow_id}/vectorStream", params=params)
864
            .prepare()
865
            .url
866
        )
867

868
        if url is None:
1✔
869
            raise InputException("Invalid websocket url")
×
870

871
        async with websockets.asyncio.client.connect(
1✔
872
            uri=self.__replace_http_with_ws(url),
873
            additional_headers=session.auth_header,
874
            open_timeout=open_timeout,
875
            max_size=None,  # allow arbitrary large messages, since it is capped by the server's chunk size
876
        ) as websocket:
877
            batch_bytes: bytes | None = None
1✔
878

879
            while websocket.state == websockets.protocol.State.OPEN:
1✔
880

881
                async def read_new_bytes() -> bytes | None:
1✔
882
                    # already send the next request to speed up the process
883
                    try:
1✔
884
                        await websocket.send("NEXT")
1✔
885
                    except websockets.exceptions.ConnectionClosed:
×
886
                        # the websocket connection is already closed, we cannot read anymore
887
                        return None
×
888

889
                    try:
1✔
890
                        data: str | bytes = await websocket.recv()
1✔
891

892
                        if isinstance(data, str):
1✔
893
                            # the server sent an error message
894
                            raise GeoEngineException({"error": data})
×
895

896
                        return data
1✔
897
                    except websockets.exceptions.ConnectionClosedOK:
×
898
                        # the websocket connection closed gracefully, so we stop reading
899
                        return None
×
900

901
                (batch_bytes, batch) = await asyncio.gather(
1✔
902
                    read_new_bytes(),
903
                    # asyncio.to_thread(process_bytes, batch_bytes), # TODO: use this when min Python version is 3.9
904
                    backports.to_thread(process_bytes, batch_bytes),
905
                )
906

907
                if batch is not None:
1✔
908
                    yield batch
1✔
909

910
            # process the last tile
911
            batch = process_bytes(batch_bytes)
1✔
912

913
            if batch is not None:
1✔
914
                yield batch
1✔
915

916
    async def vector_stream_into_geopandas(
1✔
917
        self,
918
        query_rectangle: QueryRectangle,
919
        time_start_column: str = "time_start",
920
        time_end_column: str = "time_end",
921
        open_timeout: int = 60,
922
    ) -> gpd.GeoDataFrame:
923
        """
924
        Stream the workflow result into memory and output a single geo data frame.
925

926
        NOTE: You can run out of memory if the query rectangle is too large.
927
        """
928

929
        chunk_stream = self.vector_stream(
1✔
930
            query_rectangle,
931
            time_start_column=time_start_column,
932
            time_end_column=time_end_column,
933
            open_timeout=open_timeout,
934
        )
935

936
        data_frame: gpd.GeoDataFrame | None = None
1✔
937
        chunk: gpd.GeoDataFrame | None = None
1✔
938

939
        async def read_dataframe() -> gpd.GeoDataFrame | None:
1✔
940
            try:
1✔
941
                return await chunk_stream.__anext__()
1✔
942
            except StopAsyncIteration:
1✔
943
                return None
1✔
944

945
        def merge_dataframes(df_a: gpd.GeoDataFrame | None, df_b: gpd.GeoDataFrame | None) -> gpd.GeoDataFrame | None:
1✔
946
            if df_a is None:
1✔
947
                return df_b
1✔
948

949
            if df_b is None:
1✔
950
                return df_a
×
951

952
            return pd.concat([df_a, df_b], ignore_index=True)
1✔
953

954
        while True:
1✔
955
            (chunk, data_frame) = await asyncio.gather(
1✔
956
                read_dataframe(),
957
                backports.to_thread(merge_dataframes, data_frame, chunk),
958
                # TODO: use this when min Python version is 3.9
959
                # asyncio.to_thread(merge_dataframes, data_frame, chunk),
960
            )
961

962
            # we can stop when the chunk stream is exhausted
963
            if chunk is None:
1✔
964
                break
1✔
965

966
        return data_frame
1✔
967

968
    def __replace_http_with_ws(self, url: str) -> str:
1✔
969
        """
970
        Replace the protocol in the url from `http` to `ws`.
971

972
        For the websockets library, it is necessary that the url starts with `ws://`.
973
        For HTTPS, we need to use `wss://` instead.
974
        """
975

976
        [protocol, url_part] = url.split("://", maxsplit=1)
1✔
977

978
        ws_prefix = "wss://" if "s" in protocol.lower() else "ws://"
1✔
979

980
        return f"{ws_prefix}{url_part}"
1✔
981

982

983
def register_workflow(workflow: dict[str, Any] | WorkflowBuilderOperator, timeout: int = 60) -> Workflow:
1✔
984
    """
985
    Register a workflow in Geo Engine and receive a `WorkflowId`
986
    """
987

988
    if isinstance(workflow, WorkflowBuilderOperator):
1✔
989
        workflow = workflow.to_workflow_dict()
1✔
990

991
    workflow_model = geoc.Workflow.from_dict(workflow)
1✔
992

993
    if workflow_model is None:
1✔
994
        raise InputException("Invalid workflow definition")
×
995

996
    session = get_session()
1✔
997

998
    with geoc.ApiClient(session.configuration) as api_client:
1✔
999
        workflows_api = geoc.WorkflowsApi(api_client)
1✔
1000
        response = workflows_api.register_workflow_handler(workflow_model, _request_timeout=timeout)
1✔
1001

1002
    return Workflow(WorkflowId.from_response(response))
1✔
1003

1004

1005
def workflow_by_id(workflow_id: UUID | str) -> Workflow:
1✔
1006
    """
1007
    Create a workflow object from a workflow id
1008
    """
1009

1010
    # TODO: check that workflow exists
1011

1012
    return Workflow(WorkflowId(workflow_id))
1✔
1013

1014

1015
def get_quota(user_id: UUID | None = None, timeout: int = 60) -> geoc.Quota:
1✔
1016
    """
1017
    Gets a user's quota. Only admins can get other users' quota.
1018
    """
1019

1020
    session = get_session()
×
1021

1022
    with geoc.ApiClient(session.configuration) as api_client:
×
1023
        user_api = geoc.UserApi(api_client)
×
1024

1025
        if user_id is None:
×
1026
            return user_api.quota_handler(_request_timeout=timeout)
×
1027

1028
        return user_api.get_user_quota_handler(str(user_id), _request_timeout=timeout)
×
1029

1030

1031
def update_quota(user_id: UUID, new_available_quota: int, timeout: int = 60) -> None:
1✔
1032
    """
1033
    Update a user's quota. Only admins can perform this operation.
1034
    """
1035

1036
    session = get_session()
×
1037

1038
    with geoc.ApiClient(session.configuration) as api_client:
×
1039
        user_api = geoc.UserApi(api_client)
×
1040
        user_api.update_user_quota_handler(
×
1041
            str(user_id), geoc.UpdateQuota(available=new_available_quota), _request_timeout=timeout
1042
        )
1043

1044

1045
def data_usage(offset: int = 0, limit: int = 10) -> list[geoc.DataUsage]:
1✔
1046
    """
1047
    Get data usage
1048
    """
1049

1050
    session = get_session()
1✔
1051

1052
    with geoc.ApiClient(session.configuration) as api_client:
1✔
1053
        user_api = geoc.UserApi(api_client)
1✔
1054
        response = user_api.data_usage_handler(offset=offset, limit=limit)
1✔
1055

1056
        # create dataframe from response
1057
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1058
        df = pd.DataFrame(usage_dicts)
1✔
1059
        if "timestamp" in df.columns:
1✔
1060
            df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)
1✔
1061

1062
    return df
1✔
1063

1064

1065
def data_usage_summary(
1✔
1066
    granularity: geoc.UsageSummaryGranularity, dataset: str | None = None, offset: int = 0, limit: int = 10
1067
) -> pd.DataFrame:
1068
    """
1069
    Get data usage summary
1070
    """
1071

1072
    session = get_session()
1✔
1073

1074
    with geoc.ApiClient(session.configuration) as api_client:
1✔
1075
        user_api = geoc.UserApi(api_client)
1✔
1076
        response = user_api.data_usage_summary_handler(
1✔
1077
            dataset=dataset, granularity=granularity, offset=offset, limit=limit
1078
        )
1079

1080
        # create dataframe from response
1081
        usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
1✔
1082
        df = pd.DataFrame(usage_dicts)
1✔
1083
        if "timestamp" in df.columns:
1✔
1084
            df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)
1✔
1085

1086
    return df
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc