• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

getdozer / dozer / 4280933931

pending completion
4280933931

push

github

GitHub
fix: `timestamp`, `point`, `decimal` support for grpc (#1060)

74 of 74 new or added lines in 4 files covered. (100.0%)

26425 of 36614 relevant lines covered (72.17%)

49036.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

72.95
/dozer-api/src/generator/protoc/generator/implementation.rs
1
use crate::errors::GenerationError;
2
use crate::errors::GenerationError::ServiceNotFound;
3
use crate::generator::protoc::generator::{
4
    CountMethodDesc, DecimalDesc, EventDesc, OnEventMethodDesc, PointDesc, QueryMethodDesc,
5
    RecordWithIdDesc, TokenMethodDesc, TokenResponseDesc,
6
};
7
use dozer_types::log::error;
8
use dozer_types::models::api_security::ApiSecurity;
9
use dozer_types::models::flags::Flags;
10
use dozer_types::serde::{self, Deserialize, Serialize};
11
use dozer_types::types::{FieldType, Schema};
12
use handlebars::Handlebars;
13
use inflector::Inflector;
14
use prost_reflect::{DescriptorPool, FieldDescriptor, Kind, MessageDescriptor};
15
use std::path::{Path, PathBuf};
16

17
use super::{CountResponseDesc, QueryResponseDesc, RecordDesc, ServiceDesc};
18

×
19
const POINT_TYPE_CLASS: &str = "dozer.types.PointType";
20
const DECIMAL_TYPE_CLASS: &str = "dozer.types.RustDecimal";
21
const TIMESTAMP_TYPE_CLASS: &str = "google.protobuf.Timestamp";
22

23
#[derive(Debug, Clone, Serialize, Deserialize)]
23✔
24
#[serde(crate = "self::serde")]
25
struct ProtoMetadata {
26
    import_libs: Vec<String>,
27
    package_name: String,
28
    lower_name: String,
29
    plural_pascal_name: String,
30
    pascal_name: String,
31
    props: Vec<String>,
32
    version_field_id: usize,
33
    enable_token: bool,
34
    enable_on_event: bool,
35
}
36

37
pub struct ProtoGeneratorImpl<'a> {
38
    handlebars: Handlebars<'a>,
39
    schema: &'a dozer_types::types::Schema,
40
    names: Names,
41
    folder_path: &'a Path,
42
    security: &'a Option<ApiSecurity>,
×
43
    flags: &'a Option<Flags>,
×
44
}
×
45

×
46
impl<'a> ProtoGeneratorImpl<'a> {
×
47
    pub fn new(
23✔
48
        schema_name: &str,
23✔
49
        schema: &'a Schema,
23✔
50
        folder_path: &'a Path,
23✔
51
        security: &'a Option<ApiSecurity>,
23✔
52
        flags: &'a Option<Flags>,
23✔
53
    ) -> Result<Self, GenerationError> {
23✔
54
        let names = Names::new(schema_name, schema);
23✔
55
        let mut generator = Self {
23✔
56
            handlebars: Handlebars::new(),
23✔
57
            schema,
23✔
58
            names,
23✔
59
            folder_path,
23✔
60
            security,
23✔
61
            flags,
23✔
62
        };
23✔
63
        generator.register_template()?;
23✔
64
        Ok(generator)
23✔
65
    }
23✔
66

×
67
    fn register_template(&mut self) -> Result<(), GenerationError> {
23✔
68
        let main_template = include_str!("template/proto.tmpl");
23✔
69
        self.handlebars
23✔
70
            .register_template_string("main", main_template)
23✔
71
            .map_err(|e| GenerationError::InternalError(Box::new(e)))?;
23✔
72
        Ok(())
23✔
73
    }
23✔
74

×
75
    fn props(&self) -> Vec<String> {
23✔
76
        self.schema
23✔
77
            .fields
23✔
78
            .iter()
23✔
79
            .enumerate()
23✔
80
            .zip(&self.names.record_field_names)
23✔
81
            .map(|((idx, field), field_name)| -> String {
23✔
82
                let optional = if field.nullable { "optional " } else { "" };
75✔
83
                let proto_type = convert_dozer_type_to_proto_type(field.typ.to_owned()).unwrap();
75✔
84
                format!("{optional}{proto_type} {field_name} = {};", idx + 1)
75✔
85
            })
75✔
86
            .collect()
23✔
87
    }
23✔
88

×
89
    fn libs_by_type(&self) -> Result<Vec<String>, GenerationError> {
23✔
90
        let type_need_import_libs = [TIMESTAMP_TYPE_CLASS];
23✔
91
        let mut libs_import: Vec<String> = self
23✔
92
            .schema
23✔
93
            .fields
23✔
94
            .iter()
23✔
95
            .map(|field| convert_dozer_type_to_proto_type(field.to_owned().typ).unwrap())
75✔
96
            .filter(|proto_type| -> bool {
75✔
97
                type_need_import_libs.contains(&proto_type.to_owned().as_str())
75✔
98
            })
75✔
99
            .map(|proto_type| match proto_type.as_str() {
23✔
100
                TIMESTAMP_TYPE_CLASS => "google/protobuf/timestamp.proto".to_owned(),
3✔
101
                _ => "".to_owned(),
×
102
            })
23✔
103
            .collect();
23✔
104
        libs_import.push("types.proto".to_owned());
23✔
105
        libs_import.sort();
23✔
106
        libs_import.dedup();
23✔
107
        Ok(libs_import)
23✔
108
    }
23✔
109

×
110
    fn get_metadata(&self) -> Result<ProtoMetadata, GenerationError> {
23✔
111
        let import_libs: Vec<String> = self.libs_by_type()?;
23✔
112
        let metadata = ProtoMetadata {
23✔
113
            package_name: self.names.package_name.clone(),
23✔
114
            import_libs,
23✔
115
            lower_name: self.names.lower_name.clone(),
23✔
116
            plural_pascal_name: self.names.plural_pascal_name.clone(),
23✔
117
            pascal_name: self.names.pascal_name.clone(),
23✔
118
            props: self.props(),
23✔
119
            version_field_id: self.schema.fields.len() + 1,
23✔
120
            enable_token: self.security.is_some(),
23✔
121
            enable_on_event: self.flags.clone().unwrap_or_default().push_events,
23✔
122
        };
23✔
123
        Ok(metadata)
23✔
124
    }
23✔
125

×
126
    pub fn generate_proto(&self) -> Result<(String, PathBuf), GenerationError> {
23✔
127
        if !Path::new(&self.folder_path).exists() {
23✔
128
            return Err(GenerationError::DirPathNotExist(
×
129
                self.folder_path.to_path_buf(),
×
130
            ));
×
131
        }
23✔
132

×
133
        let metadata = self.get_metadata()?;
23✔
134

×
135
        let types_proto = include_str!("../../../../../dozer-types/protos/types.proto");
23✔
136

137
        let resource_proto = self
23✔
138
            .handlebars
23✔
139
            .render("main", &metadata)
23✔
140
            .map_err(|e| GenerationError::InternalError(Box::new(e)))?;
23✔
141

×
142
        // Copy types proto file
×
143
        let mut types_file = std::fs::File::create(self.folder_path.join("types.proto"))
23✔
144
            .map_err(|e| GenerationError::InternalError(Box::new(e)))?;
23✔
145

×
146
        let resource_path = self.folder_path.join(&self.names.proto_file_name);
23✔
147
        let mut resource_file = std::fs::File::create(resource_path.clone())
23✔
148
            .map_err(|e| GenerationError::InternalError(Box::new(e)))?;
23✔
149

×
150
        std::io::Write::write_all(&mut types_file, types_proto.as_bytes())
23✔
151
            .map_err(|e| GenerationError::InternalError(Box::new(e)))?;
23✔
152

×
153
        std::io::Write::write_all(&mut resource_file, resource_proto.as_bytes())
23✔
154
            .map_err(|e| GenerationError::InternalError(Box::new(e)))?;
23✔
155

×
156
        Ok((resource_proto, resource_path))
23✔
157
    }
23✔
158

×
159
    pub fn read(
16✔
160
        descriptor: &DescriptorPool,
16✔
161
        schema_name: &str,
16✔
162
    ) -> Result<ServiceDesc, GenerationError> {
16✔
163
        fn get_field(
274✔
164
            message: &MessageDescriptor,
274✔
165
            field_name: &str,
274✔
166
        ) -> Result<FieldDescriptor, GenerationError> {
274✔
167
            message
274✔
168
                .get_field_by_name(field_name)
274✔
169
                .ok_or_else(|| GenerationError::FieldNotFound {
274✔
170
                    message_name: message.name().to_string(),
×
171
                    field_name: field_name.to_string(),
×
172
                })
274✔
173
        }
274✔
174

16✔
175
        let record_desc_from_message =
16✔
176
            |message: MessageDescriptor| -> Result<RecordDesc, GenerationError> {
24✔
177
                let version_field = get_field(&message, "__dozer_record_version")?;
24✔
178

×
179
                if let Some(point_values) = descriptor.get_message_by_name(POINT_TYPE_CLASS) {
24✔
180
                    let pv = point_values;
24✔
181
                    if let Some(decimal_values) = descriptor.get_message_by_name(DECIMAL_TYPE_CLASS)
24✔
182
                    {
×
183
                        let dv = decimal_values;
24✔
184
                        Ok(RecordDesc {
24✔
185
                            message,
24✔
186
                            version_field,
24✔
187
                            point_field: PointDesc {
24✔
188
                                message: pv.clone(),
24✔
189
                                x: get_field(&pv, "x")?,
24✔
190
                                y: get_field(&pv, "y")?,
24✔
191
                            },
×
192
                            decimal_field: DecimalDesc {
×
193
                                message: dv.clone(),
24✔
194
                                flags: get_field(&dv, "flags")?,
24✔
195
                                lo: get_field(&dv, "lo")?,
24✔
196
                                mid: get_field(&dv, "mid")?,
24✔
197
                                hi: get_field(&dv, "hi")?,
24✔
198
                            },
×
199
                        })
×
200
                    } else {
×
201
                        Err(ServiceNotFound(DECIMAL_TYPE_CLASS.to_string()))
×
202
                    }
203
                } else {
×
204
                    Err(ServiceNotFound(POINT_TYPE_CLASS.to_string()))
×
205
                }
×
206
            };
24✔
207

×
208
        let names = Names::new(schema_name, &Schema::empty());
16✔
209
        let service_name = format!("{}.{}", &names.package_name, &names.plural_pascal_name);
16✔
210
        let service = descriptor
16✔
211
            .get_service_by_name(&service_name)
16✔
212
            .ok_or(GenerationError::ServiceNotFound(service_name))?;
16✔
213

×
214
        let mut count = None;
16✔
215
        let mut query = None;
16✔
216
        let mut on_event = None;
16✔
217
        let mut token = None;
16✔
218
        for method in service.methods() {
50✔
219
            match method.name() {
50✔
220
                "count" => {
50✔
221
                    let message = method.output();
16✔
222
                    let count_field = get_field(&message, "count")?;
16✔
223
                    count = Some(CountMethodDesc {
16✔
224
                        method,
16✔
225
                        response_desc: CountResponseDesc {
16✔
226
                            message,
16✔
227
                            count_field,
16✔
228
                        },
16✔
229
                    });
16✔
230
                }
×
231
                "query" => {
34✔
232
                    let message = method.output();
16✔
233
                    let records_field = get_field(&message, "records")?;
16✔
234
                    let records_filed_kind = records_field.kind();
16✔
235
                    let Kind::Message(record_with_id_message) = records_filed_kind else {
16✔
236
                        return Err(GenerationError::ExpectedMessageField {
×
237
                            filed_name: records_field.full_name().to_string(),
×
238
                            actual: records_filed_kind
×
239
                        });
×
240
                    };
×
241
                    let id_field = get_field(&record_with_id_message, "id")?;
16✔
242
                    let record_field = get_field(&record_with_id_message, "record")?;
16✔
243
                    let record_field_kind = record_field.kind();
16✔
244
                    let Kind::Message(record_message) = record_field_kind else {
16✔
245
                        return Err(GenerationError::ExpectedMessageField {
×
246
                            filed_name: record_field.full_name().to_string(),
×
247
                            actual: record_field_kind
×
248
                        });
×
249
                    };
250
                    query = Some(QueryMethodDesc {
×
251
                        method,
16✔
252
                        response_desc: QueryResponseDesc {
16✔
253
                            message,
16✔
254
                            records_field,
16✔
255
                            record_with_id_desc: RecordWithIdDesc {
16✔
256
                                message: record_with_id_message,
16✔
257
                                id_field,
16✔
258
                                record_field,
16✔
259
                                record_desc: record_desc_from_message(record_message)?,
16✔
260
                            },
261
                        },
×
262
                    });
×
263
                }
×
264
                "on_event" => {
18✔
265
                    let message = method.output();
8✔
266
                    let typ_field = get_field(&message, "typ")?;
8✔
267
                    let old_field = get_field(&message, "old")?;
8✔
268
                    let new_field = get_field(&message, "new")?;
8✔
269
                    let new_id_field = get_field(&message, "new_id")?;
8✔
270
                    let old_field_kind = old_field.kind();
8✔
271
                    let Kind::Message(record_message) = old_field_kind else {
8✔
272
                        return Err(GenerationError::ExpectedMessageField {
×
273
                            filed_name: old_field.full_name().to_string(),
×
274
                            actual: old_field_kind
×
275
                        });
×
276
                    };
277
                    on_event = Some(OnEventMethodDesc {
278
                        method,
8✔
279
                        response_desc: EventDesc {
8✔
280
                            message,
8✔
281
                            typ_field,
8✔
282
                            old_field,
8✔
283
                            new_field,
8✔
284
                            new_id_field,
8✔
285
                            record_desc: record_desc_from_message(record_message)?,
8✔
286
                        },
287
                    });
×
288
                }
×
289
                "token" => {
10✔
290
                    let message = method.output();
10✔
291
                    let token_field = get_field(&message, "token")?;
10✔
292
                    token = Some(TokenMethodDesc {
10✔
293
                        method,
10✔
294
                        response_desc: TokenResponseDesc {
10✔
295
                            message,
10✔
296
                            token_field,
10✔
297
                        },
10✔
298
                    });
10✔
299
                }
300
                _ => {
301
                    return Err(GenerationError::UnexpectedMethod(
×
302
                        method.full_name().to_string(),
×
303
                    ))
×
304
                }
305
            }
306
        }
307

×
308
        let Some(count) = count else {
16✔
309
            return Err(GenerationError::MissingCountMethod(service.full_name().to_string()));
×
310
        };
×
311
        let Some(query) = query else {
16✔
312
            return Err(GenerationError::MissingQueryMethod(service.full_name().to_string()));
×
313
        };
×
314

×
315
        Ok(ServiceDesc {
16✔
316
            service,
16✔
317
            count,
16✔
318
            query,
16✔
319
            on_event,
16✔
320
            token,
16✔
321
        })
16✔
322
    }
16✔
323
}
×
324

×
325
struct Names {
×
326
    proto_file_name: String,
×
327
    package_name: String,
×
328
    lower_name: String,
×
329
    plural_pascal_name: String,
×
330
    pascal_name: String,
×
331
    record_field_names: Vec<String>,
×
332
}
×
333

×
334
impl Names {
×
335
    fn new(schema_name: &str, schema: &Schema) -> Self {
39✔
336
        if schema_name.contains('-') {
39✔
337
            error!("Name of the endpoint should not contain `-`.");
×
338
        }
39✔
339
        let schema_name = schema_name.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
195✔
340

39✔
341
        let package_name = format!("dozer.generated.{schema_name}");
39✔
342
        let lower_name = schema_name.to_lowercase();
39✔
343
        let plural_pascal_name = schema_name.to_pascal_case().to_plural();
39✔
344
        let pascal_name = schema_name.to_pascal_case().to_singular();
39✔
345
        let record_field_names = schema
39✔
346
            .fields
39✔
347
            .iter()
39✔
348
            .map(|field| {
75✔
349
                if field.name.contains('-') {
75✔
350
                    error!("Name of the field should not contain `-`.");
×
351
                }
75✔
352
                field
75✔
353
                    .name
75✔
354
                    .replace(|c: char| !c.is_ascii_alphanumeric(), "_")
513✔
355
            })
75✔
356
            .collect::<Vec<_>>();
39✔
357
        Self {
39✔
358
            proto_file_name: format!("{lower_name}.proto"),
39✔
359
            package_name,
39✔
360
            lower_name,
39✔
361
            plural_pascal_name,
39✔
362
            pascal_name,
39✔
363
            record_field_names,
39✔
364
        }
39✔
365
    }
39✔
366
}
367

368
fn convert_dozer_type_to_proto_type(field_type: FieldType) -> Result<String, GenerationError> {
150✔
369
    match field_type {
150✔
370
        FieldType::UInt => Ok("uint64".to_owned()),
52✔
371
        FieldType::Int => Ok("int64".to_owned()),
40✔
372
        FieldType::Float => Ok("double".to_owned()),
6✔
373
        FieldType::Boolean => Ok("bool".to_owned()),
×
374
        FieldType::String => Ok("string".to_owned()),
46✔
375
        FieldType::Text => Ok("string".to_owned()),
×
376
        FieldType::Binary => Ok("bytes".to_owned()),
×
377
        FieldType::Decimal => Ok(DECIMAL_TYPE_CLASS.to_owned()),
×
378
        FieldType::Timestamp => Ok(TIMESTAMP_TYPE_CLASS.to_owned()),
6✔
379
        FieldType::Date => Ok("string".to_owned()),
×
380
        FieldType::Bson => Ok("bytes".to_owned()),
×
381
        FieldType::Point => Ok(POINT_TYPE_CLASS.to_owned()),
×
382
    }
383
}
150✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc