• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

getdozer / dozer / 4310689762

pending completion
4310689762

push

github

GitHub
chore: Provide context to `ApiError`s (#1109)

34 of 34 new or added lines in 8 files covered. (100.0%)

28340 of 40095 relevant lines covered (70.68%)

66996.33 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.58
/dozer-api/src/generator/protoc/generator/implementation.rs
1
use crate::errors::GenerationError;
2
use crate::errors::GenerationError::ServiceNotFound;
3
use crate::generator::protoc::generator::{
4
    CountMethodDesc, DecimalDesc, EventDesc, OnEventMethodDesc, PointDesc, QueryMethodDesc,
5
    RecordWithIdDesc, TokenMethodDesc, TokenResponseDesc,
6
};
7
use dozer_types::log::error;
8
use dozer_types::models::api_security::ApiSecurity;
9
use dozer_types::models::flags::Flags;
10
use dozer_types::serde::{self, Deserialize, Serialize};
11
use dozer_types::types::{FieldType, Schema};
12
use handlebars::Handlebars;
13
use inflector::Inflector;
14
use prost_reflect::{DescriptorPool, FieldDescriptor, Kind, MessageDescriptor};
15
use std::path::{Path, PathBuf};
16

17
use super::{CountResponseDesc, QueryResponseDesc, RecordDesc, ServiceDesc};
18

19
const POINT_TYPE_CLASS: &str = "dozer.types.PointType";
20
const DECIMAL_TYPE_CLASS: &str = "dozer.types.RustDecimal";
21
const TIMESTAMP_TYPE_CLASS: &str = "google.protobuf.Timestamp";
22

23
#[derive(Debug, Clone, Serialize, Deserialize)]
23✔
24
#[serde(crate = "self::serde")]
25
struct ProtoMetadata {
26
    import_libs: Vec<String>,
27
    package_name: String,
28
    lower_name: String,
29
    plural_pascal_name: String,
30
    pascal_name: String,
31
    props: Vec<String>,
32
    version_field_id: usize,
33
    enable_token: bool,
34
    enable_on_event: bool,
35
}
36

37
pub struct ProtoGeneratorImpl<'a> {
38
    handlebars: Handlebars<'a>,
39
    schema: &'a dozer_types::types::Schema,
40
    names: Names,
41
    folder_path: &'a Path,
42
    security: &'a Option<ApiSecurity>,
43
    flags: &'a Option<Flags>,
44
}
45

46
impl<'a> ProtoGeneratorImpl<'a> {
47
    pub fn new(
23✔
48
        schema_name: &str,
23✔
49
        schema: &'a Schema,
23✔
50
        folder_path: &'a Path,
23✔
51
        security: &'a Option<ApiSecurity>,
23✔
52
        flags: &'a Option<Flags>,
23✔
53
    ) -> Result<Self, GenerationError> {
23✔
54
        let names = Names::new(schema_name, schema);
23✔
55
        let mut generator = Self {
23✔
56
            handlebars: Handlebars::new(),
23✔
57
            schema,
23✔
58
            names,
23✔
59
            folder_path,
23✔
60
            security,
23✔
61
            flags,
23✔
62
        };
23✔
63
        generator.register_template()?;
23✔
64
        Ok(generator)
23✔
65
    }
23✔
66

67
    fn register_template(&mut self) -> Result<(), GenerationError> {
23✔
68
        let main_template = include_str!("template/proto.tmpl");
23✔
69
        self.handlebars
23✔
70
            .register_template_string("main", main_template)
23✔
71
            .map_err(|e| GenerationError::HandlebarsTemplate(Box::new(e)))?;
23✔
72
        Ok(())
23✔
73
    }
23✔
74

75
    fn props(&self) -> Vec<String> {
23✔
76
        self.schema
23✔
77
            .fields
23✔
78
            .iter()
23✔
79
            .enumerate()
23✔
80
            .zip(&self.names.record_field_names)
23✔
81
            .map(|((idx, field), field_name)| -> String {
23✔
82
                let optional = if field.nullable { "optional " } else { "" };
75✔
83
                let proto_type = convert_dozer_type_to_proto_type(field.typ.to_owned()).unwrap();
75✔
84
                format!("{optional}{proto_type} {field_name} = {};", idx + 1)
75✔
85
            })
75✔
86
            .collect()
23✔
87
    }
23✔
88

89
    fn libs_by_type(&self) -> Result<Vec<String>, GenerationError> {
23✔
90
        let type_need_import_libs = [TIMESTAMP_TYPE_CLASS];
23✔
91
        let mut libs_import: Vec<String> = self
23✔
92
            .schema
23✔
93
            .fields
23✔
94
            .iter()
23✔
95
            .map(|field| convert_dozer_type_to_proto_type(field.to_owned().typ).unwrap())
75✔
96
            .filter(|proto_type| -> bool {
75✔
97
                type_need_import_libs.contains(&proto_type.to_owned().as_str())
75✔
98
            })
75✔
99
            .map(|proto_type| match proto_type.as_str() {
23✔
100
                TIMESTAMP_TYPE_CLASS => "google/protobuf/timestamp.proto".to_owned(),
3✔
101
                _ => "".to_owned(),
×
102
            })
23✔
103
            .collect();
23✔
104
        libs_import.push("types.proto".to_owned());
23✔
105
        libs_import.sort();
23✔
106
        libs_import.dedup();
23✔
107
        Ok(libs_import)
23✔
108
    }
23✔
109

110
    fn get_metadata(&self) -> Result<ProtoMetadata, GenerationError> {
23✔
111
        let import_libs: Vec<String> = self.libs_by_type()?;
23✔
112
        let metadata = ProtoMetadata {
23✔
113
            package_name: self.names.package_name.clone(),
23✔
114
            import_libs,
23✔
115
            lower_name: self.names.lower_name.clone(),
23✔
116
            plural_pascal_name: self.names.plural_pascal_name.clone(),
23✔
117
            pascal_name: self.names.pascal_name.clone(),
23✔
118
            props: self.props(),
23✔
119
            version_field_id: self.schema.fields.len() + 1,
23✔
120
            enable_token: self.security.is_some(),
23✔
121
            enable_on_event: self.flags.clone().unwrap_or_default().push_events,
23✔
122
        };
23✔
123
        Ok(metadata)
23✔
124
    }
23✔
125

126
    pub fn generate_proto(&self) -> Result<(String, PathBuf), GenerationError> {
23✔
127
        if !Path::new(&self.folder_path).exists() {
23✔
128
            return Err(GenerationError::DirPathNotExist(
×
129
                self.folder_path.to_path_buf(),
×
130
            ));
×
131
        }
23✔
132

133
        let metadata = self.get_metadata()?;
23✔
134

135
        let types_proto = include_str!("../../../../../dozer-types/protos/types.proto");
23✔
136

137
        let resource_proto = self.handlebars.render("main", &metadata)?;
23✔
138

×
139
        // Copy types proto file
×
140
        let types_path = self.folder_path.join("types.proto");
23✔
141
        std::fs::write(&types_path, types_proto)
23✔
142
            .map_err(|e| GenerationError::FailedToWriteToFile(types_path, e))?;
23✔
143

×
144
        let resource_path = self.folder_path.join(&self.names.proto_file_name);
23✔
145
        std::fs::write(&resource_path, &resource_proto)
23✔
146
            .map_err(|e| GenerationError::FailedToWriteToFile(resource_path.clone(), e))?;
23✔
147

×
148
        Ok((resource_proto, resource_path))
23✔
149
    }
23✔
150

×
151
    pub fn read(
16✔
152
        descriptor: &DescriptorPool,
16✔
153
        schema_name: &str,
16✔
154
    ) -> Result<ServiceDesc, GenerationError> {
16✔
155
        fn get_field(
274✔
156
            message: &MessageDescriptor,
274✔
157
            field_name: &str,
274✔
158
        ) -> Result<FieldDescriptor, GenerationError> {
274✔
159
            message
274✔
160
                .get_field_by_name(field_name)
274✔
161
                .ok_or_else(|| GenerationError::FieldNotFound {
274✔
162
                    message_name: message.name().to_string(),
×
163
                    field_name: field_name.to_string(),
×
164
                })
274✔
165
        }
274✔
166

16✔
167
        let record_desc_from_message =
16✔
168
            |message: MessageDescriptor| -> Result<RecordDesc, GenerationError> {
24✔
169
                let version_field = get_field(&message, "__dozer_record_version")?;
24✔
170

×
171
                if let Some(point_values) = descriptor.get_message_by_name(POINT_TYPE_CLASS) {
24✔
172
                    let pv = point_values;
24✔
173
                    if let Some(decimal_values) = descriptor.get_message_by_name(DECIMAL_TYPE_CLASS)
24✔
174
                    {
×
175
                        let dv = decimal_values;
24✔
176
                        Ok(RecordDesc {
24✔
177
                            message,
24✔
178
                            version_field,
24✔
179
                            point_field: PointDesc {
24✔
180
                                message: pv.clone(),
24✔
181
                                x: get_field(&pv, "x")?,
24✔
182
                                y: get_field(&pv, "y")?,
24✔
183
                            },
×
184
                            decimal_field: DecimalDesc {
×
185
                                message: dv.clone(),
24✔
186
                                flags: get_field(&dv, "flags")?,
24✔
187
                                lo: get_field(&dv, "lo")?,
24✔
188
                                mid: get_field(&dv, "mid")?,
24✔
189
                                hi: get_field(&dv, "hi")?,
24✔
190
                            },
×
191
                        })
192
                    } else {
193
                        Err(ServiceNotFound(DECIMAL_TYPE_CLASS.to_string()))
×
194
                    }
×
195
                } else {
×
196
                    Err(ServiceNotFound(POINT_TYPE_CLASS.to_string()))
×
197
                }
×
198
            };
24✔
199

200
        let names = Names::new(schema_name, &Schema::empty());
16✔
201
        let service_name = format!("{}.{}", &names.package_name, &names.plural_pascal_name);
16✔
202
        let service = descriptor
16✔
203
            .get_service_by_name(&service_name)
16✔
204
            .ok_or(GenerationError::ServiceNotFound(service_name))?;
16✔
205

206
        let mut count = None;
16✔
207
        let mut query = None;
16✔
208
        let mut on_event = None;
16✔
209
        let mut token = None;
16✔
210
        for method in service.methods() {
50✔
211
            match method.name() {
50✔
212
                "count" => {
50✔
213
                    let message = method.output();
16✔
214
                    let count_field = get_field(&message, "count")?;
16✔
215
                    count = Some(CountMethodDesc {
16✔
216
                        method,
16✔
217
                        response_desc: CountResponseDesc {
16✔
218
                            message,
16✔
219
                            count_field,
16✔
220
                        },
16✔
221
                    });
16✔
222
                }
×
223
                "query" => {
34✔
224
                    let message = method.output();
16✔
225
                    let records_field = get_field(&message, "records")?;
16✔
226
                    let records_filed_kind = records_field.kind();
16✔
227
                    let Kind::Message(record_with_id_message) = records_filed_kind else {
16✔
228
                        return Err(GenerationError::ExpectedMessageField {
×
229
                            filed_name: records_field.full_name().to_string(),
×
230
                            actual: records_filed_kind
×
231
                        });
×
232
                    };
×
233
                    let id_field = get_field(&record_with_id_message, "id")?;
16✔
234
                    let record_field = get_field(&record_with_id_message, "record")?;
16✔
235
                    let record_field_kind = record_field.kind();
16✔
236
                    let Kind::Message(record_message) = record_field_kind else {
16✔
237
                        return Err(GenerationError::ExpectedMessageField {
×
238
                            filed_name: record_field.full_name().to_string(),
×
239
                            actual: record_field_kind
×
240
                        });
×
241
                    };
×
242
                    query = Some(QueryMethodDesc {
×
243
                        method,
16✔
244
                        response_desc: QueryResponseDesc {
16✔
245
                            message,
16✔
246
                            records_field,
16✔
247
                            record_with_id_desc: RecordWithIdDesc {
16✔
248
                                message: record_with_id_message,
16✔
249
                                id_field,
16✔
250
                                record_field,
16✔
251
                                record_desc: record_desc_from_message(record_message)?,
16✔
252
                            },
×
253
                        },
×
254
                    });
×
255
                }
×
256
                "on_event" => {
18✔
257
                    let message = method.output();
8✔
258
                    let typ_field = get_field(&message, "typ")?;
8✔
259
                    let old_field = get_field(&message, "old")?;
8✔
260
                    let new_field = get_field(&message, "new")?;
8✔
261
                    let new_id_field = get_field(&message, "new_id")?;
8✔
262
                    let old_field_kind = old_field.kind();
8✔
263
                    let Kind::Message(record_message) = old_field_kind else {
8✔
264
                        return Err(GenerationError::ExpectedMessageField {
×
265
                            filed_name: old_field.full_name().to_string(),
×
266
                            actual: old_field_kind
×
267
                        });
×
268
                    };
×
269
                    on_event = Some(OnEventMethodDesc {
×
270
                        method,
8✔
271
                        response_desc: EventDesc {
8✔
272
                            message,
8✔
273
                            typ_field,
8✔
274
                            old_field,
8✔
275
                            new_field,
8✔
276
                            new_id_field,
8✔
277
                            record_desc: record_desc_from_message(record_message)?,
8✔
278
                        },
×
279
                    });
×
280
                }
×
281
                "token" => {
10✔
282
                    let message = method.output();
10✔
283
                    let token_field = get_field(&message, "token")?;
10✔
284
                    token = Some(TokenMethodDesc {
10✔
285
                        method,
10✔
286
                        response_desc: TokenResponseDesc {
10✔
287
                            message,
10✔
288
                            token_field,
10✔
289
                        },
10✔
290
                    });
10✔
291
                }
×
292
                _ => {
×
293
                    return Err(GenerationError::UnexpectedMethod(
×
294
                        method.full_name().to_string(),
×
295
                    ))
×
296
                }
×
297
            }
×
298
        }
×
299

300
        let Some(count) = count else {
16✔
301
            return Err(GenerationError::MissingCountMethod(service.full_name().to_string()));
×
302
        };
×
303
        let Some(query) = query else {
16✔
304
            return Err(GenerationError::MissingQueryMethod(service.full_name().to_string()));
×
305
        };
306

307
        Ok(ServiceDesc {
16✔
308
            service,
16✔
309
            count,
16✔
310
            query,
16✔
311
            on_event,
16✔
312
            token,
16✔
313
        })
16✔
314
    }
16✔
315
}
×
316

×
317
struct Names {
×
318
    proto_file_name: String,
×
319
    package_name: String,
×
320
    lower_name: String,
×
321
    plural_pascal_name: String,
×
322
    pascal_name: String,
×
323
    record_field_names: Vec<String>,
324
}
325

326
impl Names {
327
    fn new(schema_name: &str, schema: &Schema) -> Self {
39✔
328
        if schema_name.contains('-') {
39✔
329
            error!("Name of the endpoint should not contain `-`.");
×
330
        }
39✔
331
        let schema_name = schema_name.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
195✔
332

39✔
333
        let package_name = format!("dozer.generated.{schema_name}");
39✔
334
        let lower_name = schema_name.to_lowercase();
39✔
335
        let plural_pascal_name = schema_name.to_pascal_case().to_plural();
39✔
336
        let pascal_name = schema_name.to_pascal_case().to_singular();
39✔
337
        let record_field_names = schema
39✔
338
            .fields
39✔
339
            .iter()
39✔
340
            .map(|field| {
75✔
341
                if field.name.contains('-') {
75✔
342
                    error!("Name of the field should not contain `-`.");
×
343
                }
75✔
344
                field
75✔
345
                    .name
75✔
346
                    .replace(|c: char| !c.is_ascii_alphanumeric(), "_")
513✔
347
            })
75✔
348
            .collect::<Vec<_>>();
39✔
349
        Self {
39✔
350
            proto_file_name: format!("{lower_name}.proto"),
39✔
351
            package_name,
39✔
352
            lower_name,
39✔
353
            plural_pascal_name,
39✔
354
            pascal_name,
39✔
355
            record_field_names,
39✔
356
        }
39✔
357
    }
39✔
358
}
×
359

×
360
fn convert_dozer_type_to_proto_type(field_type: FieldType) -> Result<String, GenerationError> {
150✔
361
    match field_type {
150✔
362
        FieldType::UInt => Ok("uint64".to_owned()),
52✔
363
        FieldType::Int => Ok("int64".to_owned()),
40✔
364
        FieldType::Float => Ok("double".to_owned()),
6✔
365
        FieldType::Boolean => Ok("bool".to_owned()),
×
366
        FieldType::String => Ok("string".to_owned()),
46✔
367
        FieldType::Text => Ok("string".to_owned()),
×
368
        FieldType::Binary => Ok("bytes".to_owned()),
×
369
        FieldType::Decimal => Ok(DECIMAL_TYPE_CLASS.to_owned()),
×
370
        FieldType::Timestamp => Ok(TIMESTAMP_TYPE_CLASS.to_owned()),
6✔
371
        FieldType::Date => Ok("string".to_owned()),
×
372
        FieldType::Bson => Ok("bytes".to_owned()),
×
373
        FieldType::Point => Ok(POINT_TYPE_CLASS.to_owned()),
×
374
    }
×
375
}
150✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc