• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

getdozer / dozer / 4071639102

pending completion
4071639102

Pull #780

github

GitHub
Merge 6694befe5 into 3a0622c99
Pull Request #780: fix: Clear PK from Projection output

4 of 4 new or added lines in 1 file covered. (100.0%)

24332 of 35774 relevant lines covered (68.02%)

35765.43 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.08
/dozer-sql/src/pipeline/aggregation/factory.rs
1
use std::collections::HashMap;
2

3
use dozer_core::dag::{
4
    dag::DEFAULT_PORT_HANDLE,
5
    errors::ExecutionError,
6
    node::{OutputPortDef, OutputPortType, PortHandle, Processor, ProcessorFactory},
7
};
8
use dozer_types::types::{FieldDefinition, Schema};
9
use sqlparser::ast::{Expr as SqlExpr, Expr, Ident, SelectItem};
10

11
use crate::pipeline::builder::SchemaSQLContext;
12
use crate::pipeline::{
13
    errors::PipelineError,
14
    expression::{
15
        aggregate::AggregateFunctionType,
16
        builder::{BuilderExpressionType, ExpressionBuilder},
17
        execution::{Expression, ExpressionExecutor},
18
    },
19
    projection::{factory::parse_sql_select_item, processor::ProjectionProcessor},
20
};
21

22
use super::{
23
    aggregator::Aggregator,
24
    processor::{AggregationProcessor, FieldRule},
25
};
26

27
#[derive(Debug)]
×
28
pub struct AggregationProcessorFactory {
29
    select: Vec<SelectItem>,
30
    groupby: Vec<SqlExpr>,
31
    stateful: bool,
32
}
33

34
impl AggregationProcessorFactory {
35
    /// Creates a new [`AggregationProcessorFactory`].
36
    pub fn new(select: Vec<SelectItem>, groupby: Vec<SqlExpr>, stateful: bool) -> Self {
111✔
37
        Self {
111✔
38
            select,
111✔
39
            groupby,
111✔
40
            stateful,
111✔
41
        }
111✔
42
    }
111✔
43
}
44

45
impl ProcessorFactory<SchemaSQLContext> for AggregationProcessorFactory {
46
    fn get_input_ports(&self) -> Vec<PortHandle> {
290✔
47
        vec![DEFAULT_PORT_HANDLE]
290✔
48
    }
290✔
49

50
    fn get_output_ports(&self) -> Vec<OutputPortDef> {
387✔
51
        if self.stateful {
387✔
52
            vec![OutputPortDef::new(
×
53
                DEFAULT_PORT_HANDLE,
×
54
                OutputPortType::StatefulWithPrimaryKeyLookup {
×
55
                    retr_old_records_for_deletes: true,
×
56
                    retr_old_records_for_updates: true,
×
57
                },
×
58
            )]
×
59
        } else {
60
            vec![OutputPortDef::new(
387✔
61
                DEFAULT_PORT_HANDLE,
387✔
62
                OutputPortType::Stateless,
387✔
63
            )]
387✔
64
        }
65
    }
387✔
66

67
    fn get_output_schema(
193✔
68
        &self,
193✔
69
        _output_port: &PortHandle,
193✔
70
        input_schemas: &HashMap<PortHandle, (Schema, SchemaSQLContext)>,
193✔
71
    ) -> Result<(Schema, SchemaSQLContext), ExecutionError> {
193✔
72
        let (input_schema, ctx) = input_schemas
193✔
73
            .get(&DEFAULT_PORT_HANDLE)
193✔
74
            .ok_or(ExecutionError::InvalidPortHandle(DEFAULT_PORT_HANDLE))?;
193✔
75
        let output_field_rules =
193✔
76
            get_aggregation_rules(&self.select, &self.groupby, input_schema).unwrap();
193✔
77

193✔
78
        if is_aggregation(&self.groupby, &output_field_rules) {
193✔
79
            let output_schema = build_output_schema(input_schema, output_field_rules)?;
55✔
80
            return Ok((output_schema, ctx.clone()));
55✔
81
        }
138✔
82
        build_projection_schema(input_schema, ctx, &self.select)
138✔
83
    }
193✔
84

85
    fn build(
97✔
86
        &self,
97✔
87
        input_schemas: HashMap<PortHandle, Schema>,
97✔
88
        _output_schemas: HashMap<PortHandle, Schema>,
97✔
89
    ) -> Result<Box<dyn Processor>, ExecutionError> {
97✔
90
        let input_schema = input_schemas
97✔
91
            .get(&DEFAULT_PORT_HANDLE)
97✔
92
            .ok_or(ExecutionError::InvalidPortHandle(DEFAULT_PORT_HANDLE))?;
97✔
93
        let output_field_rules =
97✔
94
            get_aggregation_rules(&self.select, &self.groupby, input_schema).unwrap();
97✔
95

97✔
96
        if is_aggregation(&self.groupby, &output_field_rules) {
97✔
97
            return Ok(Box::new(AggregationProcessor::new(
28✔
98
                output_field_rules,
28✔
99
                input_schema.clone(),
28✔
100
            )));
28✔
101
        }
69✔
102

69✔
103
        let mut select_expr: Vec<(String, Expression)> = vec![];
69✔
104
        for s in self.select.iter() {
234✔
105
            match s {
234✔
106
                SelectItem::Wildcard(_) => {
107
                    let fields: Vec<SelectItem> = input_schema
×
108
                        .fields
×
109
                        .iter()
×
110
                        .map(|col| {
×
111
                            SelectItem::UnnamedExpr(Expr::Identifier(Ident::new(
×
112
                                col.to_owned().name,
×
113
                            )))
×
114
                        })
×
115
                        .collect();
×
116
                    for f in fields {
×
117
                        let res = parse_sql_select_item(&f, input_schema);
×
118
                        if let Ok(..) = res {
×
119
                            select_expr.push(res.unwrap())
×
120
                        }
×
121
                    }
122
                }
123
                _ => {
124
                    let res = parse_sql_select_item(s, input_schema);
234✔
125
                    if let Ok(..) = res {
234✔
126
                        select_expr.push(res.unwrap())
234✔
127
                    }
×
128
                }
129
            }
130
        }
131

132
        Ok(Box::new(ProjectionProcessor::new(
69✔
133
            input_schema.clone(),
69✔
134
            select_expr,
69✔
135
        )))
69✔
136
    }
97✔
137

138
    fn prepare(
×
139
        &self,
×
140
        _input_schemas: HashMap<PortHandle, (Schema, SchemaSQLContext)>,
×
141
        _output_schemas: HashMap<PortHandle, (Schema, SchemaSQLContext)>,
×
142
    ) -> Result<(), ExecutionError> {
×
143
        Ok(())
×
144
    }
×
145
}
146

147
fn is_aggregation(groupby: &[SqlExpr], output_field_rules: &[FieldRule]) -> bool {
290✔
148
    if !groupby.is_empty() {
290✔
149
        return true;
63✔
150
    }
227✔
151

227✔
152
    output_field_rules
227✔
153
        .iter()
227✔
154
        .any(|rule| matches!(rule, FieldRule::Measure(_, _, _)))
722✔
155
}
290✔
156

157
pub(crate) fn get_aggregation_rules(
336✔
158
    select: &[SelectItem],
336✔
159
    groupby: &[SqlExpr],
336✔
160
    schema: &Schema,
336✔
161
) -> Result<Vec<FieldRule>, PipelineError> {
336✔
162
    let mut select_rules = select
336✔
163
        .iter()
336✔
164
        .map(|item| parse_sql_aggregate_item(item, schema))
942✔
165
        .filter(|e| e.is_ok())
942✔
166
        .collect::<Result<Vec<FieldRule>, PipelineError>>()?;
336✔
167

168
    let mut groupby_rules = groupby
336✔
169
        .iter()
336✔
170
        .map(|expr| parse_sql_groupby_item(expr, schema))
336✔
171
        .collect::<Result<Vec<FieldRule>, PipelineError>>()?;
336✔
172

173
    select_rules.append(&mut groupby_rules);
336✔
174

336✔
175
    Ok(select_rules)
336✔
176
}
336✔
177

178
fn build_field_rule(
942✔
179
    sql_expr: &Expr,
942✔
180
    schema: &Schema,
942✔
181
    name: String,
942✔
182
) -> Result<FieldRule, PipelineError> {
942✔
183
    let builder = ExpressionBuilder {};
942✔
184
    let expression =
942✔
185
        builder.parse_sql_expression(&BuilderExpressionType::Aggregation, sql_expr, schema)?;
942✔
186

187
    match get_aggregator(expression.0.clone(), schema) {
942✔
188
        Ok(aggregator) => Ok(FieldRule::Measure(
129✔
189
            ExpressionBuilder {}
129✔
190
                .parse_sql_expression(&BuilderExpressionType::PreAggregation, sql_expr, schema)?
129✔
191
                .0,
192
            aggregator,
129✔
193
            name,
129✔
194
        )),
195
        Err(_) => Ok(FieldRule::Dimension(expression.0, true, name)),
813✔
196
    }
197
}
942✔
198

199
fn parse_sql_aggregate_item(
942✔
200
    item: &SelectItem,
942✔
201
    schema: &Schema,
942✔
202
) -> Result<FieldRule, PipelineError> {
942✔
203
    match item {
942✔
204
        SelectItem::UnnamedExpr(sql_expr) => {
888✔
205
            build_field_rule(sql_expr, schema, sql_expr.to_string())
888✔
206
        }
207
        SelectItem::ExprWithAlias { expr, alias } => {
54✔
208
            build_field_rule(expr, schema, alias.value.clone())
54✔
209
        }
210
        SelectItem::Wildcard(_) => Err(PipelineError::InvalidExpression(
×
211
            "Wildcard Operator is not supported".to_string(),
×
212
        )),
×
213
        SelectItem::QualifiedWildcard(..) => Err(PipelineError::InvalidExpression(
×
214
            "Qualified Wildcard Operator is not supported".to_string(),
×
215
        )),
×
216
    }
217
}
942✔
218

219
fn parse_sql_groupby_item(
109✔
220
    sql_expression: &SqlExpr,
109✔
221
    schema: &Schema,
109✔
222
) -> Result<FieldRule, PipelineError> {
109✔
223
    Ok(FieldRule::Dimension(
109✔
224
        ExpressionBuilder {}.build(
109✔
225
            &BuilderExpressionType::FullExpression,
109✔
226
            sql_expression,
109✔
227
            schema,
109✔
228
        )?,
109✔
229
        false,
230
        sql_expression.to_string(),
109✔
231
    ))
232
}
109✔
233

234
fn get_aggregator(
942✔
235
    expression: Box<Expression>,
942✔
236
    schema: &Schema,
942✔
237
) -> Result<Aggregator, PipelineError> {
942✔
238
    match *expression {
942✔
239
        Expression::AggregateFunction { fun, args } => {
129✔
240
            let arg_type = args[0].get_type(schema);
129✔
241
            match (&fun, arg_type) {
129✔
242
                (AggregateFunctionType::Avg, _) => Ok(Aggregator::Avg),
7✔
243
                (AggregateFunctionType::Count, _) => Ok(Aggregator::Count),
92✔
244
                (AggregateFunctionType::Max, _) => Ok(Aggregator::Max),
11✔
245
                (AggregateFunctionType::Min, _) => Ok(Aggregator::Min),
11✔
246
                (AggregateFunctionType::Sum, _) => Ok(Aggregator::Sum),
8✔
247
                _ => Err(PipelineError::InvalidExpression(format!(
×
248
                    "Not implemented Aggregation function: {fun:?}"
×
249
                ))),
×
250
            }
251
        }
252
        _ => Err(PipelineError::InvalidExpression(format!(
813✔
253
            "Not an Aggregation function: {expression:?}"
813✔
254
        ))),
813✔
255
    }
256
}
942✔
257

258
fn build_output_schema(
55✔
259
    input_schema: &Schema,
55✔
260
    output_field_rules: Vec<FieldRule>,
55✔
261
) -> Result<Schema, ExecutionError> {
55✔
262
    let mut output_schema = Schema::empty();
55✔
263
    for e in output_field_rules.iter().enumerate() {
140✔
264
        match e.1 {
140✔
265
            FieldRule::Measure(pre_aggr, aggr, name) => {
55✔
266
                let res = pre_aggr
55✔
267
                    .get_type(input_schema)
55✔
268
                    .map_err(|e| ExecutionError::InternalError(Box::new(e)))?;
55✔
269

270
                output_schema.fields.push(FieldDefinition::new(
55✔
271
                    name.clone(),
55✔
272
                    aggr.get_return_type(res.return_type),
55✔
273
                    res.nullable,
55✔
274
                    res.source,
55✔
275
                ));
55✔
276
            }
277

278
            FieldRule::Dimension(expression, is_value, name) => {
85✔
279
                if *is_value {
85✔
280
                    let res = expression
43✔
281
                        .get_type(input_schema)
43✔
282
                        .map_err(|e| ExecutionError::InternalError(Box::new(e)))?;
43✔
283

284
                    output_schema.fields.push(FieldDefinition::new(
43✔
285
                        name.clone(),
43✔
286
                        res.return_type,
43✔
287
                        res.nullable,
43✔
288
                        res.source,
43✔
289
                    ));
43✔
290
                    output_schema.primary_index.push(e.0);
43✔
291
                }
42✔
292
            }
293
        }
294
    }
295

296
    // remove primary index as already defined in the sink
297
    // the Planner will compute the primary index properly
298
    output_schema.primary_index = vec![];
55✔
299
    Ok(output_schema)
55✔
300
}
55✔
301

302
fn build_projection_schema(
138✔
303
    input_schema: &Schema,
138✔
304
    context: &SchemaSQLContext,
138✔
305
    select: &[SelectItem],
138✔
306
) -> Result<(Schema, SchemaSQLContext), ExecutionError> {
138✔
307
    let mut select_expr: Vec<(String, Expression)> = vec![];
138✔
308
    for s in select.iter() {
468✔
309
        match s {
468✔
310
            SelectItem::Wildcard(_) => {
311
                let fields: Vec<SelectItem> = input_schema
×
312
                    .fields
×
313
                    .iter()
×
314
                    .map(|col| {
×
315
                        SelectItem::UnnamedExpr(Expr::Identifier(Ident::new(col.to_owned().name)))
×
316
                    })
×
317
                    .collect();
×
318
                for f in fields {
×
319
                    let res = parse_sql_select_item(&f, input_schema);
×
320
                    if let Ok(..) = res {
×
321
                        select_expr.push(res.unwrap())
×
322
                    }
×
323
                }
324
            }
325
            _ => {
326
                let res = parse_sql_select_item(s, input_schema);
468✔
327
                if let Ok(..) = res {
468✔
328
                    select_expr.push(res.unwrap())
468✔
329
                }
×
330
            }
331
        }
332
    }
333

334
    let mut output_schema = input_schema.clone();
138✔
335
    let mut fields = vec![];
138✔
336
    for e in select_expr.iter() {
468✔
337
        let field_name = e.0.clone();
468✔
338
        let field_type =
468✔
339
            e.1.get_type(input_schema)
468✔
340
                .map_err(|e| ExecutionError::InternalError(Box::new(e)))?;
468✔
341

342
        fields.push(FieldDefinition::new(
468✔
343
            field_name,
468✔
344
            field_type.return_type,
468✔
345
            field_type.nullable,
468✔
346
            field_type.source,
468✔
347
        ));
468✔
348
    }
349
    output_schema.fields = fields;
138✔
350

138✔
351
    // remove primary index as already defined in the sink
138✔
352
    // the Planner will compute the primary index properly
138✔
353
    output_schema.primary_index = vec![];
138✔
354
    Ok((output_schema, context.clone()))
138✔
355
}
138✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc