• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine / 13195982647

07 Feb 2025 08:25AM UTC coverage: 90.031% (-0.04%) from 90.073%
13195982647

Pull #1000

github

web-flow
Merge 2410e0a4e into 22fe2f6bb
Pull Request #1000: add tensor shape to ml model

109 of 194 new or added lines in 10 files covered. (56.19%)

3 existing lines in 3 files now uncovered.

126035 of 139990 relevant lines covered (90.03%)

57499.02 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.09
/operators/src/machine_learning/onnx.rs
1
use crate::engine::{
2
    CanonicOperatorName, ExecutionContext, InitializedRasterOperator, InitializedSources, Operator,
3
    OperatorName, QueryContext, RasterBandDescriptor, RasterOperator, RasterQueryProcessor,
4
    RasterResultDescriptor, SingleRasterSource, TypedRasterQueryProcessor, WorkflowOperatorPath,
5
};
6
use crate::error;
7
use crate::machine_learning::error::{
8
    InputBandsMismatch, InputTypeMismatch, InvalidInputShape, Ort,
9
};
10
use crate::machine_learning::MachineLearningError;
11
use crate::util::Result;
12
use async_trait::async_trait;
13
use futures::stream::BoxStream;
14
use futures::StreamExt;
15
use geoengine_datatypes::machine_learning::{MlModelMetadata, MlModelName};
16
use geoengine_datatypes::primitives::{Measurement, RasterQueryRectangle};
17
use geoengine_datatypes::raster::{
18
    Grid, GridIdx2D, GridIndexAccess, GridShapeAccess, GridSize, Pixel, RasterTile2D,
19
};
20
use ndarray::{Array2, Array4};
21
use ort::tensor::{IntoTensorElementType, PrimitiveTensorElementType};
22
use serde::{Deserialize, Serialize};
23
use snafu::{ensure, ResultExt};
24

25
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26
#[serde(rename_all = "camelCase")]
27
pub struct OnnxParams {
28
    pub model: MlModelName,
29
}
30

31
/// This `QueryProcessor` applies a ml model in Onnx format on all bands of its input raster series.
32
/// For now, the model has to be for a single pixel and multiple bands.
33
pub type Onnx = Operator<OnnxParams, SingleRasterSource>;
34

35
impl OperatorName for Onnx {
36
    const TYPE_NAME: &'static str = "Onnx";
37
}
38

39
#[typetag::serde]
×
40
#[async_trait]
41
impl RasterOperator for Onnx {
42
    async fn _initialize(
43
        self: Box<Self>,
44
        path: WorkflowOperatorPath,
45
        context: &dyn ExecutionContext,
46
    ) -> Result<Box<dyn InitializedRasterOperator>> {
2✔
47
        let name = CanonicOperatorName::from(&self);
2✔
48

49
        let source = self
2✔
50
            .sources
2✔
51
            .initialize_sources(path.clone(), context)
2✔
52
            .await?
2✔
53
            .raster;
54

55
        let in_descriptor = source.result_descriptor();
2✔
56

57
        let model_metadata = context.ml_model_metadata(&self.params.model).await?;
2✔
58

59
        let tiling_shape = context.tiling_specification().tile_size_in_pixels;
2✔
60

2✔
61
        // check that we can use the model input shape with the operator
2✔
62
        ensure!(
2✔
63
            model_metadata.input_is_single_pixel()
2✔
NEW
64
                || model_metadata
×
NEW
65
                    .input_shape
×
NEW
66
                    .yx_matches_tile_shape(&tiling_shape),
×
NEW
67
            InvalidInputShape {
×
NEW
68
                tensor_shape: model_metadata.input_shape,
×
NEW
69
                tiling_shape
×
NEW
70
            }
×
71
        );
72

73
        // check that we can use the model output shape with the operator
74
        ensure!(
2✔
75
            model_metadata.output_is_single_pixel()
2✔
NEW
76
                || model_metadata
×
NEW
77
                    .output_shape
×
NEW
78
                    .yx_matches_tile_shape(&tiling_shape),
×
NEW
79
            InvalidInputShape {
×
NEW
80
                tensor_shape: model_metadata.output_shape,
×
NEW
81
                tiling_shape
×
NEW
82
            }
×
83
        );
84

85
        // check that number of input bands fits number of model features
86
        ensure!(
2✔
87
            model_metadata.num_input_bands() == in_descriptor.bands.count(),
2✔
88
            InputBandsMismatch {
×
NEW
89
                model_input_bands: model_metadata.num_input_bands(),
×
90
                source_bands: in_descriptor.bands.count(),
×
91
            }
×
92
        );
93

94
        // check that input type fits model input type
95
        ensure!(
2✔
96
            model_metadata.input_type == in_descriptor.data_type,
2✔
97
            InputTypeMismatch {
×
98
                model_input_type: model_metadata.input_type,
×
99
                source_type: in_descriptor.data_type,
×
100
            }
×
101
        );
102

103
        let out_descriptor = RasterResultDescriptor {
2✔
104
            data_type: model_metadata.output_type,
2✔
105
            spatial_reference: in_descriptor.spatial_reference,
2✔
106
            time: in_descriptor.time,
2✔
107
            bbox: in_descriptor.bbox,
2✔
108
            resolution: in_descriptor.resolution,
2✔
109
            bands: vec![RasterBandDescriptor::new(
2✔
110
                "prediction".to_string(), // TODO: parameter of the operator?
2✔
111
                Measurement::Unitless,    // TODO: get output measurement from model metadata
2✔
112
            )]
2✔
113
            .try_into()?,
2✔
114
        };
115

116
        Ok(Box::new(InitializedOnnx {
2✔
117
            name,
2✔
118
            path,
2✔
119
            result_descriptor: out_descriptor,
2✔
120
            source,
2✔
121
            model_metadata,
2✔
122
        }))
2✔
123
    }
4✔
124

125
    span_fn!(Onnx);
126
}
127

128
pub struct InitializedOnnx {
129
    name: CanonicOperatorName,
130
    path: WorkflowOperatorPath,
131
    result_descriptor: RasterResultDescriptor,
132
    source: Box<dyn InitializedRasterOperator>,
133
    model_metadata: MlModelMetadata,
134
}
135

136
impl InitializedRasterOperator for InitializedOnnx {
137
    fn result_descriptor(&self) -> &RasterResultDescriptor {
×
138
        &self.result_descriptor
×
139
    }
×
140

141
    fn query_processor(&self) -> Result<TypedRasterQueryProcessor> {
2✔
142
        let source = self.source.query_processor()?;
2✔
143
        Ok(call_on_generic_raster_processor!(
×
144
            source, input => {
2✔
145
                call_generic_raster_processor!(
×
146
                    self.model_metadata.output_type,
×
147
                    OnnxProcessor::new(
×
148
                        input,
×
149
                        self.result_descriptor.clone(),
×
NEW
150
                        self.model_metadata.clone(),
×
151
                    )
×
152
                    .boxed()
×
153
                )
154
            }
155
        ))
156
    }
2✔
157

158
    fn canonic_name(&self) -> CanonicOperatorName {
×
159
        self.name.clone()
×
160
    }
×
161

162
    fn name(&self) -> &'static str {
×
163
        Onnx::TYPE_NAME
×
164
    }
×
165

166
    fn path(&self) -> WorkflowOperatorPath {
×
167
        self.path.clone()
×
168
    }
×
169
}
170

171
pub(crate) struct OnnxProcessor<TIn, TOut> {
172
    source: Box<dyn RasterQueryProcessor<RasterType = TIn>>, // as most ml algorithms work on f32 we use this as input type
173
    result_descriptor: RasterResultDescriptor,
174
    model_metadata: MlModelMetadata,
175
    phantom: std::marker::PhantomData<TOut>,
176
}
177

178
impl<TIn, TOut> OnnxProcessor<TIn, TOut> {
179
    pub fn new(
2✔
180
        source: Box<dyn RasterQueryProcessor<RasterType = TIn>>,
2✔
181
        result_descriptor: RasterResultDescriptor,
2✔
182
        model_metadata: MlModelMetadata,
2✔
183
    ) -> Self {
2✔
184
        Self {
2✔
185
            source,
2✔
186
            result_descriptor,
2✔
187
            model_metadata,
2✔
188
            phantom: Default::default(),
2✔
189
        }
2✔
190
    }
2✔
191
}
192

193
#[async_trait]
194
impl<TIn, TOut> RasterQueryProcessor for OnnxProcessor<TIn, TOut>
195
where
196
    TIn: Pixel + NoDataValue,
197
    TOut: Pixel + IntoTensorElementType + PrimitiveTensorElementType,
198
    ort::value::Value: std::convert::TryFrom<
199
        ndarray::ArrayBase<ndarray::OwnedRepr<TIn>, ndarray::Dim<[usize; 2]>>,
200
    >,
201
    ort::value::Value: std::convert::TryFrom<
202
        ndarray::ArrayBase<ndarray::OwnedRepr<TIn>, ndarray::Dim<[usize; 4]>>,
203
    >,
204
    ort::Error: std::convert::From<
205
        <ort::value::Value as std::convert::TryFrom<
206
            ndarray::ArrayBase<ndarray::OwnedRepr<TIn>, ndarray::Dim<[usize; 2]>>,
207
        >>::Error,
208
    >,
209
    ort::Error: From<
210
        <ort::value::Value as std::convert::TryFrom<
211
            ndarray::ArrayBase<ndarray::OwnedRepr<TIn>, ndarray::Dim<[usize; 4]>>,
212
        >>::Error,
213
    >,
214
{
215
    type RasterType = TOut;
216

217
    async fn raster_query<'a>(
218
        &'a self,
219
        query: RasterQueryRectangle,
220
        ctx: &'a dyn QueryContext,
221
    ) -> Result<BoxStream<'a, Result<RasterTile2D<TOut>>>> {
2✔
222
        let num_bands = self.source.raster_result_descriptor().bands.count() as usize;
2✔
223

2✔
224
        let mut source_query = query.clone();
2✔
225
        source_query.attributes = (0..num_bands as u32).collect::<Vec<u32>>().try_into()?;
2✔
226

227
        // TODO: re-use session accross queries?
228
        let session = ort::session::Session::builder()
2✔
229
            .context(Ort)?
2✔
230
            .commit_from_file(&self.model_metadata.file_path)
2✔
231
            .context(Ort)
2✔
232
            .inspect_err(|e| {
2✔
233
                tracing::debug!(
×
234
                    "Could not create ONNX session for {:?}. Error: {}",
×
NEW
235
                    self.model_metadata.file_path.file_name(),
×
236
                    e
237
                );
238
            })?;
2✔
239

240
        tracing::debug!(
2✔
241
            "Created ONNX session for {:?}",
×
NEW
242
            &self.model_metadata.file_path.file_name()
×
243
        );
244

245
        let stream = self
2✔
246
            .source
2✔
247
            .raster_query(source_query, ctx)
2✔
248
            .await?
2✔
249
            .chunks(num_bands) // chunk the tiles to get all bands for a spatial index at once
2✔
250
            // TODO: this does not scale for large number of bands.
2✔
251
            //       In that case we would need to collect only a fixed number of pixel from each each,
2✔
252
            //       and repeat the process until the whole tile is finished
2✔
253
            .map(move |chunk| {
4✔
254
                // TODO: spawn task and await
4✔
255

4✔
256
                if chunk.len() != num_bands {
4✔
257
                    // if there are not exactly N tiles, it should mean the last tile was an error and the chunker ended prematurely
258
                    if let Some(Err(e)) = chunk.into_iter().last() {
×
259
                        return Err(e);
×
260
                    }
×
261
                    // if there is no error, the source did not produce all bands, which likely means a bug in an operator
×
262
                    return Err(error::Error::MustNotHappen {
×
263
                        message: "source did not produce all bands".to_string(),
×
264
                    });
×
265
                }
4✔
266

267
                let tiles = chunk.into_iter().collect::<Result<Vec<_>>>()?;
4✔
268

269
                let first_tile = &tiles[0];
4✔
270
                let time = first_tile.time;
4✔
271
                let tile_position = first_tile.tile_position;
4✔
272
                let global_geo_transform = first_tile.global_geo_transform;
4✔
273
                let cache_hint = first_tile.cache_hint;
4✔
274

4✔
275
                let tile_shape = tiles[0].grid_shape();
4✔
276
                let width = tile_shape.axis_size_x();
4✔
277
                let height = tile_shape.axis_size_y();
4✔
278

4✔
279
                // TODO: collect into a ndarray directly
4✔
280

4✔
281
                // TODO: use flat array instead of nested Vecs
4✔
282
                let mut pixels: Vec<Vec<TIn>> = vec![vec![TIn::zero(); num_bands]; width * height];
4✔
283

284
                for (tile_index, tile) in tiles.into_iter().enumerate() {
10✔
285
                    // TODO: use map_elements or map_elements_parallel to avoid the double loop
286
                    for y in 0..height {
20✔
287
                        for x in 0..width {
40✔
288
                            let pixel_index = y * width + x;
40✔
289
                            let pixel_value = tile
40✔
290
                                .get_at_grid_index(GridIdx2D::from([y as isize, x as isize]))?
40✔
291
                                .unwrap_or(TIn::NO_DATA); // TODO: properly handle missing values or skip the pixel entirely instead
40✔
292
                            pixels[pixel_index][tile_index] = pixel_value;
40✔
293
                        }
294
                    }
295
                }
296

297
                let pixels = pixels.into_iter().flatten().collect::<Vec<TIn>>();
4✔
298

299
                let outputs = if self.model_metadata.input_is_single_pixel() {
4✔
300
                    let rows = width * height;
4✔
301
                    let cols = num_bands;
4✔
302

4✔
303
                    let samples = Array2::from_shape_vec((rows, cols), pixels).expect(
4✔
304
                        "Array2 should be valid because it is created from a Vec with the correct size",
4✔
305
                    );
4✔
306

4✔
307
                    let input_name = &session.inputs[0].name;
4✔
308

309
                    let out = session
4✔
310
                        .run(ort::inputs![input_name => samples].context(Ort)?)
4✔
311
                        .context(Ort)?;
4✔
312
                    Ok(out)
4✔
NEW
313
                } else if self.model_metadata.input_shape.yx_matches_tile_shape(&tile_shape){
×
NEW
314
                    let samples = Array4::from_shape_vec((1, height, width, num_bands), pixels).expect( // y,x, attributes
×
NEW
315
                        "Array2 should be valid because it is created from a Vec with the correct size",
×
NEW
316
                    );
×
NEW
317

×
NEW
318
                    let input_name = &session.inputs[0].name;
×
319

NEW
320
                    let out = session
×
NEW
321
                        .run(ort::inputs![input_name => samples].context(Ort)?)
×
NEW
322
                        .context(Ort)?;
×
323

NEW
324
                    Ok(out)
×
325
                } else {
NEW
326
                    Err(
×
NEW
327
                        MachineLearningError::InvalidInputShape {
×
NEW
328
                            tensor_shape: self.model_metadata.input_shape,
×
NEW
329
                            tiling_shape: tile_shape
×
NEW
330
                        }
×
NEW
331
                    )
×
332
                }.map_err(error::Error::from)?;
4✔
333

334
                // assume the first output is the prediction and ignore the other outputs (e.g. probabilities for classification)
335
                // we don't access the output by name because it can vary, e.g. "output_label" vs "variable"
336
                let predictions = outputs[0].try_extract_tensor::<TOut>().context(Ort)?;
4✔
337

338
                // extract the values as a raw vector because we expect one prediction per pixel.
339
                // this works for 1d tensors as well as 2d tensors with a single column
340
                let (predictions, offset) = predictions.into_owned().into_raw_vec_and_offset();
4✔
341
                debug_assert!(offset.is_none() || offset == Some(0));
4✔
342

343
                // TODO: create no data mask from input no data masks
344
                Ok(RasterTile2D::new(
345
                    time,
4✔
346
                    tile_position,
4✔
347
                    0,
4✔
348
                    global_geo_transform,
4✔
349
                    Grid::new([width, height].into(), predictions)?.into(),
4✔
350
                    cache_hint,
4✔
351
                ))
352
            });
4✔
353

2✔
354
        Ok(stream.boxed())
2✔
355
    }
4✔
356

357
    fn raster_result_descriptor(&self) -> &RasterResultDescriptor {
2✔
358
        &self.result_descriptor
2✔
359
    }
2✔
360
}
361

362
// workaround trait to handle missing values for all datatypes.
363
// TODO: this should be handled differently, like skipping the pixel entirely or using a different value for missing values
364
trait NoDataValue {
365
    const NO_DATA: Self;
366
}
367

368
impl NoDataValue for f32 {
369
    const NO_DATA: Self = f32::NAN;
370
}
371

372
impl NoDataValue for f64 {
373
    const NO_DATA: Self = f64::NAN;
374
}
375

376
// Define a macro to implement NoDataValue for various types with NO_DATA as 0
377
macro_rules! impl_no_data_value_zero {
378
    ($($t:ty),*) => {
379
        $(
380
            impl NoDataValue for $t {
381
                const NO_DATA: Self = 0;
382
            }
383
        )*
384
    };
385
}
386

387
// Use the macro to implement NoDataValue for i8, u8, i16, u16, etc.
388
impl_no_data_value_zero!(i8, u8, i16, u16, i32, u32, i64, u64);
389

390
#[cfg(test)]
391
mod tests {
392
    use crate::{
393
        engine::{
394
            MockExecutionContext, MockQueryContext, MultipleRasterSources, RasterBandDescriptors,
395
        },
396
        machine_learning::metadata_from_file::load_model_metadata,
397
        mock::{MockRasterSource, MockRasterSourceParams},
398
        processing::{RasterStacker, RasterStackerParams},
399
    };
400
    use approx::assert_abs_diff_eq;
401
    use geoengine_datatypes::{
402
        primitives::{CacheHint, SpatialPartition2D, SpatialResolution, TimeInterval},
403
        raster::{
404
            GridOrEmpty, GridShape, RasterDataType, RenameBands, TilesEqualIgnoringCacheHint,
405
        },
406
        spatial_reference::SpatialReference,
407
        test_data,
408
        util::test::TestDefault,
409
    };
410
    use ndarray::{arr2, array, Array1, Array2};
411

412
    use super::*;
413

414
    #[test]
415
    fn ort() {
1✔
416
        let session = ort::session::Session::builder()
1✔
417
            .unwrap()
1✔
418
            .commit_from_file(test_data!("ml/onnx/test_classification.onnx"))
1✔
419
            .unwrap();
1✔
420

1✔
421
        let input_name = &session.inputs[0].name;
1✔
422

1✔
423
        let new_samples = arr2(&[[0.1f32, 0.2], [0.2, 0.3], [0.2, 0.2], [0.3, 0.1]]);
1✔
424

1✔
425
        let outputs = session
1✔
426
            .run(ort::inputs![input_name => new_samples].unwrap())
1✔
427
            .unwrap();
1✔
428

1✔
429
        let predictions = outputs["output_label"]
1✔
430
            .try_extract_tensor::<i64>()
1✔
431
            .unwrap()
1✔
432
            .into_owned()
1✔
433
            .into_dimensionality()
1✔
434
            .unwrap();
1✔
435

1✔
436
        assert_eq!(predictions, &array![33i64, 33, 42, 42]);
1✔
437
    }
1✔
438

439
    #[test]
440
    fn ort_dynamic() {
1✔
441
        let session = ort::session::Session::builder()
1✔
442
            .unwrap()
1✔
443
            .commit_from_file(test_data!("ml/onnx/test_classification.onnx"))
1✔
444
            .unwrap();
1✔
445

1✔
446
        let input_name = &session.inputs[0].name;
1✔
447

1✔
448
        let pixels = vec![
1✔
449
            vec![0.1f32, 0.2],
1✔
450
            vec![0.2, 0.3],
1✔
451
            vec![0.2, 0.2],
1✔
452
            vec![0.3, 0.1],
1✔
453
        ]
1✔
454
        .into_iter()
1✔
455
        .flatten()
1✔
456
        .collect::<Vec<f32>>();
1✔
457

1✔
458
        let rows = 4;
1✔
459
        let cols = 2;
1✔
460

1✔
461
        let new_samples = Array2::from_shape_vec((rows, cols), pixels).unwrap();
1✔
462

1✔
463
        let outputs = session
1✔
464
            .run(ort::inputs![input_name => new_samples].unwrap())
1✔
465
            .unwrap();
1✔
466

1✔
467
        let predictions = outputs["output_label"]
1✔
468
            .try_extract_tensor::<i64>()
1✔
469
            .unwrap()
1✔
470
            .into_owned()
1✔
471
            .into_dimensionality()
1✔
472
            .unwrap();
1✔
473

1✔
474
        assert_eq!(predictions, &array![33i64, 33, 42, 42]);
1✔
475
    }
1✔
476

477
    #[test]
478
    fn regression() {
1✔
479
        let session = ort::session::Session::builder()
1✔
480
            .unwrap()
1✔
481
            .commit_from_file(test_data!("ml/onnx/test_regression.onnx"))
1✔
482
            .unwrap();
1✔
483

1✔
484
        let input_name = &session.inputs[0].name;
1✔
485

1✔
486
        let pixels = vec![
1✔
487
            vec![0.1f32, 0.1, 0.2],
1✔
488
            vec![0.1, 0.2, 0.2],
1✔
489
            vec![0.2, 0.2, 0.2],
1✔
490
            vec![0.2, 0.2, 0.1],
1✔
491
        ]
1✔
492
        .into_iter()
1✔
493
        .flatten()
1✔
494
        .collect::<Vec<f32>>();
1✔
495

1✔
496
        let rows = 4;
1✔
497
        let cols = 3;
1✔
498

1✔
499
        let new_samples = Array2::from_shape_vec((rows, cols), pixels).unwrap();
1✔
500

1✔
501
        let outputs = session
1✔
502
            .run(ort::inputs![input_name => new_samples].unwrap())
1✔
503
            .unwrap();
1✔
504

1✔
505
        let predictions: Array1<f32> = outputs["variable"]
1✔
506
            .try_extract_tensor::<f32>()
1✔
507
            .unwrap()
1✔
508
            .to_owned()
1✔
509
            .to_shape((4,))
1✔
510
            .unwrap()
1✔
511
            .to_owned();
1✔
512

1✔
513
        assert!(predictions.abs_diff_eq(&array![0.4f32, 0.5, 0.6, 0.5], 1e-6));
1✔
514
    }
1✔
515

516
    // TOODO: add test using neural network model
517

518
    #[tokio::test]
519
    #[allow(clippy::too_many_lines)]
520
    async fn it_classifies_tiles() {
1✔
521
        let data: Vec<RasterTile2D<f32>> = vec![
1✔
522
            RasterTile2D {
1✔
523
                time: TimeInterval::new_unchecked(0, 5),
1✔
524
                tile_position: [-1, 0].into(),
1✔
525
                band: 0,
1✔
526
                global_geo_transform: TestDefault::test_default(),
1✔
527
                grid_array: Grid::new([2, 2].into(), vec![0.1f32, 0.1, 0.2, 0.2])
1✔
528
                    .unwrap()
1✔
529
                    .into(),
1✔
530
                properties: Default::default(),
1✔
531
                cache_hint: CacheHint::default(),
1✔
532
            },
1✔
533
            RasterTile2D {
1✔
534
                time: TimeInterval::new_unchecked(0, 5),
1✔
535
                tile_position: [-1, 1].into(),
1✔
536
                band: 0,
1✔
537
                global_geo_transform: TestDefault::test_default(),
1✔
538
                grid_array: Grid::new([2, 2].into(), vec![0.2f32, 0.2, 0.1, 0.1])
1✔
539
                    .unwrap()
1✔
540
                    .into(),
1✔
541
                properties: Default::default(),
1✔
542
                cache_hint: CacheHint::default(),
1✔
543
            },
1✔
544
        ];
1✔
545

1✔
546
        let data2: Vec<RasterTile2D<f32>> = vec![
1✔
547
            RasterTile2D {
1✔
548
                time: TimeInterval::new_unchecked(0, 5),
1✔
549
                tile_position: [-1, 0].into(),
1✔
550
                band: 0,
1✔
551
                global_geo_transform: TestDefault::test_default(),
1✔
552
                grid_array: Grid::new([2, 2].into(), vec![0.2f32, 0.2, 0.1, 0.1])
1✔
553
                    .unwrap()
1✔
554
                    .into(),
1✔
555
                properties: Default::default(),
1✔
556
                cache_hint: CacheHint::default(),
1✔
557
            },
1✔
558
            RasterTile2D {
1✔
559
                time: TimeInterval::new_unchecked(0, 5),
1✔
560
                tile_position: [-1, 1].into(),
1✔
561
                band: 0,
1✔
562
                global_geo_transform: TestDefault::test_default(),
1✔
563
                grid_array: Grid::new([2, 2].into(), vec![0.1f32, 0.1, 0.2, 0.2])
1✔
564
                    .unwrap()
1✔
565
                    .into(),
1✔
566
                properties: Default::default(),
1✔
567
                cache_hint: CacheHint::default(),
1✔
568
            },
1✔
569
        ];
1✔
570

1✔
571
        let mrs1 = MockRasterSource {
1✔
572
            params: MockRasterSourceParams {
1✔
573
                data: data.clone(),
1✔
574
                result_descriptor: RasterResultDescriptor {
1✔
575
                    data_type: RasterDataType::F32,
1✔
576
                    spatial_reference: SpatialReference::epsg_4326().into(),
1✔
577
                    time: None,
1✔
578
                    bbox: None,
1✔
579
                    resolution: None,
1✔
580
                    bands: RasterBandDescriptors::new_single_band(),
1✔
581
                },
1✔
582
            },
1✔
583
        }
1✔
584
        .boxed();
1✔
585

1✔
586
        let mrs2 = MockRasterSource {
1✔
587
            params: MockRasterSourceParams {
1✔
588
                data: data2.clone(),
1✔
589
                result_descriptor: RasterResultDescriptor {
1✔
590
                    data_type: RasterDataType::F32,
1✔
591
                    spatial_reference: SpatialReference::epsg_4326().into(),
1✔
592
                    time: None,
1✔
593
                    bbox: None,
1✔
594
                    resolution: None,
1✔
595
                    bands: RasterBandDescriptors::new_single_band(),
1✔
596
                },
1✔
597
            },
1✔
598
        }
1✔
599
        .boxed();
1✔
600

1✔
601
        let stacker = RasterStacker {
1✔
602
            params: RasterStackerParams {
1✔
603
                rename_bands: RenameBands::Default,
1✔
604
            },
1✔
605
            sources: MultipleRasterSources {
1✔
606
                rasters: vec![mrs1, mrs2],
1✔
607
            },
1✔
608
        }
1✔
609
        .boxed();
1✔
610

1✔
611
        // load a very simple model that checks whether the first band is greater than the second band
1✔
612
        let model_name = MlModelName {
1✔
613
            namespace: None,
1✔
614
            name: "test_classification".into(),
1✔
615
        };
1✔
616

1✔
617
        let onnx = Onnx {
1✔
618
            params: OnnxParams {
1✔
619
                model: model_name.clone(),
1✔
620
            },
1✔
621
            sources: SingleRasterSource { raster: stacker },
1✔
622
        }
1✔
623
        .boxed();
1✔
624

1✔
625
        let mut exe_ctx = MockExecutionContext::test_default();
1✔
626
        exe_ctx.tiling_specification.tile_size_in_pixels = GridShape {
1✔
627
            shape_array: [2, 2],
1✔
628
        };
1✔
629
        exe_ctx.ml_models.insert(
1✔
630
            model_name,
1✔
631
            load_model_metadata(test_data!("ml/onnx/test_classification.onnx")).unwrap(),
1✔
632
        );
1✔
633

1✔
634
        let query_rect = RasterQueryRectangle {
1✔
635
            spatial_bounds: SpatialPartition2D::new_unchecked((0., 1.).into(), (3., 0.).into()),
1✔
636
            time_interval: TimeInterval::new_unchecked(0, 5),
1✔
637
            spatial_resolution: SpatialResolution::one(),
1✔
638
            attributes: [0].try_into().unwrap(),
1✔
639
        };
1✔
640

1✔
641
        let query_ctx = MockQueryContext::test_default();
1✔
642

1✔
643
        let op = onnx
1✔
644
            .initialize(WorkflowOperatorPath::initialize_root(), &exe_ctx)
1✔
645
            .await
1✔
646
            .unwrap();
1✔
647

1✔
648
        let qp = op.query_processor().unwrap().get_i64().unwrap();
1✔
649

1✔
650
        let result = qp
1✔
651
            .raster_query(query_rect, &query_ctx)
1✔
652
            .await
1✔
653
            .unwrap()
1✔
654
            .collect::<Vec<_>>()
1✔
655
            .await;
1✔
656
        let result = result.into_iter().collect::<Result<Vec<_>>>().unwrap();
1✔
657

1✔
658
        let expected: Vec<RasterTile2D<i64>> = vec![
1✔
659
            RasterTile2D {
1✔
660
                time: TimeInterval::new_unchecked(0, 5),
1✔
661
                tile_position: [-1, 0].into(),
1✔
662
                band: 0,
1✔
663
                global_geo_transform: TestDefault::test_default(),
1✔
664
                grid_array: Grid::new([2, 2].into(), vec![33i64, 33, 42, 42])
1✔
665
                    .unwrap()
1✔
666
                    .into(),
1✔
667
                properties: Default::default(),
1✔
668
                cache_hint: CacheHint::default(),
1✔
669
            },
1✔
670
            RasterTile2D {
1✔
671
                time: TimeInterval::new_unchecked(0, 5),
1✔
672
                tile_position: [-1, 1].into(),
1✔
673
                band: 0,
1✔
674
                global_geo_transform: TestDefault::test_default(),
1✔
675
                grid_array: Grid::new([2, 2].into(), vec![42i64, 42, 33, 33])
1✔
676
                    .unwrap()
1✔
677
                    .into(),
1✔
678
                properties: Default::default(),
1✔
679
                cache_hint: CacheHint::default(),
1✔
680
            },
1✔
681
        ];
1✔
682

1✔
683
        assert!(expected.tiles_equal_ignoring_cache_hint(&result));
1✔
684
    }
1✔
685

686
    #[tokio::test]
687
    #[allow(clippy::too_many_lines)]
688
    async fn it_regresses_tiles() {
1✔
689
        let data: Vec<RasterTile2D<f32>> = vec![
1✔
690
            RasterTile2D {
1✔
691
                time: TimeInterval::new_unchecked(0, 5),
1✔
692
                tile_position: [-1, 0].into(),
1✔
693
                band: 0,
1✔
694
                global_geo_transform: TestDefault::test_default(),
1✔
695
                grid_array: Grid::new([2, 2].into(), vec![0.1f32, 0.2, 0.3, 0.4])
1✔
696
                    .unwrap()
1✔
697
                    .into(),
1✔
698
                properties: Default::default(),
1✔
699
                cache_hint: CacheHint::default(),
1✔
700
            },
1✔
701
            RasterTile2D {
1✔
702
                time: TimeInterval::new_unchecked(0, 5),
1✔
703
                tile_position: [-1, 1].into(),
1✔
704
                band: 0,
1✔
705
                global_geo_transform: TestDefault::test_default(),
1✔
706
                grid_array: Grid::new([2, 2].into(), vec![0.5f32, 0.6, 0.7, 0.8])
1✔
707
                    .unwrap()
1✔
708
                    .into(),
1✔
709
                properties: Default::default(),
1✔
710
                cache_hint: CacheHint::default(),
1✔
711
            },
1✔
712
        ];
1✔
713

1✔
714
        let data2: Vec<RasterTile2D<f32>> = vec![
1✔
715
            RasterTile2D {
1✔
716
                time: TimeInterval::new_unchecked(0, 5),
1✔
717
                tile_position: [-1, 0].into(),
1✔
718
                band: 0,
1✔
719
                global_geo_transform: TestDefault::test_default(),
1✔
720
                grid_array: Grid::new([2, 2].into(), vec![0.9f32, 0.8, 0.7, 0.6])
1✔
721
                    .unwrap()
1✔
722
                    .into(),
1✔
723
                properties: Default::default(),
1✔
724
                cache_hint: CacheHint::default(),
1✔
725
            },
1✔
726
            RasterTile2D {
1✔
727
                time: TimeInterval::new_unchecked(0, 5),
1✔
728
                tile_position: [-1, 1].into(),
1✔
729
                band: 0,
1✔
730
                global_geo_transform: TestDefault::test_default(),
1✔
731
                grid_array: Grid::new([2, 2].into(), vec![0.5f32, 0.4, 0.3, 0.22])
1✔
732
                    .unwrap()
1✔
733
                    .into(),
1✔
734
                properties: Default::default(),
1✔
735
                cache_hint: CacheHint::default(),
1✔
736
            },
1✔
737
        ];
1✔
738

1✔
739
        let data3: Vec<RasterTile2D<f32>> = vec![
1✔
740
            RasterTile2D {
1✔
741
                time: TimeInterval::new_unchecked(0, 5),
1✔
742
                tile_position: [-1, 0].into(),
1✔
743
                band: 0,
1✔
744
                global_geo_transform: TestDefault::test_default(),
1✔
745
                grid_array: Grid::new([2, 2].into(), vec![0.1f32, 0.2, 0.3, 0.4])
1✔
746
                    .unwrap()
1✔
747
                    .into(),
1✔
748
                properties: Default::default(),
1✔
749
                cache_hint: CacheHint::default(),
1✔
750
            },
1✔
751
            RasterTile2D {
1✔
752
                time: TimeInterval::new_unchecked(0, 5),
1✔
753
                tile_position: [-1, 1].into(),
1✔
754
                band: 0,
1✔
755
                global_geo_transform: TestDefault::test_default(),
1✔
756
                grid_array: Grid::new([2, 2].into(), vec![0.5f32, 0.6, 0.7, 0.8])
1✔
757
                    .unwrap()
1✔
758
                    .into(),
1✔
759
                properties: Default::default(),
1✔
760
                cache_hint: CacheHint::default(),
1✔
761
            },
1✔
762
        ];
1✔
763

1✔
764
        let mrs1 = MockRasterSource {
1✔
765
            params: MockRasterSourceParams {
1✔
766
                data: data.clone(),
1✔
767
                result_descriptor: RasterResultDescriptor {
1✔
768
                    data_type: RasterDataType::F32,
1✔
769
                    spatial_reference: SpatialReference::epsg_4326().into(),
1✔
770
                    time: None,
1✔
771
                    bbox: None,
1✔
772
                    resolution: None,
1✔
773
                    bands: RasterBandDescriptors::new_single_band(),
1✔
774
                },
1✔
775
            },
1✔
776
        }
1✔
777
        .boxed();
1✔
778

1✔
779
        let mrs2 = MockRasterSource {
1✔
780
            params: MockRasterSourceParams {
1✔
781
                data: data2.clone(),
1✔
782
                result_descriptor: RasterResultDescriptor {
1✔
783
                    data_type: RasterDataType::F32,
1✔
784
                    spatial_reference: SpatialReference::epsg_4326().into(),
1✔
785
                    time: None,
1✔
786
                    bbox: None,
1✔
787
                    resolution: None,
1✔
788
                    bands: RasterBandDescriptors::new_single_band(),
1✔
789
                },
1✔
790
            },
1✔
791
        }
1✔
792
        .boxed();
1✔
793

1✔
794
        let mrs3 = MockRasterSource {
1✔
795
            params: MockRasterSourceParams {
1✔
796
                data: data3.clone(),
1✔
797
                result_descriptor: RasterResultDescriptor {
1✔
798
                    data_type: RasterDataType::F32,
1✔
799
                    spatial_reference: SpatialReference::epsg_4326().into(),
1✔
800
                    time: None,
1✔
801
                    bbox: None,
1✔
802
                    resolution: None,
1✔
803
                    bands: RasterBandDescriptors::new_single_band(),
1✔
804
                },
1✔
805
            },
1✔
806
        }
1✔
807
        .boxed();
1✔
808

1✔
809
        let stacker = RasterStacker {
1✔
810
            params: RasterStackerParams {
1✔
811
                rename_bands: RenameBands::Default,
1✔
812
            },
1✔
813
            sources: MultipleRasterSources {
1✔
814
                rasters: vec![mrs1, mrs2, mrs3],
1✔
815
            },
1✔
816
        }
1✔
817
        .boxed();
1✔
818

1✔
819
        // load a very simple model that performs regression to predict the sum of the three bands
1✔
820
        let model_name = MlModelName {
1✔
821
            namespace: None,
1✔
822
            name: "test_regression".into(),
1✔
823
        };
1✔
824

1✔
825
        let onnx = Onnx {
1✔
826
            params: OnnxParams {
1✔
827
                model: model_name.clone(),
1✔
828
            },
1✔
829
            sources: SingleRasterSource { raster: stacker },
1✔
830
        }
1✔
831
        .boxed();
1✔
832

1✔
833
        let mut exe_ctx = MockExecutionContext::test_default();
1✔
834
        exe_ctx.tiling_specification.tile_size_in_pixels = GridShape {
1✔
835
            shape_array: [2, 2],
1✔
836
        };
1✔
837
        exe_ctx.ml_models.insert(
1✔
838
            model_name,
1✔
839
            load_model_metadata(test_data!("ml/onnx/test_regression.onnx")).unwrap(),
1✔
840
        );
1✔
841

1✔
842
        let query_rect = RasterQueryRectangle {
1✔
843
            spatial_bounds: SpatialPartition2D::new_unchecked((0., 1.).into(), (3., 0.).into()),
1✔
844
            time_interval: TimeInterval::new_unchecked(0, 5),
1✔
845
            spatial_resolution: SpatialResolution::one(),
1✔
846
            attributes: [0].try_into().unwrap(),
1✔
847
        };
1✔
848

1✔
849
        let query_ctx = MockQueryContext::test_default();
1✔
850

1✔
851
        let op = onnx
1✔
852
            .initialize(WorkflowOperatorPath::initialize_root(), &exe_ctx)
1✔
853
            .await
1✔
854
            .unwrap();
1✔
855

1✔
856
        let qp = op.query_processor().unwrap().get_f32().unwrap();
1✔
857

1✔
858
        let result = qp
1✔
859
            .raster_query(query_rect, &query_ctx)
1✔
860
            .await
1✔
861
            .unwrap()
1✔
862
            .collect::<Vec<_>>()
1✔
863
            .await;
1✔
864
        let result = result.into_iter().collect::<Result<Vec<_>>>().unwrap();
1✔
865

1✔
866
        assert_eq!(result.len(), 2);
1✔
867

1✔
868
        let expected = vec![vec![1.1f32, 1.2, 1.3, 1.4], vec![1.5f32, 1.6, 1.7, 1.8]];
1✔
869

1✔
870
        for (tile, expected) in result.iter().zip(expected) {
2✔
871
            let GridOrEmpty::Grid(result_array) = &tile.grid_array else {
2✔
872
                panic!("no result array")
1✔
873
            };
1✔
874

1✔
875
            assert_abs_diff_eq!(
2✔
876
                result_array.inner_grid.data.as_slice(),
2✔
877
                expected.as_slice(),
2✔
878
                epsilon = 0.1
2✔
879
            );
2✔
880
        }
1✔
881
    }
1✔
882
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc