• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine / 10699374903

04 Sep 2024 09:11AM UTC coverage: 91.006% (-0.1%) from 91.122%
10699374903

push

github

web-flow
Merge pull request #977 from geo-engine/model_db

db for ml models (wip)

460 of 843 new or added lines in 25 files covered. (54.57%)

16 existing lines in 10 files now uncovered.

133668 of 146878 relevant lines covered (91.01%)

52516.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.62
/operators/src/engine/execution_context.rs
1
use super::query::QueryAbortRegistration;
2
use super::{
3
    CreateSpan, InitializedPlotOperator, InitializedRasterOperator, InitializedVectorOperator,
4
    MockQueryContext, WorkflowOperatorPath,
5
};
6
use crate::engine::{
7
    ChunkByteSize, RasterResultDescriptor, ResultDescriptor, VectorResultDescriptor,
8
};
9
use crate::error::Error;
10
use crate::mock::MockDatasetDataSourceLoadingInfo;
11
use crate::source::{GdalLoadingInfo, OgrSourceDataset};
12
use crate::util::{create_rayon_thread_pool, Result};
13
use async_trait::async_trait;
14
use core::any::TypeId;
15
use geoengine_datatypes::dataset::{DataId, NamedData};
16
use geoengine_datatypes::machine_learning::{MlModelMetadata, MlModelName};
17
use geoengine_datatypes::primitives::{RasterQueryRectangle, VectorQueryRectangle};
18
use geoengine_datatypes::raster::TilingSpecification;
19
use geoengine_datatypes::util::test::TestDefault;
20
use rayon::ThreadPool;
21
use serde::{Deserialize, Serialize};
22
use std::any::Any;
23
use std::collections::HashMap;
24
use std::fmt::Debug;
25
use std::marker::PhantomData;
26
use std::sync::Arc;
27

28
/// A context that provides certain utility access during operator initialization
29
#[async_trait::async_trait]
30
pub trait ExecutionContext: Send
31
    + Sync
32
    + MetaDataProvider<MockDatasetDataSourceLoadingInfo, VectorResultDescriptor, VectorQueryRectangle>
33
    + MetaDataProvider<OgrSourceDataset, VectorResultDescriptor, VectorQueryRectangle>
34
    + MetaDataProvider<GdalLoadingInfo, RasterResultDescriptor, RasterQueryRectangle>
35
{
36
    fn thread_pool(&self) -> &Arc<ThreadPool>;
37
    fn tiling_specification(&self) -> TilingSpecification;
38

39
    fn wrap_initialized_raster_operator(
40
        &self,
41
        op: Box<dyn InitializedRasterOperator>,
42
        span: CreateSpan,
43
        path: WorkflowOperatorPath, // TODO: remove and allow operators to tell its path
44
    ) -> Box<dyn InitializedRasterOperator>;
45

46
    fn wrap_initialized_vector_operator(
47
        &self,
48
        op: Box<dyn InitializedVectorOperator>,
49
        span: CreateSpan,
50
        path: WorkflowOperatorPath,
51
    ) -> Box<dyn InitializedVectorOperator>;
52

53
    fn wrap_initialized_plot_operator(
54
        &self,
55
        op: Box<dyn InitializedPlotOperator>,
56
        span: CreateSpan,
57
        path: WorkflowOperatorPath,
58
    ) -> Box<dyn InitializedPlotOperator>;
59

60
    async fn resolve_named_data(&self, data: &NamedData) -> Result<DataId>;
61

62
    async fn ml_model_metadata(&self, name: &MlModelName) -> Result<MlModelMetadata>;
63

64
    /// get the `ExecutionContextExtensions` that contain additional information
65
    fn extensions(&self) -> &ExecutionContextExtensions;
66
}
67

68
/// This type allows adding additional information to the `ExecutionContext`.
69
/// It acts like a type map, allowing one to store one value per type.
70
#[derive(Default)]
71
pub struct ExecutionContextExtensions {
72
    map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
73
}
74

75
impl ExecutionContextExtensions {
76
    pub fn insert<T: 'static + Send + Sync>(&mut self, val: T) -> Option<T> {
×
77
        self.map
×
78
            .insert(TypeId::of::<T>(), Box::new(val))
×
79
            .and_then(downcast_owned)
×
80
    }
×
81

82
    pub fn get<T: 'static + Send + Sync>(&self) -> Option<&T> {
×
83
        self.map
×
84
            .get(&TypeId::of::<T>())
×
85
            .and_then(|boxed: &Box<dyn Any + Send + Sync>| boxed.downcast_ref())
×
86
    }
×
87
}
88

89
fn downcast_owned<T: 'static + Send + Sync>(boxed: Box<dyn Any + Send + Sync>) -> Option<T> {
×
90
    boxed.downcast().ok().map(|boxed| *boxed)
×
91
}
×
92

93
#[async_trait]
94
pub trait MetaDataProvider<L, R, Q>
95
where
96
    R: ResultDescriptor,
97
{
98
    async fn meta_data(&self, id: &DataId) -> Result<Box<dyn MetaData<L, R, Q>>>;
99
}
100

101
#[async_trait]
102
pub trait MetaData<L, R, Q>: Debug + Send + Sync
103
where
104
    R: ResultDescriptor,
105
{
106
    async fn loading_info(&self, query: Q) -> Result<L>;
107
    async fn result_descriptor(&self) -> Result<R>;
108

109
    fn box_clone(&self) -> Box<dyn MetaData<L, R, Q>>;
110
}
111

112
impl<L, R, Q> Clone for Box<dyn MetaData<L, R, Q>>
113
where
114
    R: ResultDescriptor,
115
{
116
    fn clone(&self) -> Box<dyn MetaData<L, R, Q>> {
125✔
117
        self.box_clone()
125✔
118
    }
125✔
119
}
120

121
pub struct MockExecutionContext {
122
    pub thread_pool: Arc<ThreadPool>,
123
    pub meta_data: HashMap<DataId, Box<dyn Any + Send + Sync>>,
124
    pub named_data: HashMap<NamedData, DataId>,
125
    pub ml_models: HashMap<MlModelName, MlModelMetadata>,
126
    pub tiling_specification: TilingSpecification,
127
    pub extensions: ExecutionContextExtensions,
128
}
129

130
impl TestDefault for MockExecutionContext {
131
    fn test_default() -> Self {
144✔
132
        Self {
144✔
133
            thread_pool: create_rayon_thread_pool(0),
144✔
134
            meta_data: HashMap::default(),
144✔
135
            named_data: HashMap::default(),
144✔
136
            ml_models: HashMap::default(),
144✔
137
            tiling_specification: TilingSpecification::test_default(),
144✔
138
            extensions: Default::default(),
144✔
139
        }
144✔
140
    }
144✔
141
}
142

143
impl MockExecutionContext {
144
    pub fn new_with_tiling_spec(tiling_specification: TilingSpecification) -> Self {
131✔
145
        Self {
131✔
146
            thread_pool: create_rayon_thread_pool(0),
131✔
147
            meta_data: HashMap::default(),
131✔
148
            named_data: HashMap::default(),
131✔
149
            ml_models: HashMap::default(),
131✔
150
            tiling_specification,
131✔
151
            extensions: Default::default(),
131✔
152
        }
131✔
153
    }
131✔
154

155
    pub fn new_with_tiling_spec_and_thread_count(
×
156
        tiling_specification: TilingSpecification,
×
157
        num_threads: usize,
×
158
    ) -> Self {
×
159
        Self {
×
160
            thread_pool: create_rayon_thread_pool(num_threads),
×
161
            meta_data: HashMap::default(),
×
162
            named_data: HashMap::default(),
×
NEW
163
            ml_models: HashMap::default(),
×
164
            tiling_specification,
×
165
            extensions: Default::default(),
×
166
        }
×
167
    }
×
168

169
    pub fn add_meta_data<L, R, Q>(
49✔
170
        &mut self,
49✔
171
        data: DataId,
49✔
172
        named_data: NamedData,
49✔
173
        meta_data: Box<dyn MetaData<L, R, Q>>,
49✔
174
    ) where
49✔
175
        L: Send + Sync + 'static,
49✔
176
        R: Send + Sync + 'static + ResultDescriptor,
49✔
177
        Q: Send + Sync + 'static,
49✔
178
    {
49✔
179
        self.meta_data.insert(
49✔
180
            data.clone(),
49✔
181
            Box::new(meta_data) as Box<dyn Any + Send + Sync>,
49✔
182
        );
49✔
183

49✔
184
        self.named_data.insert(named_data, data);
49✔
185
    }
49✔
186

187
    pub fn delete_meta_data(&mut self, named_data: &NamedData) {
2✔
188
        let data = self.named_data.remove(named_data);
2✔
189
        if let Some(data) = data {
2✔
190
            self.meta_data.remove(&data);
2✔
191
        }
2✔
192
    }
2✔
193

194
    pub fn mock_query_context(&self, chunk_byte_size: ChunkByteSize) -> MockQueryContext {
3✔
195
        let (abort_registration, abort_trigger) = QueryAbortRegistration::new();
3✔
196
        MockQueryContext {
3✔
197
            chunk_byte_size,
3✔
198
            thread_pool: self.thread_pool.clone(),
3✔
199
            extensions: Default::default(),
3✔
200
            abort_registration,
3✔
201
            abort_trigger: Some(abort_trigger),
3✔
202
        }
3✔
203
    }
3✔
204
}
205

206
#[async_trait::async_trait]
207
impl ExecutionContext for MockExecutionContext {
208
    fn thread_pool(&self) -> &Arc<ThreadPool> {
×
209
        &self.thread_pool
×
210
    }
×
211

212
    fn tiling_specification(&self) -> TilingSpecification {
247✔
213
        self.tiling_specification
247✔
214
    }
247✔
215

216
    fn wrap_initialized_raster_operator(
320✔
217
        &self,
320✔
218
        op: Box<dyn InitializedRasterOperator>,
320✔
219
        _span: CreateSpan,
320✔
220
        _path: WorkflowOperatorPath,
320✔
221
    ) -> Box<dyn InitializedRasterOperator> {
320✔
222
        op
320✔
223
    }
320✔
224

225
    fn wrap_initialized_vector_operator(
168✔
226
        &self,
168✔
227
        op: Box<dyn InitializedVectorOperator>,
168✔
228
        _span: CreateSpan,
168✔
229
        _path: WorkflowOperatorPath,
168✔
230
    ) -> Box<dyn InitializedVectorOperator> {
168✔
231
        op
168✔
232
    }
168✔
233

234
    fn wrap_initialized_plot_operator(
53✔
235
        &self,
53✔
236
        op: Box<dyn InitializedPlotOperator>,
53✔
237
        _span: CreateSpan,
53✔
238
        _path: WorkflowOperatorPath,
53✔
239
    ) -> Box<dyn InitializedPlotOperator> {
53✔
240
        op
53✔
241
    }
53✔
242

243
    async fn resolve_named_data(&self, data: &NamedData) -> Result<DataId> {
53✔
244
        self.named_data
53✔
245
            .get(data)
53✔
246
            .cloned()
53✔
247
            .ok_or_else(|| Error::UnknownDatasetName { name: data.clone() })
53✔
248
    }
53✔
249

250
    async fn ml_model_metadata(&self, name: &MlModelName) -> Result<MlModelMetadata> {
2✔
251
        self.ml_models
2✔
252
            .get(name)
2✔
253
            .cloned()
2✔
254
            .ok_or_else(|| Error::UnknownMlModelName { name: name.clone() })
2✔
255
    }
2✔
256

257
    fn extensions(&self) -> &ExecutionContextExtensions {
×
258
        &self.extensions
×
259
    }
×
260
}
261

262
#[async_trait]
263
impl<L, R, Q> MetaDataProvider<L, R, Q> for MockExecutionContext
264
where
265
    L: 'static,
266
    R: 'static + ResultDescriptor,
267
    Q: 'static,
268
{
269
    async fn meta_data(&self, id: &DataId) -> Result<Box<dyn MetaData<L, R, Q>>> {
53✔
270
        let meta_data = self
53✔
271
            .meta_data
53✔
272
            .get(id)
53✔
273
            .ok_or(Error::UnknownDataId)?
53✔
274
            .downcast_ref::<Box<dyn MetaData<L, R, Q>>>()
53✔
275
            .ok_or(Error::InvalidMetaDataType)?;
53✔
276

53✔
277
        Ok(meta_data.clone())
53✔
278
    }
53✔
279
}
280

281
#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)]
3✔
282
#[serde(rename_all = "camelCase")]
283
pub struct StaticMetaData<L, R, Q>
284
where
285
    L: Debug + Clone + Send + Sync + 'static,
286
    R: Debug + Send + Sync + 'static + ResultDescriptor,
287
    Q: Debug + Clone + Send + Sync + 'static,
288
{
289
    pub loading_info: L,
290
    pub result_descriptor: R,
291
    #[serde(skip)]
292
    pub phantom: PhantomData<Q>,
293
}
294

295
#[async_trait]
296
impl<L, R, Q> MetaData<L, R, Q> for StaticMetaData<L, R, Q>
297
where
298
    L: Debug + Clone + Send + Sync + 'static,
299
    R: Debug + Send + Sync + 'static + ResultDescriptor,
300
    Q: Debug + Clone + Send + Sync + 'static,
301
{
302
    async fn loading_info(&self, _query: Q) -> Result<L> {
61✔
303
        Ok(self.loading_info.clone())
61✔
304
    }
61✔
305

306
    async fn result_descriptor(&self) -> Result<R> {
36✔
307
        Ok(self.result_descriptor.clone())
36✔
308
    }
36✔
309

310
    fn box_clone(&self) -> Box<dyn MetaData<L, R, Q>> {
49✔
311
        Box::new(self.clone())
49✔
312
    }
49✔
313
}
314

315
mod db_types {
316
    use geoengine_datatypes::delegate_from_to_sql;
317
    use postgres_types::{FromSql, ToSql};
318

319
    use super::*;
320

321
    pub type MockMetaData = StaticMetaData<
322
        MockDatasetDataSourceLoadingInfo,
323
        VectorResultDescriptor,
324
        VectorQueryRectangle,
325
    >;
326

327
    #[derive(Debug, ToSql, FromSql)]
485✔
328
    #[postgres(name = "MockMetaData")]
329
    pub struct MockMetaDataDbType {
330
        pub loading_info: MockDatasetDataSourceLoadingInfo,
331
        pub result_descriptor: VectorResultDescriptor,
332
    }
333

334
    impl From<&MockMetaData> for MockMetaDataDbType {
335
        fn from(other: &MockMetaData) -> Self {
2✔
336
            Self {
2✔
337
                loading_info: other.loading_info.clone(),
2✔
338
                result_descriptor: other.result_descriptor.clone(),
2✔
339
            }
2✔
340
        }
2✔
341
    }
342

343
    impl TryFrom<MockMetaDataDbType> for MockMetaData {
344
        type Error = Error;
345

346
        fn try_from(other: MockMetaDataDbType) -> Result<Self, Self::Error> {
2✔
347
            Ok(Self {
2✔
348
                loading_info: other.loading_info,
2✔
349
                result_descriptor: other.result_descriptor,
2✔
350
                phantom: PhantomData,
2✔
351
            })
2✔
352
        }
2✔
353
    }
354

355
    pub type OgrMetaData =
356
        StaticMetaData<OgrSourceDataset, VectorResultDescriptor, VectorQueryRectangle>;
357

358
    #[derive(Debug, ToSql, FromSql)]
497✔
359
    #[postgres(name = "OgrMetaData")]
360
    pub struct OgrMetaDataDbType {
361
        pub loading_info: OgrSourceDataset,
362
        pub result_descriptor: VectorResultDescriptor,
363
    }
364

365
    impl From<&StaticMetaData<OgrSourceDataset, VectorResultDescriptor, VectorQueryRectangle>>
366
        for OgrMetaDataDbType
367
    {
368
        fn from(other: &OgrMetaData) -> Self {
33✔
369
            Self {
33✔
370
                loading_info: other.loading_info.clone(),
33✔
371
                result_descriptor: other.result_descriptor.clone(),
33✔
372
            }
33✔
373
        }
33✔
374
    }
375

376
    impl TryFrom<OgrMetaDataDbType> for OgrMetaData {
377
        type Error = Error;
378

379
        fn try_from(other: OgrMetaDataDbType) -> Result<Self, Self::Error> {
14✔
380
            Ok(Self {
14✔
381
                loading_info: other.loading_info,
14✔
382
                result_descriptor: other.result_descriptor,
14✔
383
                phantom: PhantomData,
14✔
384
            })
14✔
385
        }
14✔
386
    }
387

388
    delegate_from_to_sql!(MockMetaData, MockMetaDataDbType);
389
    delegate_from_to_sql!(OgrMetaData, OgrMetaDataDbType);
390
}
391

392
#[cfg(test)]
393
mod tests {
394
    use super::*;
395
    use geoengine_datatypes::collections::VectorDataType;
396
    use geoengine_datatypes::spatial_reference::SpatialReferenceOption;
397

398
    #[tokio::test]
399
    async fn test() {
1✔
400
        let info = StaticMetaData {
1✔
401
            loading_info: 1_i32,
1✔
402
            result_descriptor: VectorResultDescriptor {
1✔
403
                data_type: VectorDataType::Data,
1✔
404
                spatial_reference: SpatialReferenceOption::Unreferenced,
1✔
405
                columns: Default::default(),
1✔
406
                time: None,
1✔
407
                bbox: None,
1✔
408
            },
1✔
409
            phantom: Default::default(),
1✔
410
        };
1✔
411

1✔
412
        let info: Box<dyn MetaData<i32, VectorResultDescriptor, VectorQueryRectangle>> =
1✔
413
            Box::new(info);
1✔
414

1✔
415
        let info2: Box<dyn Any + Send + Sync> = Box::new(info);
1✔
416

1✔
417
        let info3 = info2
1✔
418
            .downcast_ref::<Box<dyn MetaData<i32, VectorResultDescriptor, VectorQueryRectangle>>>()
1✔
419
            .unwrap();
1✔
420

1✔
421
        assert_eq!(
1✔
422
            info3.result_descriptor().await.unwrap(),
1✔
423
            VectorResultDescriptor {
1✔
424
                data_type: VectorDataType::Data,
1✔
425
                spatial_reference: SpatialReferenceOption::Unreferenced,
1✔
426
                columns: Default::default(),
1✔
427
                time: None,
1✔
428
                bbox: None,
1✔
429
            }
1✔
430
        );
1✔
431
    }
1✔
432
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc