• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine-python / 16367912334

18 Jul 2025 10:06AM UTC coverage: 76.934% (+0.1%) from 76.806%
16367912334

push

github

web-flow
ci: use Ruff as new formatter and linter (#233)

* wip

* pycodestyle

* update dependencies

* skl2onnx

* use ruff

* apply formatter

* apply lint auto fixes

* manually apply lints

* change check

* ruff ci from branch

2805 of 3646 relevant lines covered (76.93%)

0.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.11
geoengine/raster.py
1
"""Raster data types"""
2

3
from __future__ import annotations
1✔
4

5
from collections.abc import AsyncIterator
1✔
6
from typing import Literal, cast
1✔
7

8
import geoengine_openapi_client
1✔
9
import numpy as np
1✔
10
import pyarrow as pa
1✔
11
import xarray as xr
1✔
12

13
import geoengine.types as gety
1✔
14
from geoengine.util import clamp_datetime_ms_ns
1✔
15

16

17
# pylint: disable=too-many-return-statements
18
def ge_type_to_np(res_dt: Literal["U8", "U16", "U32", "U64", "I8", "I16", "I32", "I64", "F32", "F64"]):
1✔
19
    """Convert a Geo Engine data type to a numpy data type"""
20

21
    if res_dt == "U8":
×
22
        return np.uint8
×
23
    if res_dt == "U16":
×
24
        return np.uint16
×
25
    if res_dt == "U32":
×
26
        return np.uint32
×
27
    if res_dt == "U64":
×
28
        return np.uint64
×
29
    if res_dt == "I8":
×
30
        return np.int8
×
31
    if res_dt == "I16":
×
32
        return np.int16
×
33
    if res_dt == "I32":
×
34
        return np.int32
×
35
    if res_dt == "I64":
×
36
        return np.int64
×
37
    if res_dt == "F32":
×
38
        return np.float32
×
39
    if res_dt == "F64":
×
40
        return np.float64
×
41
    raise TypeError("Unknown type literal")
×
42

43

44
class RasterTile2D:
1✔
45
    """A 2D raster tile as produced by the Geo Engine"""
46

47
    size_x: int
1✔
48
    size_y: int
1✔
49
    data: pa.Array
1✔
50
    geo_transform: gety.GeoTransform
1✔
51
    crs: str
1✔
52
    time: gety.TimeInterval
1✔
53
    band: int
1✔
54

55
    # pylint: disable=too-many-arguments,too-many-positional-arguments
56
    def __init__(
1✔
57
        self,
58
        shape: tuple[int, int],
59
        data: pa.Array,
60
        geo_transform: gety.GeoTransform,
61
        crs: str,
62
        time: gety.TimeInterval,
63
        band: int,
64
    ):
65
        """Create a RasterTile2D object"""
66
        self.size_y, self.size_x = shape
1✔
67
        self.data = data
1✔
68
        self.geo_transform = geo_transform
1✔
69
        self.crs = crs
1✔
70
        self.time = time
1✔
71
        self.band = band
1✔
72

73
    @property
1✔
74
    def shape(self) -> tuple[int, int]:
1✔
75
        """Return the shape of the raster tile in numpy order (y_size, x_size)"""
76
        return (self.size_y, self.size_x)
1✔
77

78
    @property
1✔
79
    def data_type(self) -> pa.DataType:
1✔
80
        """Return the arrow data type of the raster tile"""
81
        return self.data.type
1✔
82

83
    @property
1✔
84
    def numpy_data_type(self) -> np.dtype:
1✔
85
        """Return the numpy dtype of the raster tile"""
86
        return self.data_type.to_pandas_dtype()
1✔
87

88
    @property
1✔
89
    def has_null_values(self) -> bool:
1✔
90
        """Return whether the raster tile has null values"""
91
        return self.data.null_count > 0
1✔
92

93
    @property
1✔
94
    def time_start_ms(self) -> np.datetime64:
1✔
95
        return self.time.start.astype("datetime64[ms]")
1✔
96

97
    @property
1✔
98
    def time_end_ms(self) -> np.datetime64 | None:
1✔
99
        return None if self.time.end is None else self.time.end.astype("datetime64[ms]")
×
100

101
    @property
1✔
102
    def pixel_size(self) -> tuple[float, float]:
1✔
103
        return (self.geo_transform.x_pixel_size, self.geo_transform.y_pixel_size)
1✔
104

105
    def to_numpy_data_array(self, fill_null_value=0) -> np.ndarray:
1✔
106
        """
107
        Return the raster tile as a numpy array.
108
        Caution: this will not mask nodata values but replace them with the provided value !
109
        """
110
        nulled_array = self.data.fill_null(fill_null_value)
1✔
111
        return nulled_array.to_numpy(
1✔
112
            zero_copy_only=True,  # data was already copied when creating the "null filled" array
113
        ).reshape(self.shape)
114

115
    def to_numpy_mask_array(self, nan_is_null=False) -> np.ndarray | None:
1✔
116
        """
117
        Return the raster tiles mask as a numpy array.
118
        True means no data, False means data.
119
        If the raster tile has no null values, None is returned.
120
        It is possible to specify whether NaN values should be considered as no data when creating the mask.
121
        """
122
        numpy_mask = None
1✔
123
        if self.has_null_values:
1✔
124
            numpy_mask = (
1✔
125
                self.data.is_null(
126
                    nan_is_null=nan_is_null  # nan is not no data
127
                )
128
                .to_numpy(
129
                    zero_copy_only=False  # cannot zero-copy with bools
130
                )
131
                .reshape(self.shape)
132
            )
133
        return numpy_mask
1✔
134

135
    def to_numpy_masked_array(self, nan_is_null=False) -> np.ma.MaskedArray:
1✔
136
        """Return the raster tile as a masked numpy array"""
137
        numpy_data = self.to_numpy_data_array()
1✔
138
        maybe_numpy_mask = self.to_numpy_mask_array(nan_is_null=nan_is_null)
1✔
139

140
        assert maybe_numpy_mask is None or maybe_numpy_mask.shape == numpy_data.shape
1✔
141

142
        numpy_mask: np.ndarray | np.ma.MaskType = np.ma.nomask if maybe_numpy_mask is None else maybe_numpy_mask
1✔
143

144
        numpy_masked_data: np.ma.MaskedArray = np.ma.masked_array(numpy_data, mask=numpy_mask)
1✔
145

146
        return numpy_masked_data
1✔
147

148
    def coords_x(self, pixel_center=False) -> np.ndarray:
1✔
149
        """
150
        Return the x coordinates of the raster tile
151
        If pixel_center is True, the coordinates will be the center of the pixels.
152
        Otherwise they will be the upper left edges.
153
        """
154
        start = self.geo_transform.x_min
1✔
155

156
        if pixel_center:
1✔
157
            start += self.geo_transform.x_half_pixel_size
1✔
158

159
        return np.arange(
1✔
160
            start=start,
161
            stop=self.geo_transform.x_max(self.size_x),
162
            step=self.geo_transform.x_pixel_size,
163
        )
164

165
    def coords_y(self, pixel_center=False) -> np.ndarray:
1✔
166
        """
167
        Return the y coordinates of the raster tile
168
        If pixel_center is True, the coordinates will be the center of the pixels.
169
        Otherwise they will be the upper left edges.
170
        """
171
        start = self.geo_transform.y_max
1✔
172

173
        if pixel_center:
1✔
174
            start += self.geo_transform.y_half_pixel_size
1✔
175

176
        return np.arange(
1✔
177
            start=start,
178
            stop=self.geo_transform.y_min(self.size_y),
179
            step=self.geo_transform.y_pixel_size,
180
        )
181

182
    def to_xarray(self, clip_with_bounds: gety.SpatialBounds | None = None) -> xr.DataArray:
1✔
183
        """
184
        Return the raster tile as an xarray.DataArray.
185

186
        Note:
187
            - Xarray does not support masked arrays.
188
                - Masked pixels are converted to NaNs and the nodata value is set to NaN as well.
189
            - Xarray uses numpy's datetime64[ns] which only covers the years from 1678 to 2262.
190
                - Date times that are outside of the defined range are clipped to the limits of the range.
191
        """
192

193
        # clamp the dates to the min and max range
194
        clamped_date = clamp_datetime_ms_ns(self.time_start_ms)
1✔
195

196
        array = xr.DataArray(
1✔
197
            self.to_numpy_masked_array(),
198
            dims=["y", "x"],
199
            coords={
200
                "x": self.coords_x(pixel_center=True),
201
                "y": self.coords_y(pixel_center=True),
202
                "time": clamped_date,  # TODO: incorporate time end?
203
                "band": self.band,
204
            },
205
        )
206
        array.rio.write_crs(self.crs, inplace=True)
1✔
207

208
        if clip_with_bounds is not None:
1✔
209
            array = array.rio.clip_box(*clip_with_bounds.as_bbox_tuple(), auto_expand=True)
×
210
            array = cast(xr.DataArray, array)
×
211

212
        return array
1✔
213

214
    def spatial_partition(self) -> gety.SpatialPartition2D:
1✔
215
        """Return the spatial partition of the raster tile"""
216
        return gety.SpatialPartition2D(
×
217
            self.geo_transform.x_min,
218
            self.geo_transform.y_min(self.size_y),
219
            self.geo_transform.x_max(self.size_x),
220
            self.geo_transform.y_max,
221
        )
222

223
    def spatial_resolution(self) -> gety.SpatialResolution:
1✔
224
        return self.geo_transform.spatial_resolution()
×
225

226
    def is_empty(self) -> bool:
1✔
227
        """Returns true if the tile is empty"""
228
        num_pixels = self.size_x * self.size_y
×
229
        num_nulls = self.data.null_count
×
230
        return num_pixels == num_nulls
×
231

232
    @staticmethod
1✔
233
    def from_ge_record_batch(record_batch: pa.RecordBatch) -> RasterTile2D:
1✔
234
        """Create a RasterTile2D from an Arrow record batch recieved from the Geo Engine"""
235
        metadata = record_batch.schema.metadata
1✔
236
        inner = geoengine_openapi_client.GdalDatasetGeoTransform.from_json(metadata[b"geoTransform"])
1✔
237
        assert inner is not None, "Failed to parse geoTransform"
1✔
238
        geo_transform = gety.GeoTransform.from_response(inner)
1✔
239
        x_size = int(metadata[b"xSize"])
1✔
240
        y_size = int(metadata[b"ySize"])
1✔
241
        spatial_reference = metadata[b"spatialReference"].decode("utf-8")
1✔
242
        # We know from the backend that there is only one array a.k.a. one column
243
        arrow_array = record_batch.column(0)
1✔
244

245
        inner_time = geoengine_openapi_client.TimeInterval.from_json(metadata[b"time"])
1✔
246
        assert inner_time is not None, "Failed to parse time"
1✔
247
        time = gety.TimeInterval.from_response(inner_time)
1✔
248

249
        band = int(metadata[b"band"])
1✔
250

251
        return RasterTile2D(
1✔
252
            (y_size, x_size),
253
            arrow_array,
254
            geo_transform,
255
            spatial_reference,
256
            time,
257
            band,
258
        )
259

260

261
class RasterTileStack2D:
1✔
262
    """A stack of all the bands of a raster tile as produced by the Geo Engine"""
263

264
    size_y: int
1✔
265
    size_x: int
1✔
266
    geo_transform: gety.GeoTransform
1✔
267
    crs: str
1✔
268
    time: gety.TimeInterval
1✔
269
    data: list[pa.Array]
1✔
270
    bands: list[int]
1✔
271

272
    # pylint: disable=too-many-arguments,too-many-positional-arguments
273
    def __init__(
1✔
274
        self,
275
        tile_shape: tuple[int, int],
276
        data: list[pa.Array],
277
        geo_transform: gety.GeoTransform,
278
        crs: str,
279
        time: gety.TimeInterval,
280
        bands: list[int],
281
    ):
282
        """Create a RasterTileStack2D object"""
283
        (self.size_y, self.size_x) = tile_shape
×
284
        self.data = data
×
285
        self.geo_transform = geo_transform
×
286
        self.crs = crs
×
287
        self.time = time
×
288
        self.bands = bands
×
289

290
    def single_band(self, index: int) -> RasterTile2D:
1✔
291
        """Return a single band from the stack"""
292
        return RasterTile2D(
×
293
            (self.size_y, self.size_x),
294
            self.data[index],
295
            self.geo_transform,
296
            self.crs,
297
            self.time,
298
            self.bands[index],
299
        )
300

301
    def to_numpy_masked_array_stack(self) -> np.ma.MaskedArray:
1✔
302
        """Return the raster stack as a 3D masked numpy array"""
303
        arrays = [self.single_band(i).to_numpy_masked_array() for i in range(0, len(self.data))]
×
304
        stack = np.ma.stack(arrays, axis=0)
×
305
        return stack
×
306

307
    def to_xarray(self, clip_with_bounds: gety.SpatialBounds | None = None) -> xr.DataArray:
1✔
308
        """Return the raster stack as an xarray.DataArray"""
309
        arrays = [self.single_band(i).to_xarray(clip_with_bounds) for i in range(0, len(self.data))]
×
310
        stack = xr.concat(arrays, dim="band")
×
311
        return stack
×
312

313

314
async def tile_stream_to_stack_stream(raster_stream: AsyncIterator[RasterTile2D]) -> AsyncIterator[RasterTileStack2D]:
1✔
315
    """Convert a stream of raster tiles to stream of stacked tiles"""
316
    store: list[RasterTile2D] = []
×
317
    first_band: int = -1
×
318

319
    async for tile in raster_stream:
×
320
        if len(store) == 0:
×
321
            first_band = tile.band
×
322
            store.append(tile)
×
323

324
        else:
325
            # check things that should be the same for all tiles
326
            assert tile.shape == store[0].shape, "Tile shapes do not match"
×
327
            # TODO: geo transform should be the same for all tiles
328
            #       tiles should have a tile position or global pixel position
329

330
            # assert tile.geo_transform == store[0].geo_transform, 'Tile geo_transforms do not match'
331
            assert tile.crs == store[0].crs, "Tile crs do not match"
×
332

333
            if tile.band == first_band:
×
334
                assert tile.time.start >= store[0].time.start, "Tile time intervals must be equal or increasing"
×
335

336
                stack = [tile.data for tile in store]
×
337
                tile_shape = store[0].shape
×
338
                bands = [tile.band for tile in store]
×
339
                geo_transforms = store[0].geo_transform
×
340
                crs = store[0].crs
×
341
                time = store[0].time
×
342

343
                store = [tile]
×
344
                yield RasterTileStack2D(tile_shape, stack, geo_transforms, crs, time, bands)
×
345

346
            else:
347
                assert tile.time == store[0].time, "Time missmatch. " + str(store[0].time) + " != " + str(tile.time)
×
348
                store.append(tile)
×
349

350
    if len(store) > 0:
×
351
        tile_shape = store[0].shape
×
352
        stack = [tile.data for tile in store]
×
353
        bands = [tile.band for tile in store]
×
354
        geo_transforms = store[0].geo_transform
×
355
        crs = store[0].crs
×
356
        time = store[0].time
×
357

358
        store = []
×
359

360
        yield RasterTileStack2D(tile_shape, stack, geo_transforms, crs, time, bands)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc