• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

geo-engine / geoengine / 14891926777

07 May 2025 07:44PM UTC coverage: 89.975%. First build
14891926777

Pull #1000

github

web-flow
Merge 99d283bb1 into 4d3e935a3
Pull Request #1000: add tensor shape to ml model

109 of 194 new or added lines in 10 files covered. (56.19%)

126812 of 140942 relevant lines covered (89.97%)

57108.76 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

6.06
/services/src/machine_learning/mod.rs
1
use crate::{
2
    api::model::datatypes::{MlTensorShape3D, RasterDataType},
3
    config::{MachineLearning, get_config_element},
4
    datasets::upload::{UploadId, UploadRootPath},
5
    identifier,
6
    util::path_with_base_path,
7
};
8
use async_trait::async_trait;
9
use error::{MachineLearningError, error::CouldNotFindMlModelFileMachineLearningError};
10
use name::MlModelName;
11
use postgres_types::{FromSql, ToSql};
12
use serde::{Deserialize, Serialize};
13
use snafu::ResultExt;
14
use std::borrow::Cow;
15
use utoipa::{IntoParams, ToSchema};
16
use validator::{Validate, ValidationError};
17

18
pub mod error;
19
pub mod name;
20
mod postgres;
21

22
identifier!(MlModelId);
23

24
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, ToSchema, FromSql, ToSql)]
×
25
#[serde(rename_all = "camelCase")]
26
pub struct MlModelIdAndName {
27
    pub id: MlModelId,
28
    pub name: MlModelName,
29
}
30

31
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, ToSchema, FromSql, ToSql)]
75✔
32
#[serde(rename_all = "camelCase")]
33
pub struct MlModel {
34
    pub name: MlModelName,
35
    pub display_name: String,
36
    pub description: String,
37
    pub upload: UploadId,
38
    pub metadata: MlModelMetadata,
39
}
40

41
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, ToSchema, FromSql, ToSql)]
95✔
42
#[serde(rename_all = "camelCase")]
43
pub struct MlModelMetadata {
44
    pub file_name: String,
45
    pub input_type: RasterDataType,
46
    pub output_type: RasterDataType,
47
    pub input_shape: MlTensorShape3D,
48
    pub output_shape: MlTensorShape3D,
49
    // TODO: output measurement, e.g. classification or regression, label names for classification. This would have to be provided by the model creator along the model file as it cannot be extracted from the model file(?)
50
}
51

52
impl MlModel {
53
    pub fn metadata_for_operator(
×
54
        &self,
×
55
    ) -> Result<geoengine_datatypes::machine_learning::MlModelMetadata, MachineLearningError> {
×
56
        Ok(geoengine_datatypes::machine_learning::MlModelMetadata {
×
57
            file_path: path_with_base_path(
×
58
                &self
×
59
                    .upload
×
60
                    .root_path()
×
61
                    .context(CouldNotFindMlModelFileMachineLearningError)?,
×
62
                self.metadata.file_name.as_ref(),
×
63
            )
×
64
            .context(CouldNotFindMlModelFileMachineLearningError)?,
×
65
            input_type: self.metadata.input_type.into(),
×
66
            output_type: self.metadata.output_type.into(),
×
NEW
67
            input_shape: self.metadata.input_shape.into(),
×
NEW
68
            output_shape: self.metadata.output_shape.into(),
×
69
        })
70
    }
×
71
}
72

73
#[derive(Debug, Deserialize, Serialize, Clone, IntoParams, Validate)]
×
74
#[into_params(parameter_in = Query)]
75
pub struct MlModelListOptions {
76
    #[param(example = 0)]
77
    pub offset: u32,
78
    #[param(example = 2)]
79
    #[validate(custom(function = "validate_list_limit"))]
80
    pub limit: u32,
81
}
82

83
fn validate_list_limit(value: u32) -> Result<(), ValidationError> {
×
84
    let limit = get_config_element::<MachineLearning>()
×
85
        .expect("should exist because it is defined in the default config")
×
86
        .list_limit;
×
87
    if value <= limit {
×
88
        return Ok(());
×
89
    }
×
90

×
91
    let mut err = ValidationError::new("limit (too large)");
×
92
    err.add_param::<u32>(Cow::Borrowed("max limit"), &limit);
×
93
    Err(err)
×
94
}
×
95

96
#[async_trait]
97
pub trait MlModelDb {
98
    async fn list_models(
99
        &self,
100
        options: &MlModelListOptions,
101
    ) -> Result<Vec<MlModel>, MachineLearningError>;
102

103
    async fn load_model(&self, name: &MlModelName) -> Result<MlModel, MachineLearningError>;
104

105
    async fn load_model_metadata(
106
        &self,
107
        name: &MlModelName,
108
    ) -> Result<MlModelMetadata, MachineLearningError>;
109

110
    async fn add_model(&self, model: MlModel) -> Result<MlModelIdAndName, MachineLearningError>;
111

112
    async fn resolve_model_name_to_id(
113
        &self,
114
        name: &MlModelName,
115
    ) -> Result<Option<MlModelId>, MachineLearningError>;
116
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc