• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

getdozer / dozer / 3978628498

pending completion
3978628498

Pull #705

github

GitHub
Merge 8775fcda7 into e2f9ad287
Pull Request #705: chore: support for generic schema context in `Sink`, `Processor` and `Source` factories

572 of 572 new or added lines in 35 files covered. (100.0%)

22294 of 34850 relevant lines covered (63.97%)

40332.28 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.63
/dozer-sql/src/pipeline/builder.rs
1
use crate::pipeline::aggregation::factory::AggregationProcessorFactory;
2
use crate::pipeline::builder::PipelineError::InvalidQuery;
3
use crate::pipeline::selection::factory::SelectionProcessorFactory;
4
use crate::pipeline::{errors::PipelineError, product::factory::ProductProcessorFactory};
5
use dozer_core::dag::app::AppPipeline;
6
use dozer_core::dag::app::PipelineEntryPoint;
7
use dozer_core::dag::appsource::AppSourceId;
8
use dozer_core::dag::dag::DEFAULT_PORT_HANDLE;
9
use dozer_core::dag::node::PortHandle;
10
use sqlparser::ast::{Join, TableFactor, TableWithJoins};
11
use sqlparser::{
12
    ast::{Query, Select, SetExpr, Statement},
13
    dialect::AnsiDialect,
14
    parser::Parser,
15
};
16
use std::collections::{HashMap, HashSet};
17
use std::sync::Arc;
18

19
use super::errors::UnsupportedSqlError;
20
use super::expression::builder::{fullname_from_ident, normalize_ident};
21

×
22
#[derive(Debug, Clone)]
652✔
23
pub struct SchemaSQLContext {}
×
24

×
25
/// The struct contains some contexts during query to pipeline.
×
26
#[derive(Debug, Clone, Default)]
84✔
27
pub struct QueryContext {
×
28
    pub cte_names: HashSet<String>,
×
29
}
×
30

×
31
#[derive(Debug, Clone)]
90✔
32
pub struct IndexedTabelWithJoins {
×
33
    pub relation: (NameOrAlias, TableFactor),
×
34
    pub joins: Vec<(NameOrAlias, Join)>,
×
35
}
×
36

×
37
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
768✔
38
pub struct NameOrAlias(pub String, pub Option<String>);
×
39

×
40
pub fn statement_to_pipeline(
82✔
41
    sql: &str,
82✔
42
) -> Result<(AppPipeline<SchemaSQLContext>, (String, PortHandle)), PipelineError> {
82✔
43
    let dialect = AnsiDialect {};
82✔
44
    let mut ctx = QueryContext::default();
82✔
45

82✔
46
    let ast = Parser::parse_sql(&dialect, sql).unwrap();
82✔
47
    let query_name = NameOrAlias(format!("query_{}", uuid::Uuid::new_v4()), None);
82✔
48
    let statement = ast.get(0).expect("First statement is missing").to_owned();
82✔
49

82✔
50
    let mut pipeline = AppPipeline::new();
82✔
51
    let mut pipeline_map = HashMap::new();
82✔
52
    if let Statement::Query(query) = statement {
82✔
53
        query_to_pipeline(
82✔
54
            &query_name,
82✔
55
            &query,
82✔
56
            &mut pipeline,
82✔
57
            &mut pipeline_map,
82✔
58
            &mut ctx,
82✔
59
            false,
82✔
60
        )?;
82✔
61
    };
×
62
    let node = pipeline_map
82✔
63
        .get(&query_name.0)
82✔
64
        .expect("query should have been initialized")
82✔
65
        .to_owned();
82✔
66
    Ok((pipeline, node))
82✔
67
}
82✔
68

×
69
fn query_to_pipeline(
×
70
    processor_name: &NameOrAlias,
×
71
    query: &Query,
×
72
    pipeline: &mut AppPipeline<SchemaSQLContext>,
×
73
    pipeline_map: &mut HashMap<String, (String, PortHandle)>,
×
74
    query_ctx: &mut QueryContext,
×
75
    stateful: bool,
×
76
) -> Result<(), PipelineError> {
×
77
    // Attach the first pipeline if there is with clause
×
78
    if let Some(with) = &query.with {
90✔
79
        if with.recursive {
4✔
80
            return Err(PipelineError::UnsupportedSqlError(
×
81
                UnsupportedSqlError::Recursive,
×
82
            ));
×
83
        }
4✔
84

×
85
        for table in &with.cte_tables {
10✔
86
            if table.from.is_some() {
6✔
87
                return Err(PipelineError::UnsupportedSqlError(
×
88
                    UnsupportedSqlError::CteFromError,
×
89
                ));
×
90
            }
6✔
91
            let table_name = table.alias.name.to_string();
6✔
92
            if query_ctx.cte_names.contains(&table_name) {
6✔
93
                return Err(InvalidQuery(format!(
×
94
                    "WITH query name {table_name:?} specified more than once"
×
95
                )));
×
96
            }
6✔
97
            query_ctx.cte_names.insert(table_name.clone());
6✔
98
            query_to_pipeline(
6✔
99
                &NameOrAlias(table_name.clone(), Some(table_name)),
6✔
100
                &table.query,
6✔
101
                pipeline,
6✔
102
                pipeline_map,
6✔
103
                query_ctx,
6✔
104
                false,
6✔
105
            )?;
6✔
106
        }
×
107
    };
86✔
108

×
109
    match *query.body.clone() {
90✔
110
        SetExpr::Select(select) => {
90✔
111
            select_to_pipeline(processor_name, *select, pipeline, pipeline_map, stateful)?;
90✔
112
        }
×
113
        SetExpr::Query(query) => {
×
114
            let query_name = format!("subquery_{}", uuid::Uuid::new_v4());
×
115
            let mut ctx = QueryContext::default();
×
116
            query_to_pipeline(
×
117
                &NameOrAlias(query_name, None),
×
118
                &query,
×
119
                pipeline,
×
120
                pipeline_map,
×
121
                &mut ctx,
×
122
                stateful,
×
123
            )?
×
124
        }
×
125
        _ => {
×
126
            return Err(PipelineError::UnsupportedSqlError(
×
127
                UnsupportedSqlError::SelectOnlyError,
×
128
            ))
×
129
        }
×
130
    };
×
131
    Ok(())
90✔
132
}
90✔
133

×
134
fn select_to_pipeline(
90✔
135
    processor_name: &NameOrAlias,
90✔
136
    select: Select,
90✔
137
    pipeline: &mut AppPipeline<SchemaSQLContext>,
90✔
138
    pipeline_map: &mut HashMap<String, (String, PortHandle)>,
90✔
139
    stateful: bool,
90✔
140
) -> Result<(), PipelineError> {
90✔
141
    // FROM clause
90✔
142
    if select.from.len() != 1 {
90✔
143
        return Err(InvalidQuery(
×
144
            "FROM clause doesn't support \"Comma Syntax\"".to_string(),
×
145
        ));
×
146
    }
90✔
147

×
148
    let input_tables = get_input_tables(&select.from[0], pipeline, pipeline_map)?;
90✔
149

150
    let product = ProductProcessorFactory::new(input_tables.clone());
90✔
151

×
152
    let input_endpoints = get_entry_points(&input_tables, pipeline_map)?;
90✔
153

×
154
    let gen_product_name = format!("product_{}", uuid::Uuid::new_v4());
90✔
155
    let gen_agg_name = format!("agg_{}", uuid::Uuid::new_v4());
90✔
156
    let gen_selection_name = format!("select_{}", uuid::Uuid::new_v4());
90✔
157
    pipeline.add_processor(Arc::new(product), &gen_product_name, input_endpoints);
90✔
158

90✔
159
    let input_names = get_input_names(&input_tables);
90✔
160
    for (port_index, table_name) in input_names.iter().enumerate() {
97✔
161
        if let Some((processor_name, processor_port)) = pipeline_map.get(&table_name.0) {
97✔
162
            pipeline.connect_nodes(
8✔
163
                processor_name,
8✔
164
                Some(*processor_port),
8✔
165
                &gen_product_name,
8✔
166
                Some(port_index as PortHandle),
8✔
167
            )?;
8✔
168
        }
89✔
169
    }
×
170

×
171
    let aggregation =
90✔
172
        AggregationProcessorFactory::new(select.projection.clone(), select.group_by, stateful);
90✔
173

90✔
174
    pipeline.add_processor(Arc::new(aggregation), &gen_agg_name, vec![]);
90✔
175

×
176
    // Where clause
×
177
    if let Some(selection) = select.selection {
90✔
178
        let selection = SelectionProcessorFactory::new(selection);
44✔
179
        // first_node_name = String::from("selection");
44✔
180

44✔
181
        pipeline.add_processor(Arc::new(selection), &gen_selection_name, vec![]);
44✔
182

44✔
183
        pipeline.connect_nodes(
44✔
184
            &gen_product_name,
44✔
185
            Some(DEFAULT_PORT_HANDLE),
44✔
186
            &gen_selection_name,
44✔
187
            Some(DEFAULT_PORT_HANDLE),
44✔
188
        )?;
44✔
189

×
190
        pipeline.connect_nodes(
44✔
191
            &gen_selection_name,
44✔
192
            Some(DEFAULT_PORT_HANDLE),
44✔
193
            &gen_agg_name,
44✔
194
            Some(DEFAULT_PORT_HANDLE),
44✔
195
        )?;
44✔
196
    } else {
×
197
        pipeline.connect_nodes(
46✔
198
            &gen_product_name,
46✔
199
            Some(DEFAULT_PORT_HANDLE),
46✔
200
            &gen_agg_name,
46✔
201
            Some(DEFAULT_PORT_HANDLE),
46✔
202
        )?;
46✔
203
    }
×
204

×
205
    pipeline_map.insert(
90✔
206
        processor_name.0.clone(),
90✔
207
        (gen_agg_name, DEFAULT_PORT_HANDLE),
90✔
208
    );
90✔
209

90✔
210
    Ok(())
90✔
211
}
90✔
212

×
213
/// Returns a vector of input port handles and relative table name
×
214
///
215
/// # Errors
×
216
///
×
217
/// This function will return an error if it's not possible to get an input name.
×
218
pub fn get_input_tables(
90✔
219
    from: &TableWithJoins,
90✔
220
    pipeline: &mut AppPipeline<SchemaSQLContext>,
90✔
221
    pipeline_map: &mut HashMap<String, (String, PortHandle)>,
90✔
222
) -> Result<IndexedTabelWithJoins, PipelineError> {
90✔
223
    let mut input_tables = vec![];
90✔
224

×
225
    let name = get_from_source(&from.relation, pipeline, pipeline_map)?;
90✔
226
    input_tables.insert(0, name.clone());
90✔
227
    let mut joins = vec![];
90✔
228

×
229
    for (index, join) in from.joins.iter().enumerate() {
90✔
230
        let input_name = get_from_source(&join.relation, pipeline, pipeline_map)?;
7✔
231
        joins.push((input_name.clone(), join.clone()));
7✔
232
        input_tables.insert(index + 1, input_name);
7✔
233
    }
×
234

×
235
    Ok(IndexedTabelWithJoins {
90✔
236
        relation: (name, from.relation.clone()),
90✔
237
        joins,
90✔
238
    })
90✔
239
}
90✔
240

241
pub fn get_input_names(input_tables: &IndexedTabelWithJoins) -> Vec<NameOrAlias> {
408✔
242
    let mut input_names = vec![];
408✔
243
    input_names.push(input_tables.relation.0.clone());
408✔
244

245
    for join in &input_tables.joins {
422✔
246
        input_names.push(join.0.clone());
14✔
247
    }
14✔
248
    input_names
408✔
249
}
408✔
250
pub fn get_entry_points(
90✔
251
    input_tables: &IndexedTabelWithJoins,
90✔
252
    pipeline_map: &mut HashMap<String, (String, PortHandle)>,
90✔
253
) -> Result<Vec<PipelineEntryPoint>, PipelineError> {
90✔
254
    let mut endpoints = vec![];
90✔
255

90✔
256
    let input_names = get_input_names(input_tables);
90✔
257

×
258
    for (input_port, table) in input_names.iter().enumerate() {
97✔
259
        let name = table.0.clone();
97✔
260
        if !pipeline_map.contains_key(&name) {
97✔
261
            endpoints.push(PipelineEntryPoint::new(
89✔
262
                AppSourceId::new(name, None),
89✔
263
                input_port as PortHandle,
89✔
264
            ));
89✔
265
        }
89✔
266
    }
×
267

×
268
    Ok(endpoints)
90✔
269
}
90✔
270

×
271
pub fn get_from_source(
97✔
272
    relation: &TableFactor,
97✔
273
    pipeline: &mut AppPipeline<SchemaSQLContext>,
97✔
274
    pipeline_map: &mut HashMap<String, (String, PortHandle)>,
97✔
275
) -> Result<NameOrAlias, PipelineError> {
97✔
276
    match relation {
97✔
277
        TableFactor::Table { name, alias, .. } => {
95✔
278
            let input_name = name
95✔
279
                .0
95✔
280
                .iter()
95✔
281
                .map(normalize_ident)
95✔
282
                .collect::<Vec<String>>()
95✔
283
                .join(".");
95✔
284
            let alias_name = alias
95✔
285
                .as_ref()
95✔
286
                .map(|a| fullname_from_ident(&[a.name.clone()]));
95✔
287

95✔
288
            Ok(NameOrAlias(input_name, alias_name))
95✔
289
        }
×
290
        TableFactor::Derived {
×
291
            lateral: _,
×
292
            subquery,
2✔
293
            alias,
2✔
294
        } => {
2✔
295
            let name = format!("derived_{}", uuid::Uuid::new_v4());
2✔
296
            let alias_name = alias
2✔
297
                .as_ref()
2✔
298
                .map(|alias_ident| fullname_from_ident(&[alias_ident.name.clone()]));
2✔
299

2✔
300
            let name_or = NameOrAlias(name, alias_name);
2✔
301
            let mut ctx = QueryContext::default();
2✔
302
            query_to_pipeline(&name_or, subquery, pipeline, pipeline_map, &mut ctx, false)?;
2✔
303

×
304
            Ok(name_or)
2✔
305
        }
×
306
        _ => Err(PipelineError::UnsupportedSqlError(
×
307
            UnsupportedSqlError::JoinTable,
×
308
        )),
×
309
    }
×
310
}
97✔
311

×
312
#[cfg(test)]
×
313
mod tests {
×
314
    use super::statement_to_pipeline;
×
315

×
316
    #[test]
1✔
317
    fn sql_logic_test_1() {
1✔
318
        let statements: Vec<&str> = vec![
1✔
319
            r#"
1✔
320
            SELECT
1✔
321
            a.name as "Genre",
1✔
322
                SUM(amount) as "Gross Revenue(in $)"
1✔
323
            FROM
1✔
324
            (
1✔
325
                SELECT
1✔
326
                c.name, f.title, p.amount
1✔
327
            FROM film f
1✔
328
            LEFT JOIN film_category fc
1✔
329
            ON fc.film_id = f.film_id
1✔
330
            LEFT JOIN category c
1✔
331
            ON fc.category_id = c.category_id
1✔
332
            LEFT JOIN inventory i
1✔
333
            ON i.film_id = f.film_id
1✔
334
            LEFT JOIN rental r
1✔
335
            ON r.inventory_id = i.inventory_id
1✔
336
            LEFT JOIN payment p
1✔
337
            ON p.rental_id = r.rental_id
1✔
338
            WHERE p.amount IS NOT NULL
1✔
339
            ) a
1✔
340

1✔
341
            GROUP BY name
1✔
342
            ORDER BY sum(amount) desc
1✔
343
            LIMIT 5;
1✔
344
            "#,
1✔
345
            r#"
1✔
346
                SELECT
1✔
347
                c.name, f.title, p.amount
1✔
348
            FROM film f
1✔
349
            LEFT JOIN film_category fc
1✔
350
            "#,
1✔
351
            r#"
1✔
352
            WITH tbl as (select id from a)
1✔
353
            select id from tbl
1✔
354
            "#,
1✔
355
            r#"
1✔
356
            WITH tbl as (select id from  a),
1✔
357
            tbl2 as (select id from tbl)
1✔
358
            select id from tbl2
1✔
359
            "#,
1✔
360
            r#"
1✔
361
            WITH cte_table1 as (select id_dt1 from (select id_t1 from table_1) as derived_table_1),
1✔
362
            cte_table2 as (select id_ct1 from cte_table1)
1✔
363
            select id_ct2 from cte_table2
1✔
364
            "#,
1✔
365
            r#"
1✔
366
                with tbl as (select id, ticker from stocks)
1✔
367
                select tbl.id from  stocks join tbl on tbl.id = stocks.id;
1✔
368
            "#,
1✔
369
        ];
1✔
370
        for sql in statements {
7✔
371
            let _pipeline = statement_to_pipeline(sql).unwrap();
6✔
372
        }
6✔
373
    }
1✔
374
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc