• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

supabase / pg_graphql / 14885589616

07 May 2025 02:12PM UTC coverage: 94.213% (-1.0%) from 95.228%
14885589616

Pull #590

github

web-flow
Merge 7daffe775 into 49fadbec8
Pull Request #590: Add support for single record queries by primary key

217 of 314 new or added lines in 4 files covered. (69.11%)

224 existing lines in 4 files now uncovered.

7684 of 8156 relevant lines covered (94.21%)

1195.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.29
/src/sql_types.rs
1
use bimap::BiBTreeMap;
2
use cached::proc_macro::cached;
3
use cached::SizedCache;
4
use lazy_static::lazy_static;
5
use pgrx::*;
6
use serde::{Deserialize, Serialize};
7
use std::cmp::Ordering;
8
use std::collections::hash_map::DefaultHasher;
9
use std::collections::{HashMap, HashSet};
10
use std::hash::{Hash, Hasher};
11
use std::sync::Arc;
12
use std::*;
13

14
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
15
pub struct ColumnPermissions {
16
    pub is_insertable: bool,
17
    pub is_selectable: bool,
18
    pub is_updatable: bool,
19
    // admin interface
20
    // alterable?
21
}
22

23
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
24
pub struct ColumnDirectives {
25
    pub name: Option<String>,
26
    pub description: Option<String>,
27
}
28

29
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
30
pub struct Column {
31
    pub name: String,
32
    pub type_oid: u32,
33
    pub type_: Option<Arc<Type>>,
34
    pub type_name: String,
35
    pub max_characters: Option<i32>,
36
    pub schema_oid: u32,
37
    pub is_not_null: bool,
38
    pub is_serial: bool,
39
    pub is_generated: bool,
40
    pub has_default: bool,
41
    pub attribute_num: i32,
42
    pub permissions: ColumnPermissions,
43
    pub comment: Option<String>,
44
    pub directives: ColumnDirectives,
45
}
46

47
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
48
pub struct FunctionDirectives {
49
    pub name: Option<String>,
50
    // @graphql({"description": "the address of ..." })
51
    pub description: Option<String>,
52
}
53

54
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
55
pub struct FunctionPermissions {
56
    pub is_executable: bool,
57
}
58

59
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
60
pub enum FunctionVolatility {
61
    #[serde(rename(deserialize = "v"))]
62
    Volatile,
63
    #[serde(rename(deserialize = "s"))]
64
    Stable,
65
    #[serde(rename(deserialize = "i"))]
66
    Immutable,
67
}
68

69
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
70
pub struct Function {
71
    pub oid: u32,
72
    pub name: String,
73
    pub schema_oid: u32,
74
    pub schema_name: String,
75
    pub arg_types: Vec<u32>,
76
    pub arg_names: Option<Vec<String>>,
77
    pub arg_defaults: Option<String>,
78
    pub num_args: u32,
79
    pub num_default_args: u32,
80
    pub arg_type_names: Vec<String>,
81
    pub volatility: FunctionVolatility,
82
    pub type_oid: u32,
83
    pub type_name: String,
84
    pub is_set_of: bool,
85
    pub comment: Option<String>,
86
    pub directives: FunctionDirectives,
87
    pub permissions: FunctionPermissions,
88
}
89

90
impl Function {
91
    pub fn args(&self) -> impl Iterator<Item = (u32, &str, Option<&str>, Option<DefaultValue>)> {
5,344✔
92
        ArgsIterator::new(
5,344✔
93
            &self.arg_types,
5,344✔
94
            &self.arg_type_names,
5,344✔
95
            &self.arg_names,
5,344✔
96
            &self.arg_defaults,
5,344✔
97
            self.num_default_args,
5,344✔
98
        )
5,344✔
99
    }
5,344✔
100

101
    pub fn function_names_to_count(all_functions: &[Arc<Function>]) -> HashMap<&String, u32> {
537✔
102
        let mut function_name_to_count = HashMap::new();
537✔
103
        for function_name in all_functions.iter().map(|f| &f.name) {
2,219✔
104
            let entry = function_name_to_count.entry(function_name).or_insert(0u32);
2,219✔
105
            *entry += 1;
2,219✔
106
        }
2,219✔
107
        function_name_to_count
537✔
108
    }
537✔
109

110
    pub fn is_supported(
2,219✔
111
        &self,
2,219✔
112
        context: &Context,
2,219✔
113
        function_name_to_count: &HashMap<&String, u32>,
2,219✔
114
    ) -> bool {
2,219✔
115
        let types = &context.types;
2,219✔
116
        self.return_type_is_supported(types)
2,219✔
117
            && self.arg_types_are_supported(types)
2,183✔
118
            && !self.is_function_overloaded(function_name_to_count)
2,086✔
119
            && !self.has_a_nameless_arg()
1,906✔
120
            && self.permissions.is_executable
1,694✔
121
            && !self.is_in_a_system_schema()
1,691✔
122
    }
2,219✔
123

124
    fn arg_types_are_supported(&self, types: &HashMap<u32, Arc<Type>>) -> bool {
2,183✔
125
        self.args().all(|(arg_type, _, _, _)| {
3,393✔
126
            if let Some(arg_type) = types.get(&arg_type) {
3,393✔
127
                let array_element_type_is_supported = self.array_element_type_is_supported(
3,393✔
128
                    &arg_type.category,
3,393✔
129
                    arg_type.array_element_type_oid,
3,393✔
130
                    types,
3,393✔
131
                );
3,393✔
132
                arg_type.category == TypeCategory::Other
3,393✔
133
                    || (arg_type.category == TypeCategory::Array && array_element_type_is_supported)
776✔
134
            } else {
135
                false
×
136
            }
137
        })
3,393✔
138
    }
2,183✔
139

140
    fn return_type_is_supported(&self, types: &HashMap<u32, Arc<Type>>) -> bool {
2,219✔
141
        if let Some(return_type) = types.get(&self.type_oid) {
2,219✔
142
            let array_element_type_is_supported = self.array_element_type_is_supported(
2,219✔
143
                &return_type.category,
2,219✔
144
                return_type.array_element_type_oid,
2,219✔
145
                types,
2,219✔
146
            );
2,219✔
147
            return_type.category != TypeCategory::Pseudo
2,219✔
148
                && return_type.category != TypeCategory::Enum
2,219✔
149
                && return_type.name != "record"
2,214✔
150
                && return_type.name != "trigger"
2,199✔
151
                && return_type.name != "event_trigger"
2,195✔
152
                && array_element_type_is_supported
2,188✔
153
        } else {
154
            false
×
155
        }
156
    }
2,219✔
157

158
    fn array_element_type_is_supported(
5,612✔
159
        &self,
5,612✔
160
        type_category: &TypeCategory,
5,612✔
161
        array_element_type_oid: Option<u32>,
5,612✔
162
        types: &HashMap<u32, Arc<Type>>,
5,612✔
163
    ) -> bool {
5,612✔
164
        if *type_category == TypeCategory::Array {
5,612✔
165
            if let Some(array_element_type_oid) = array_element_type_oid {
999✔
166
                if let Some(array_element_type) = types.get(&array_element_type_oid) {
999✔
167
                    array_element_type.category == TypeCategory::Other
999✔
168
                } else {
169
                    false
×
170
                }
171
            } else {
172
                false
×
173
            }
174
        } else {
175
            true
4,613✔
176
        }
177
    }
5,612✔
178

179
    fn is_function_overloaded(&self, function_name_to_count: &HashMap<&String, u32>) -> bool {
2,086✔
180
        if let Some(&count) = function_name_to_count.get(&self.name) {
2,086✔
181
            count > 1
2,086✔
182
        } else {
183
            false
×
184
        }
185
    }
2,086✔
186

187
    fn has_a_nameless_arg(&self) -> bool {
1,906✔
188
        self.args().any(|(_, _, arg_name, _)| arg_name.is_none())
2,798✔
189
    }
1,906✔
190

191
    fn is_in_a_system_schema(&self) -> bool {
1,691✔
192
        // These are default schemas in supabase configuration
1,691✔
193
        let system_schemas = &["graphql", "graphql_public", "auth", "extensions"];
1,691✔
194
        system_schemas.contains(&self.schema_name.as_str())
1,691✔
195
    }
1,691✔
196
}
197

198
struct ArgsIterator<'a> {
199
    index: usize,
200
    arg_types: &'a [u32],
201
    arg_type_names: &'a Vec<String>,
202
    arg_names: &'a Option<Vec<String>>,
203
    arg_defaults: Vec<Option<DefaultValue>>,
204
}
205

206
#[derive(Clone)]
207
pub enum DefaultValue {
208
    NonNull(String),
209
    Null,
210
}
211

212
impl<'a> ArgsIterator<'a> {
213
    fn new(
5,344✔
214
        arg_types: &'a [u32],
5,344✔
215
        arg_type_names: &'a Vec<String>,
5,344✔
216
        arg_names: &'a Option<Vec<String>>,
5,344✔
217
        arg_defaults: &'a Option<String>,
5,344✔
218
        num_default_args: u32,
5,344✔
219
    ) -> ArgsIterator<'a> {
5,344✔
220
        ArgsIterator {
5,344✔
221
            index: 0,
5,344✔
222
            arg_types,
5,344✔
223
            arg_type_names,
5,344✔
224
            arg_names,
5,344✔
225
            arg_defaults: Self::defaults(
5,344✔
226
                arg_types,
5,344✔
227
                arg_defaults,
5,344✔
228
                num_default_args as usize,
5,344✔
229
                arg_types.len(),
5,344✔
230
            ),
5,344✔
231
        }
5,344✔
232
    }
5,344✔
233

234
    fn defaults(
5,344✔
235
        arg_types: &'a [u32],
5,344✔
236
        arg_defaults: &'a Option<String>,
5,344✔
237
        num_default_args: usize,
5,344✔
238
        num_total_args: usize,
5,344✔
239
    ) -> Vec<Option<DefaultValue>> {
5,344✔
240
        let mut defaults = vec![None; num_total_args];
5,344✔
241
        let Some(arg_defaults) = arg_defaults else {
5,344✔
242
            return defaults;
5,219✔
243
        };
244

245
        if num_default_args == 0 {
125✔
246
            return defaults;
×
247
        }
125✔
248

125✔
249
        let default_strs: Vec<&str> = arg_defaults.split(',').collect();
125✔
250

125✔
251
        if default_strs.len() != num_default_args {
125✔
252
            return defaults;
13✔
253
        }
112✔
254

112✔
255
        debug_assert!(num_default_args <= num_total_args);
112✔
256
        let start_idx = num_total_args - num_default_args;
112✔
257
        for i in start_idx..num_total_args {
471✔
258
            defaults[i] = Self::sql_to_graphql_default(default_strs[i - start_idx], arg_types[i])
471✔
259
        }
260

261
        defaults
112✔
262
    }
5,344✔
263

264
    fn sql_to_graphql_default(default_str: &str, type_oid: u32) -> Option<DefaultValue> {
471✔
265
        let trimmed = default_str.trim();
471✔
266

471✔
267
        if trimmed.starts_with("NULL::") {
471✔
268
            return Some(DefaultValue::Null);
318✔
269
        }
153✔
270

271
        let res = match type_oid {
153✔
272
            21 | 23 => trimmed
72✔
273
                .parse::<i32>()
72✔
274
                .ok()
72✔
275
                .map(|i| DefaultValue::NonNull(i.to_string())),
72✔
276
            16 => trimmed
11✔
277
                .parse::<bool>()
11✔
278
                .ok()
11✔
279
                .map(|i| DefaultValue::NonNull(i.to_string())),
11✔
280
            700 | 701 => trimmed
22✔
281
                .parse::<f64>()
22✔
282
                .ok()
22✔
283
                .map(|i| DefaultValue::NonNull(i.to_string())),
22✔
284
            25 => trimmed.strip_suffix("::text").map(|i| {
26✔
285
                DefaultValue::NonNull(format!("\"{}\"", i.trim_matches(',').trim_matches('\'')))
16✔
286
            }),
26✔
287
            _ => None,
22✔
288
        };
289

290
        // return the non-parsed value as default if for whatever reason the default value can't
291
        // be parsed into a value of the required type. This fixes problems where the default
292
        // is a complex expression like a function call etc. See test/sql/issue_533.sql for
293
        // a test case for this scenario.
294
        if res.is_some() {
153✔
295
            res
101✔
296
        } else {
297
            Some(DefaultValue::Null)
52✔
298
        }
299
    }
471✔
300
}
301

302
lazy_static! {
303
    static ref TEXT_TYPE: String = "text".to_string();
304
}
305

306
impl<'a> Iterator for ArgsIterator<'a> {
307
    type Item = (u32, &'a str, Option<&'a str>, Option<DefaultValue>);
308

309
    fn next(&mut self) -> Option<Self::Item> {
13,205✔
310
        if self.index < self.arg_types.len() {
13,205✔
311
            debug_assert!(self.arg_types.len() == self.arg_type_names.len());
8,170✔
312
            let arg_name = if let Some(arg_names) = self.arg_names {
8,170✔
313
                debug_assert!(arg_names.len() >= self.arg_types.len());
7,220✔
314
                let arg_name = arg_names[self.index].as_str();
7,220✔
315
                if !arg_name.is_empty() {
7,220✔
316
                    Some(arg_name)
7,220✔
317
                } else {
318
                    None
×
319
                }
320
            } else {
321
                None
950✔
322
            };
323
            let arg_type = self.arg_types[self.index];
8,170✔
324
            let mut arg_type_name = &self.arg_type_names[self.index];
8,170✔
325
            if arg_type_name == "character" {
8,170✔
326
                arg_type_name = &TEXT_TYPE;
86✔
327
            }
8,084✔
328
            let arg_default = self.arg_defaults[self.index].clone();
8,170✔
329
            self.index += 1;
8,170✔
330
            Some((arg_type, arg_type_name, arg_name, arg_default))
8,170✔
331
        } else {
332
            None
5,035✔
333
        }
334
    }
13,205✔
335
}
336

337
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
338
pub struct TablePermissions {
339
    pub is_insertable: bool,
340
    pub is_selectable: bool,
341
    pub is_updatable: bool,
342
    pub is_deletable: bool,
343
}
344

345
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
346
pub struct TypePermissions {
347
    pub is_usable: bool,
348
}
349

350
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
351
pub enum TypeCategory {
352
    Enum,
353
    Composite,
354
    Table,
355
    Array,
356
    Pseudo,
357
    Other,
358
}
359

360
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
361
pub struct Type {
362
    pub oid: u32,
363
    pub schema_oid: u32,
364
    pub name: String,
365
    pub category: TypeCategory,
366
    pub array_element_type_oid: Option<u32>,
367
    pub table_oid: Option<u32>,
368
    pub comment: Option<String>,
369
    pub permissions: TypePermissions,
370
    pub details: Option<TypeDetails>,
371
}
372

373
// `TypeDetails` derives `Deserialized` but is not expected to come
374
// from the SQL context. Instead, it is populated in a separate pass.
375
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
376
pub enum TypeDetails {
377
    Enum(Arc<Enum>),
378
    Composite(Arc<Composite>),
379
    Table(Arc<Table>),
380
    Element(Arc<Type>),
381
}
382

383
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
384
pub struct EnumValue {
385
    pub oid: u32,
386
    pub name: String,
387
    pub sort_order: i32,
388
}
389

390
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
391
pub struct Enum {
392
    pub oid: u32,
393
    pub schema_oid: u32,
394
    pub name: String,
395
    pub values: Vec<EnumValue>,
396
    pub array_element_type_oid: Option<u32>,
397
    pub comment: Option<String>,
398
    pub permissions: TypePermissions,
399
    pub directives: EnumDirectives,
400
}
401

402
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
403
pub struct EnumDirectives {
404
    pub name: Option<String>,
405
    pub mappings: Option<BiBTreeMap<String, String>>,
406
}
407

408
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
409
pub struct Composite {
410
    pub oid: u32,
411
    pub schema_oid: u32,
412
}
413

414
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
415
pub struct Index {
416
    pub table_oid: u32,
417
    pub column_names: Vec<String>,
418
    pub is_unique: bool,
419
    pub is_primary_key: bool,
420
}
421

422
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
423
pub struct ForeignKeyTableInfo {
424
    pub oid: u32,
425
    // The table's actual name
426
    pub name: String,
427
    pub schema: String,
428
    pub is_rls_enabled: bool,
429
    pub column_names: Vec<String>,
430
}
431

432
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
433
pub struct ForeignKeyDirectives {
434
    pub local_name: Option<String>,
435
    pub foreign_name: Option<String>,
436
}
437

438
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
439
pub struct ForeignKey {
440
    pub directives: ForeignKeyDirectives,
441
    pub local_table_meta: ForeignKeyTableInfo,
442
    pub referenced_table_meta: ForeignKeyTableInfo,
443
}
444

445
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
446
pub struct TableDirectiveTotalCount {
447
    pub enabled: bool,
448
}
449

450
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
451
pub struct TableDirectiveForeignKey {
452
    // Equivalent to ForeignKeyDirectives.local_name
453
    pub local_name: Option<String>,
454
    pub local_columns: Vec<String>,
455

456
    // Equivalent to ForeignKeyDirectives.foreign_name
457
    pub foreign_name: Option<String>,
458
    pub foreign_schema: String,
459
    pub foreign_table: String,
460
    pub foreign_columns: Vec<String>,
461
}
462

463
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
464
pub struct TableDirectives {
465
    // @graphql({"name": "Foo" })
466
    pub name: Option<String>,
467

468
    // @graphql({"description": "the address of ..." })
469
    pub description: Option<String>,
470

471
    // @graphql({"totalCount": { "enabled": true } })
472
    pub total_count: Option<TableDirectiveTotalCount>,
473

474
    // @graphql({"primary_key_columns": ["id"]})
475
    pub primary_key_columns: Option<Vec<String>>,
476

477
    /*
478
    @graphql(
479
      {
480
        "foreign_keys": [
481
          {
482
            <REQUIRED>
483
            "local_columns": ["account_id"],
484
            "foriegn_schema": "public",
485
            "foriegn_table": "account",
486
            "foriegn_columns": ["id"],
487

488
            <OPTIONAL>
489
            "local_name": "foo",
490
            "foreign_name": "bar",
491
          },
492
        ]
493
      }
494
    )
495
    */
496
    pub foreign_keys: Option<Vec<TableDirectiveForeignKey>>,
497
}
498

499
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
500
pub struct Table {
501
    pub oid: u32,
502
    pub name: String,
503
    pub schema_oid: u32,
504
    pub schema: String,
505
    pub columns: Vec<Arc<Column>>,
506
    pub comment: Option<String>,
507
    pub is_rls_enabled: bool,
508
    pub relkind: String, // r = table, v = view, m = mat view, f = foreign table
509
    pub reltype: u32,
510
    pub permissions: TablePermissions,
511
    pub indexes: Vec<Index>,
512
    #[serde(default)]
513
    pub functions: Vec<Arc<Function>>,
514
    pub directives: TableDirectives,
515
}
516

517
impl Table {
518
    pub fn primary_key(&self) -> Option<Index> {
3,574✔
519
        let real_pkey = self.indexes.iter().find(|x| x.is_primary_key);
3,574✔
520

3,574✔
521
        if real_pkey.is_some() {
3,574✔
522
            return real_pkey.cloned();
3,431✔
523
        }
143✔
524

525
        // Check for a primary key definition in comment directives
526
        if let Some(column_names) = &self.directives.primary_key_columns {
143✔
527
            // validate that columns exist on the table
528
            let mut valid_column_names: Vec<&String> = vec![];
122✔
529
            for column_name in column_names {
244✔
530
                for column in &self.columns {
450✔
531
                    if column_name == &column.name {
328✔
532
                        valid_column_names.push(&column.name);
122✔
533
                    }
206✔
534
                }
535
            }
536
            if valid_column_names.len() != column_names.len() {
122✔
537
                // At least one of the column names didn't exist on the table
538
                // so the primary key directive is not valid
539
                // Ideally we'd throw an error here instead
540
                None
×
541
            } else {
542
                Some(Index {
122✔
543
                    table_oid: self.oid,
122✔
544
                    column_names: column_names.clone(),
122✔
545
                    is_unique: true,
122✔
546
                    is_primary_key: true,
122✔
547
                })
122✔
548
            }
549
        } else {
550
            None
21✔
551
        }
552
    }
3,574✔
553

554
    pub fn primary_key_columns(&self) -> Vec<&Arc<Column>> {
889✔
555
        self.primary_key()
889✔
556
            .map(|x| x.column_names)
889✔
557
            .unwrap_or_default()
889✔
558
            .iter()
889✔
559
            .map(|col_name| {
898✔
560
                self.columns
898✔
561
                    .iter()
898✔
562
                    .find(|col| &col.name == col_name)
907✔
563
                    .expect("Failed to unwrap pkey by column names")
898✔
564
            })
898✔
565
            .collect::<Vec<&Arc<Column>>>()
889✔
566
    }
889✔
567

568
    pub fn is_any_column_selectable(&self) -> bool {
1,731✔
569
        self.columns.iter().any(|x| x.permissions.is_selectable)
1,751✔
570
    }
1,731✔
571
    pub fn is_any_column_insertable(&self) -> bool {
406✔
572
        self.columns.iter().any(|x| x.permissions.is_insertable)
415✔
573
    }
406✔
574

575
    pub fn is_any_column_updatable(&self) -> bool {
406✔
576
        self.columns.iter().any(|x| x.permissions.is_updatable)
417✔
577
    }
406✔
578
}
579

580
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
581
pub struct SchemaDirectives {
582
    // @graphql({"inflect_names": true})
583
    pub inflect_names: bool,
584
    // @graphql({"max_rows": 20})
585
    pub max_rows: u64,
586
}
587

588
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
589
pub struct Schema {
590
    pub oid: u32,
591
    pub name: String,
592
    pub comment: Option<String>,
593
    pub directives: SchemaDirectives,
594
}
595

596
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
597
pub struct Config {
598
    pub search_path: Vec<String>,
599
    pub role: String,
600
    pub schema_version: i32,
601
}
602

603
#[derive(Deserialize, Debug, Eq, PartialEq)]
604
pub struct Context {
605
    pub config: Config,
606
    pub schemas: HashMap<u32, Schema>,
607
    pub tables: HashMap<u32, Arc<Table>>,
608
    foreign_keys: Vec<Arc<ForeignKey>>,
609
    pub types: HashMap<u32, Arc<Type>>,
610
    pub enums: HashMap<u32, Arc<Enum>>,
611
    pub composites: Vec<Arc<Composite>>,
612
    pub functions: Vec<Arc<Function>>,
613
}
614

615
impl Hash for Context {
616
    fn hash<H: Hasher>(&self, state: &mut H) {
2,127✔
617
        // Only the config is needed to has ha Context
2,127✔
618
        self.config.hash(state);
2,127✔
619
    }
2,127✔
620
}
621

622
impl Context {
623
    /// Collect all foreign keys referencing (inbound or outbound) a table
624
    pub fn foreign_keys(&self) -> Vec<Arc<ForeignKey>> {
362✔
625
        let mut fkeys: Vec<Arc<ForeignKey>> = self.foreign_keys.clone();
362✔
626

627
        // Add foreign keys defined in comment directives
628
        for table in self.tables.values() {
600✔
629
            let directive_fkeys: Vec<TableDirectiveForeignKey> =
600✔
630
                match &table.directives.foreign_keys {
600✔
631
                    Some(keys) => keys.clone(),
28✔
632
                    None => vec![],
572✔
633
                };
634

635
            for directive_fkey in directive_fkeys.iter() {
600✔
636
                let referenced_t = match self.get_table_by_name(
28✔
637
                    &directive_fkey.foreign_schema,
28✔
638
                    &directive_fkey.foreign_table,
28✔
639
                ) {
28✔
640
                    Some(t) => t,
28✔
641
                    None => {
642
                        // No table found with requested name. Skip.
643
                        continue;
×
644
                    }
645
                };
646

647
                let referenced_t_column_names: HashSet<&String> =
28✔
648
                    referenced_t.columns.iter().map(|x| &x.name).collect();
80✔
649

28✔
650
                // Verify all foreign column references are valid
28✔
651
                if !directive_fkey
28✔
652
                    .foreign_columns
28✔
653
                    .iter()
28✔
654
                    .all(|col| referenced_t_column_names.contains(col))
28✔
655
                {
656
                    // Skip if invalid references exist
UNCOV
657
                    continue;
×
658
                }
28✔
659

28✔
660
                let fk = ForeignKey {
28✔
661
                    local_table_meta: ForeignKeyTableInfo {
28✔
662
                        oid: table.oid,
28✔
663
                        name: table.name.clone(),
28✔
664
                        schema: table.schema.clone(),
28✔
665
                        is_rls_enabled: table.is_rls_enabled,
28✔
666
                        column_names: directive_fkey.local_columns.clone(),
28✔
667
                    },
28✔
668
                    referenced_table_meta: ForeignKeyTableInfo {
28✔
669
                        oid: referenced_t.oid,
28✔
670
                        name: referenced_t.name.clone(),
28✔
671
                        schema: referenced_t.schema.clone(),
28✔
672
                        is_rls_enabled: table.is_rls_enabled,
28✔
673
                        column_names: directive_fkey.foreign_columns.clone(),
28✔
674
                    },
28✔
675
                    directives: ForeignKeyDirectives {
28✔
676
                        local_name: directive_fkey.local_name.clone(),
28✔
677
                        foreign_name: directive_fkey.foreign_name.clone(),
28✔
678
                    },
28✔
679
                };
28✔
680

28✔
681
                fkeys.push(Arc::new(fk));
28✔
682
            }
683
        }
684

685
        fkeys
362✔
686
            .into_iter()
362✔
687
            .filter(|fk| self.fkey_is_selectable(fk))
362✔
688
            .collect()
362✔
689
    }
362✔
690

691
    /// Check if a type is a composite type
692
    pub fn is_composite(&self, type_oid: u32) -> bool {
2,640✔
693
        self.composites.iter().any(|x| x.oid == type_oid)
2,640✔
694
    }
2,640✔
695

696
    pub fn get_table_by_name(
28✔
697
        &self,
28✔
698
        schema_name: &String,
28✔
699
        table_name: &String,
28✔
700
    ) -> Option<&Arc<Table>> {
28✔
701
        self.tables
28✔
702
            .values()
28✔
703
            .find(|x| &x.schema == schema_name && &x.name == table_name)
84✔
704
    }
28✔
705

706
    pub fn get_table_by_oid(&self, oid: u32) -> Option<&Arc<Table>> {
774✔
707
        self.tables.get(&oid)
774✔
708
    }
774✔
709

710
    /// Check if the local side of a foreign key is comprised of unique columns
711
    pub fn fkey_is_locally_unique(&self, fkey: &ForeignKey) -> bool {
124✔
712
        let table: &Arc<Table> = match self.get_table_by_oid(fkey.local_table_meta.oid) {
124✔
713
            Some(table) => table,
124✔
714
            None => {
UNCOV
715
                return false;
×
716
            }
717
        };
718

719
        let fkey_columns: HashSet<&String> = fkey.local_table_meta.column_names.iter().collect();
124✔
720

721
        for index in table.indexes.iter().filter(|x| x.is_unique) {
132✔
722
            let index_column_names: HashSet<&String> = index.column_names.iter().collect();
130✔
723

130✔
724
            if index_column_names
130✔
725
                .iter()
130✔
726
                .all(|col_name| fkey_columns.contains(col_name))
132✔
727
            {
728
                return true;
10✔
729
            }
120✔
730
        }
731
        false
114✔
732
    }
124✔
733

734
    /// Are both sides of the foreign key composed of selectable columns
735
    pub fn fkey_is_selectable(&self, fkey: &ForeignKey) -> bool {
210✔
736
        let local_table: &Arc<Table> = match self.get_table_by_oid(fkey.local_table_meta.oid) {
210✔
737
            Some(table) => table,
210✔
738
            None => {
UNCOV
739
                return false;
×
740
            }
741
        };
742

743
        let referenced_table: &Arc<Table> =
210✔
744
            match self.get_table_by_oid(fkey.referenced_table_meta.oid) {
210✔
745
                Some(table) => table,
210✔
746
                None => {
UNCOV
747
                    return false;
×
748
                }
749
            };
750

751
        let fkey_local_columns = &fkey.local_table_meta.column_names;
210✔
752
        let fkey_referenced_columns = &fkey.referenced_table_meta.column_names;
210✔
753

210✔
754
        let local_columns_selectable: HashSet<&String> = local_table
210✔
755
            .columns
210✔
756
            .iter()
210✔
757
            .filter(|x| x.permissions.is_selectable)
666✔
758
            .map(|col| &col.name)
660✔
759
            .collect();
210✔
760

210✔
761
        let referenced_columns_selectable: HashSet<&String> = referenced_table
210✔
762
            .columns
210✔
763
            .iter()
210✔
764
            .filter(|x| x.permissions.is_selectable)
1,064✔
765
            .map(|col| &col.name)
1,058✔
766
            .collect();
210✔
767

210✔
768
        fkey_local_columns
210✔
769
            .iter()
210✔
770
            .all(|col| local_columns_selectable.contains(col))
210✔
771
            && fkey_referenced_columns
210✔
772
                .iter()
210✔
773
                .all(|col| referenced_columns_selectable.contains(col))
210✔
774
    }
210✔
775
}
776

777
/// This method is similar to `Spi::get_one` with the only difference
778
/// being that it calls `client.select` instead of `client.update`.
779
/// The `client.update` method generates a new transaction id so
780
/// calling `Spi::get_one` is not possible when postgres is in
781
/// recovery mode.
782
pub(crate) fn get_one_readonly<A: FromDatum + IntoDatum>(
922✔
783
    query: &str,
922✔
784
) -> std::result::Result<Option<A>, pgrx::spi::Error> {
922✔
785
    Spi::connect(|client| client.select(query, Some(1), None)?.first().get_one())
922✔
786
}
922✔
787

788
pub fn load_sql_config() -> Config {
631✔
789
    let query = include_str!("../sql/load_sql_config.sql");
631✔
790
    let sql_result: serde_json::Value = get_one_readonly::<JsonB>(query)
631✔
791
        .expect("failed to read sql config")
631✔
792
        .expect("sql config is missing")
631✔
793
        .0;
631✔
794
    let config: Config =
631✔
795
        serde_json::from_value(sql_result).expect("failed to convert sql config into json");
631✔
796
    config
631✔
797
}
631✔
798

799
pub fn calculate_hash<T: Hash>(t: &T) -> u64 {
5,323✔
800
    let mut s = DefaultHasher::new();
5,323✔
801
    t.hash(&mut s);
5,323✔
802
    s.finish()
5,323✔
803
}
5,323✔
804

805
#[cached(
116✔
806
    type = "SizedCache<u64, Result<Arc<Context>, String>>",
116✔
807
    create = "{ SizedCache::with_size(250) }",
116✔
808
    convert = r#"{ calculate_hash(_config) }"#
116✔
809
)]
116✔
810
pub fn load_sql_context(_config: &Config) -> Result<Arc<Context>, String> {
291✔
811
    // cache value for next query
291✔
812
    let query = include_str!("../sql/load_sql_context.sql");
291✔
813
    let sql_result: serde_json::Value = get_one_readonly::<JsonB>(query)
291✔
814
        .expect("failed to read sql context")
291✔
815
        .expect("sql context is missing")
291✔
816
        .0;
291✔
817
    let context: Result<Context, serde_json::Error> = serde_json::from_value(sql_result);
291✔
818

819
    /// This pass cross-reference types with its details
820
    fn type_details(mut context: Context) -> Context {
289✔
821
        let mut array_types = HashMap::new();
289✔
822
        // We process types to cross-reference their details
823
        for (oid, type_) in context.types.iter_mut() {
177,677✔
824
            if let Some(mtype) = Arc::get_mut(type_) {
177,677✔
825
                // It should be possible to get a mutable reference to type at this point
826
                // as there are no other references to this Arc at this point.
827

828
                // Depending on type's category, locate the appropriate piece of information about it
829
                mtype.details = match mtype.category {
177,677✔
830
                    TypeCategory::Enum => context.enums.get(oid).cloned().map(TypeDetails::Enum),
26✔
831
                    TypeCategory::Composite => context
1✔
832
                        .composites
1✔
833
                        .iter()
1✔
834
                        .find(|c| c.oid == *oid)
1✔
835
                        .cloned()
1✔
836
                        .map(TypeDetails::Composite),
1✔
837
                    TypeCategory::Table => context.tables.get(oid).cloned().map(TypeDetails::Table),
60,632✔
838
                    TypeCategory::Array => {
839
                        // We can't cross-reference with `context.types` here as it is already mutably borrowed,
840
                        // so we instead memorize the reference to process later
841
                        if let Some(element_oid) = mtype.array_element_type_oid {
84,937✔
842
                            array_types.insert(*oid, element_oid);
84,937✔
843
                        }
84,937✔
844
                        None
84,937✔
845
                    }
846
                    _ => None,
32,081✔
847
                };
UNCOV
848
            }
×
849
        }
850

851
        // Ensure the types are ordered so that we don't run into a situation where we can't
852
        // update the type anymore as it has been referenced but the type details weren't completed yet
853
        let referenced_types = array_types.values().copied().collect::<Vec<_>>();
289✔
854
        let mut ordered_types = array_types
289✔
855
            .iter()
289✔
856
            .map(|(k, v)| (*k, *v))
84,937✔
857
            .collect::<Vec<_>>();
289✔
858
        // We sort them by their presence in referencing. If the type has been referenced,
289✔
859
        // it should be at the top.
289✔
860
        ordered_types.sort_by(|(k1, _), (k2, _)| {
191,071✔
861
            if referenced_types.contains(k1) && referenced_types.contains(k2) {
191,071✔
862
                Ordering::Equal
420✔
863
            } else if referenced_types.contains(k1) {
190,651✔
864
                Ordering::Less
38,006✔
865
            } else {
866
                Ordering::Greater
152,645✔
867
            }
868
        });
191,071✔
869

870
        // Now we're ready to process array types
871
        for (array_oid, element_oid) in ordered_types {
85,226✔
872
            // We remove the element type from the map to ensure there is no mutability conflict for when
873
            // we get a mutable reference to the array type. We will put it back after we're done with it,
874
            // a few lines below.
875
            if let Some(element_t) = context.types.remove(&element_oid) {
84,937✔
876
                if let Some(array_t) = context.types.get_mut(&array_oid) {
84,937✔
877
                    if let Some(array) = Arc::get_mut(array_t) {
84,937✔
878
                        // It should be possible to get a mutable reference to type at this point
84,937✔
879
                        // as there are no other references to this Arc at this point.
84,937✔
880
                        array.details = Some(TypeDetails::Element(element_t.clone()));
84,937✔
881
                    } else {
84,937✔
882
                        // For some reason, we weren't able to get it. It means something have changed
883
                        // in our logic and we're presenting an assertion violation. Let's report it.
884
                        // It's a bug.
UNCOV
885
                        pgrx::warning!(
×
UNCOV
886
                            "Assertion violation: array type with OID {} is already referenced",
×
UNCOV
887
                            array_oid
×
UNCOV
888
                        );
×
UNCOV
889
                        continue;
×
890
                    }
891
                    // Put the element type back. NB: Very important to keep this line! It'll be used
892
                    // further down the loop. There is a check at the end of each loop's iteration that
893
                    // we actually did this. Being defensive.
894
                    context.types.insert(element_oid, element_t);
84,937✔
895
                } else {
896
                    // We weren't able to find the OID of the array, which is odd because we just got
897
                    // it from the context. This means we messed something up and it is a bug. Report it.
UNCOV
898
                    pgrx::warning!(
×
UNCOV
899
                        "Assertion violation: array type with OID {} is not found",
×
UNCOV
900
                        array_oid
×
UNCOV
901
                    );
×
UNCOV
902
                    continue;
×
903
                }
904
            } else {
905
                // We weren't able to find the OID of the element type, which is also odd because we just got
906
                // it from the context. This means it's a bug as well. Report it.
UNCOV
907
                pgrx::warning!(
×
UNCOV
908
                        "Assertion violation: referenced element type with OID {} of array type with OID {} is not found",
×
UNCOV
909
                        element_oid, array_oid);
×
UNCOV
910
                continue;
×
911
            }
912

913
            // Here we are asserting that we did in fact return the element type back to the list. Part of being
914
            // defensive here.
915
            if !context.types.contains_key(&element_oid) {
84,937✔
UNCOV
916
                pgrx::warning!("Assertion violation: referenced element type with OID {} was not returned to the list of types", element_oid );
×
UNCOV
917
                continue;
×
918
            }
84,937✔
919
        }
920

921
        context
289✔
922
    }
289✔
923

924
    /// This pass cross-reference column types
925
    fn column_types(mut context: Context) -> Context {
289✔
926
        // We process tables to cross-reference their columns' types
927
        for (_oid, table) in context.tables.iter_mut() {
289✔
928
            if let Some(mtable) = Arc::get_mut(table) {
227✔
929
                // It should be possible to get a mutable reference to table at this point
930
                // as there are no other references to this Arc at this point.
931

932
                // We will now iterate over columns
933
                for column in mtable.columns.iter_mut() {
792✔
934
                    if let Some(mcolumn) = Arc::get_mut(column) {
792✔
935
                        // It should be possible to get a mutable reference to column at this point
792✔
936
                        // as there are no other references to this Arc at this point.
792✔
937

792✔
938
                        // Find a matching type
792✔
939
                        mcolumn.type_ = context.types.get(&mcolumn.type_oid).cloned();
792✔
940
                    }
792✔
941
                }
UNCOV
942
            }
×
943
        }
944
        context
289✔
945
    }
289✔
946

947
    /// This pass populates functions for tables
948
    fn populate_table_functions(mut context: Context) -> Context {
289✔
949
        let mut arg_type_to_func: HashMap<u32, Vec<&Arc<Function>>> = HashMap::new();
289✔
950
        for function in context.functions.iter().filter(|f| f.num_args == 1) {
1,446✔
951
            let functions = arg_type_to_func.entry(function.arg_types[0]).or_default();
235✔
952
            functions.push(function);
235✔
953
        }
235✔
954
        for table in &mut context.tables.values_mut() {
289✔
955
            if let Some(table) = Arc::get_mut(table) {
227✔
956
                if let Some(functions) = arg_type_to_func.get(&table.reltype) {
227✔
957
                    for function in functions {
71✔
958
                        table.functions.push(Arc::clone(function));
36✔
959
                    }
36✔
960
                }
192✔
UNCOV
961
            }
×
962
        }
963
        context
289✔
964
    }
289✔
965

966
    context
291✔
967
        .map(type_details)
291✔
968
        .map(column_types)
291✔
969
        .map(populate_table_functions)
291✔
970
        .map(Arc::new)
291✔
971
        .map_err(|e| {
291✔
972
            format!(
2✔
973
                "Error while loading schema, check comment directives. {}",
2✔
974
                e
2✔
975
            )
2✔
976
        })
291✔
977
}
291✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc