• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

supabase / pg_graphql / 21279141151

23 Jan 2026 08:07AM UTC coverage: 91.329% (-0.03%) from 91.36%
21279141151

Pull #620

github

web-flow
Merge 151f0b46a into 0d7a66bfd
Pull Request #620: chore: bump Rust edition to 2024

126 of 213 new or added lines in 7 files covered. (59.15%)

4 existing lines in 2 files now uncovered.

7499 of 8211 relevant lines covered (91.33%)

1140.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.82
/src/sql_types.rs
1
use bimap::BiBTreeMap;
2
use cached::SizedCache;
3
use cached::proc_macro::cached;
4
use lazy_static::lazy_static;
5
use pgrx::*;
6
use serde::{Deserialize, Serialize};
7

8
use crate::error::GraphQLResult;
9
use std::cmp::Ordering;
10
use std::collections::hash_map::DefaultHasher;
11
use std::collections::{HashMap, HashSet};
12
use std::hash::{Hash, Hasher};
13
use std::sync::Arc;
14
use std::*;
15

16
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
17
pub struct ColumnPermissions {
18
    pub is_insertable: bool,
19
    pub is_selectable: bool,
20
    pub is_updatable: bool,
21
    // admin interface
22
    // alterable?
23
}
24

25
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
26
pub struct ColumnDirectives {
27
    pub name: Option<String>,
28
    pub description: Option<String>,
29
}
30

31
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
32
pub struct Column {
33
    pub name: String,
34
    pub type_oid: u32,
35
    pub type_: Option<Arc<Type>>,
36
    pub type_name: String,
37
    pub max_characters: Option<i32>,
38
    pub schema_oid: u32,
39
    pub is_not_null: bool,
40
    pub is_serial: bool,
41
    pub is_generated: bool,
42
    pub has_default: bool,
43
    pub attribute_num: i32,
44
    pub permissions: ColumnPermissions,
45
    pub comment: Option<String>,
46
    pub directives: ColumnDirectives,
47
}
48

49
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
50
pub struct FunctionDirectives {
51
    pub name: Option<String>,
52
    // @graphql({"description": "the address of ..." })
53
    pub description: Option<String>,
54
}
55

56
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
57
pub struct FunctionPermissions {
58
    pub is_executable: bool,
59
}
60

61
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
62
pub enum FunctionVolatility {
63
    #[serde(rename(deserialize = "v"))]
64
    Volatile,
65
    #[serde(rename(deserialize = "s"))]
66
    Stable,
67
    #[serde(rename(deserialize = "i"))]
68
    Immutable,
69
}
70

71
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
72
pub struct Function {
73
    pub oid: u32,
74
    pub name: String,
75
    pub schema_oid: u32,
76
    pub schema_name: String,
77
    pub arg_types: Vec<u32>,
78
    pub arg_names: Option<Vec<String>>,
79
    pub arg_defaults: Option<String>,
80
    pub num_args: u32,
81
    pub num_default_args: u32,
82
    pub arg_type_names: Vec<String>,
83
    pub volatility: FunctionVolatility,
84
    pub type_oid: u32,
85
    pub type_name: String,
86
    pub is_set_of: bool,
87
    pub comment: Option<String>,
88
    pub directives: FunctionDirectives,
89
    pub permissions: FunctionPermissions,
90
}
91

92
impl Function {
93
    pub fn args(&self) -> impl Iterator<Item = (u32, &str, Option<&str>, Option<DefaultValue>)> {
5,344✔
94
        ArgsIterator::new(
5,344✔
95
            &self.arg_types,
5,344✔
96
            &self.arg_type_names,
5,344✔
97
            &self.arg_names,
5,344✔
98
            &self.arg_defaults,
5,344✔
99
            self.num_default_args,
5,344✔
100
        )
101
    }
5,344✔
102

103
    pub fn function_names_to_count(all_functions: &[Arc<Function>]) -> HashMap<&String, u32> {
542✔
104
        let mut function_name_to_count = HashMap::new();
542✔
105
        for function_name in all_functions.iter().map(|f| &f.name) {
2,219✔
106
            let entry = function_name_to_count.entry(function_name).or_insert(0u32);
2,219✔
107
            *entry += 1;
2,219✔
108
        }
2,219✔
109
        function_name_to_count
542✔
110
    }
542✔
111

112
    pub fn is_supported(
2,219✔
113
        &self,
2,219✔
114
        context: &Context,
2,219✔
115
        function_name_to_count: &HashMap<&String, u32>,
2,219✔
116
    ) -> bool {
2,219✔
117
        let types = &context.types;
2,219✔
118
        self.return_type_is_supported(types)
2,219✔
119
            && self.arg_types_are_supported(types)
2,183✔
120
            && !self.is_function_overloaded(function_name_to_count)
2,086✔
121
            && !self.has_a_nameless_arg()
1,906✔
122
            && self.permissions.is_executable
1,694✔
123
            && !self.is_in_a_system_schema()
1,691✔
124
    }
2,219✔
125

126
    fn arg_types_are_supported(&self, types: &HashMap<u32, Arc<Type>>) -> bool {
2,183✔
127
        self.args().all(|(arg_type, _, _, _)| {
3,393✔
128
            if let Some(arg_type) = types.get(&arg_type) {
3,393✔
129
                let array_element_type_is_supported = self.array_element_type_is_supported(
3,393✔
130
                    &arg_type.category,
3,393✔
131
                    arg_type.array_element_type_oid,
3,393✔
132
                    types,
3,393✔
133
                );
134
                arg_type.category == TypeCategory::Other
3,393✔
135
                    || (arg_type.category == TypeCategory::Array && array_element_type_is_supported)
776✔
136
            } else {
137
                false
×
138
            }
139
        })
3,393✔
140
    }
2,183✔
141

142
    fn return_type_is_supported(&self, types: &HashMap<u32, Arc<Type>>) -> bool {
2,219✔
143
        if let Some(return_type) = types.get(&self.type_oid) {
2,219✔
144
            let array_element_type_is_supported = self.array_element_type_is_supported(
2,219✔
145
                &return_type.category,
2,219✔
146
                return_type.array_element_type_oid,
2,219✔
147
                types,
2,219✔
148
            );
149
            return_type.category != TypeCategory::Pseudo
2,219✔
150
                && return_type.category != TypeCategory::Enum
2,219✔
151
                && return_type.name != "record"
2,214✔
152
                && return_type.name != "trigger"
2,199✔
153
                && return_type.name != "event_trigger"
2,195✔
154
                && array_element_type_is_supported
2,188✔
155
        } else {
156
            false
×
157
        }
158
    }
2,219✔
159

160
    fn array_element_type_is_supported(
5,612✔
161
        &self,
5,612✔
162
        type_category: &TypeCategory,
5,612✔
163
        array_element_type_oid: Option<u32>,
5,612✔
164
        types: &HashMap<u32, Arc<Type>>,
5,612✔
165
    ) -> bool {
5,612✔
166
        if *type_category == TypeCategory::Array {
5,612✔
167
            if let Some(array_element_type_oid) = array_element_type_oid {
999✔
168
                if let Some(array_element_type) = types.get(&array_element_type_oid) {
999✔
169
                    array_element_type.category == TypeCategory::Other
999✔
170
                } else {
171
                    false
×
172
                }
173
            } else {
174
                false
×
175
            }
176
        } else {
177
            true
4,613✔
178
        }
179
    }
5,612✔
180

181
    fn is_function_overloaded(&self, function_name_to_count: &HashMap<&String, u32>) -> bool {
2,086✔
182
        if let Some(&count) = function_name_to_count.get(&self.name) {
2,086✔
183
            count > 1
2,086✔
184
        } else {
185
            false
×
186
        }
187
    }
2,086✔
188

189
    fn has_a_nameless_arg(&self) -> bool {
1,906✔
190
        self.args().any(|(_, _, arg_name, _)| arg_name.is_none())
2,798✔
191
    }
1,906✔
192

193
    fn is_in_a_system_schema(&self) -> bool {
1,691✔
194
        // These are default schemas in supabase configuration
195
        let system_schemas = &["graphql", "graphql_public", "auth", "extensions"];
1,691✔
196
        system_schemas.contains(&self.schema_name.as_str())
1,691✔
197
    }
1,691✔
198
}
199

200
struct ArgsIterator<'a> {
201
    index: usize,
202
    arg_types: &'a [u32],
203
    arg_type_names: &'a Vec<String>,
204
    arg_names: &'a Option<Vec<String>>,
205
    arg_defaults: Vec<Option<DefaultValue>>,
206
}
207

208
#[derive(Clone)]
209
pub enum DefaultValue {
210
    NonNull(String),
211
    Null,
212
}
213

214
impl<'a> ArgsIterator<'a> {
215
    fn new(
5,344✔
216
        arg_types: &'a [u32],
5,344✔
217
        arg_type_names: &'a Vec<String>,
5,344✔
218
        arg_names: &'a Option<Vec<String>>,
5,344✔
219
        arg_defaults: &'a Option<String>,
5,344✔
220
        num_default_args: u32,
5,344✔
221
    ) -> ArgsIterator<'a> {
5,344✔
222
        ArgsIterator {
5,344✔
223
            index: 0,
5,344✔
224
            arg_types,
5,344✔
225
            arg_type_names,
5,344✔
226
            arg_names,
5,344✔
227
            arg_defaults: Self::defaults(
5,344✔
228
                arg_types,
5,344✔
229
                arg_defaults,
5,344✔
230
                num_default_args as usize,
5,344✔
231
                arg_types.len(),
5,344✔
232
            ),
5,344✔
233
        }
5,344✔
234
    }
5,344✔
235

236
    fn defaults(
5,344✔
237
        arg_types: &'a [u32],
5,344✔
238
        arg_defaults: &'a Option<String>,
5,344✔
239
        num_default_args: usize,
5,344✔
240
        num_total_args: usize,
5,344✔
241
    ) -> Vec<Option<DefaultValue>> {
5,344✔
242
        let mut defaults = vec![None; num_total_args];
5,344✔
243
        let Some(arg_defaults) = arg_defaults else {
5,344✔
244
            return defaults;
5,219✔
245
        };
246

247
        if num_default_args == 0 {
125✔
248
            return defaults;
×
249
        }
125✔
250

251
        let default_strs: Vec<&str> = arg_defaults.split(',').collect();
125✔
252

253
        if default_strs.len() != num_default_args {
125✔
254
            return defaults;
13✔
255
        }
112✔
256

257
        debug_assert!(num_default_args <= num_total_args);
112✔
258
        let start_idx = num_total_args - num_default_args;
112✔
259
        for i in start_idx..num_total_args {
471✔
260
            defaults[i] = Self::sql_to_graphql_default(default_strs[i - start_idx], arg_types[i])
471✔
261
        }
262

263
        defaults
112✔
264
    }
5,344✔
265

266
    fn sql_to_graphql_default(default_str: &str, type_oid: u32) -> Option<DefaultValue> {
471✔
267
        let trimmed = default_str.trim();
471✔
268

269
        if trimmed.starts_with("NULL::") {
471✔
270
            return Some(DefaultValue::Null);
318✔
271
        }
153✔
272

273
        let res = match type_oid {
153✔
274
            21 | 23 => trimmed
72✔
275
                .parse::<i32>()
72✔
276
                .ok()
72✔
277
                .map(|i| DefaultValue::NonNull(i.to_string())),
72✔
278
            16 => trimmed
11✔
279
                .parse::<bool>()
11✔
280
                .ok()
11✔
281
                .map(|i| DefaultValue::NonNull(i.to_string())),
11✔
282
            700 | 701 => trimmed
22✔
283
                .parse::<f64>()
22✔
284
                .ok()
22✔
285
                .map(|i| DefaultValue::NonNull(i.to_string())),
22✔
286
            25 => trimmed.strip_suffix("::text").map(|i| {
26✔
287
                DefaultValue::NonNull(format!("\"{}\"", i.trim_matches(',').trim_matches('\'')))
16✔
288
            }),
16✔
289
            _ => None,
22✔
290
        };
291

292
        // return the non-parsed value as default if for whatever reason the default value can't
293
        // be parsed into a value of the required type. This fixes problems where the default
294
        // is a complex expression like a function call etc. See test/sql/issue_533.sql for
295
        // a test case for this scenario.
296
        if res.is_some() {
153✔
297
            res
101✔
298
        } else {
299
            Some(DefaultValue::Null)
52✔
300
        }
301
    }
471✔
302
}
303

304
lazy_static! {
305
    static ref TEXT_TYPE: String = "text".to_string();
306
}
307

308
impl<'a> Iterator for ArgsIterator<'a> {
309
    type Item = (u32, &'a str, Option<&'a str>, Option<DefaultValue>);
310

311
    fn next(&mut self) -> Option<Self::Item> {
13,205✔
312
        if self.index < self.arg_types.len() {
13,205✔
313
            debug_assert!(self.arg_types.len() == self.arg_type_names.len());
8,170✔
314
            let arg_name = if let Some(arg_names) = self.arg_names {
8,170✔
315
                debug_assert!(arg_names.len() >= self.arg_types.len());
7,409✔
316
                let arg_name = arg_names[self.index].as_str();
7,409✔
317
                if !arg_name.is_empty() {
7,409✔
318
                    Some(arg_name)
7,409✔
319
                } else {
320
                    None
×
321
                }
322
            } else {
323
                None
761✔
324
            };
325
            let arg_type = self.arg_types[self.index];
8,170✔
326
            let mut arg_type_name = &self.arg_type_names[self.index];
8,170✔
327
            if arg_type_name == "character" {
8,170✔
328
                arg_type_name = &TEXT_TYPE;
86✔
329
            }
8,084✔
330
            let arg_default = self.arg_defaults[self.index].clone();
8,170✔
331
            self.index += 1;
8,170✔
332
            Some((arg_type, arg_type_name, arg_name, arg_default))
8,170✔
333
        } else {
334
            None
5,035✔
335
        }
336
    }
13,205✔
337
}
338

339
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
340
pub struct TablePermissions {
341
    pub is_insertable: bool,
342
    pub is_selectable: bool,
343
    pub is_updatable: bool,
344
    pub is_deletable: bool,
345
}
346

347
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
348
pub struct TypePermissions {
349
    pub is_usable: bool,
350
}
351

352
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
353
pub enum TypeCategory {
354
    Enum,
355
    Composite,
356
    Table,
357
    Array,
358
    Pseudo,
359
    Other,
360
}
361

362
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
363
pub struct Type {
364
    pub oid: u32,
365
    pub schema_oid: u32,
366
    pub name: String,
367
    pub category: TypeCategory,
368
    pub array_element_type_oid: Option<u32>,
369
    pub table_oid: Option<u32>,
370
    pub comment: Option<String>,
371
    pub permissions: TypePermissions,
372
    pub details: Option<TypeDetails>,
373
}
374

375
// `TypeDetails` derives `Deserialized` but is not expected to come
376
// from the SQL context. Instead, it is populated in a separate pass.
377
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
378
pub enum TypeDetails {
379
    Enum(Arc<Enum>),
380
    Composite(Arc<Composite>),
381
    Table(Arc<Table>),
382
    Element(Arc<Type>),
383
}
384

385
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
386
pub struct EnumValue {
387
    pub oid: u32,
388
    pub name: String,
389
    pub sort_order: i32,
390
}
391

392
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
393
pub struct Enum {
394
    pub oid: u32,
395
    pub schema_oid: u32,
396
    pub name: String,
397
    pub values: Vec<EnumValue>,
398
    pub array_element_type_oid: Option<u32>,
399
    pub comment: Option<String>,
400
    pub permissions: TypePermissions,
401
    pub directives: EnumDirectives,
402
}
403

404
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
405
pub struct EnumDirectives {
406
    pub name: Option<String>,
407
    pub mappings: Option<BiBTreeMap<String, String>>,
408
}
409

410
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
411
pub struct Composite {
412
    pub oid: u32,
413
    pub schema_oid: u32,
414
}
415

416
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
417
pub struct Index {
418
    pub table_oid: u32,
419
    pub column_names: Vec<String>,
420
    pub is_unique: bool,
421
    pub is_primary_key: bool,
422
}
423

424
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
425
pub struct ForeignKeyTableInfo {
426
    pub oid: u32,
427
    // The table's actual name
428
    pub name: String,
429
    pub schema: String,
430
    pub is_rls_enabled: bool,
431
    pub column_names: Vec<String>,
432
}
433

434
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
435
pub struct ForeignKeyDirectives {
436
    pub local_name: Option<String>,
437
    pub foreign_name: Option<String>,
438
}
439

440
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
441
pub struct ForeignKey {
442
    pub directives: ForeignKeyDirectives,
443
    pub local_table_meta: ForeignKeyTableInfo,
444
    pub referenced_table_meta: ForeignKeyTableInfo,
445
}
446

447
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
448
pub struct TableDirectiveTotalCount {
449
    pub enabled: bool,
450
}
451

452
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
453
pub struct TableDirectiveAggregate {
454
    pub enabled: bool,
455
}
456

457
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
458
pub struct TableDirectiveForeignKey {
459
    // Equivalent to ForeignKeyDirectives.local_name
460
    pub local_name: Option<String>,
461
    pub local_columns: Vec<String>,
462

463
    // Equivalent to ForeignKeyDirectives.foreign_name
464
    pub foreign_name: Option<String>,
465
    pub foreign_schema: String,
466
    pub foreign_table: String,
467
    pub foreign_columns: Vec<String>,
468
}
469

470
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
471
pub struct TableDirectives {
472
    // @graphql({"name": "Foo" })
473
    pub name: Option<String>,
474

475
    // @graphql({"description": "the address of ..." })
476
    pub description: Option<String>,
477

478
    // @graphql({"totalCount": { "enabled": true } })
479
    pub total_count: Option<TableDirectiveTotalCount>,
480

481
    // @graphql({"aggregate": { "enabled": true } })
482
    pub aggregate: Option<TableDirectiveAggregate>,
483

484
    // @graphql({"primary_key_columns": ["id"]})
485
    pub primary_key_columns: Option<Vec<String>>,
486

487
    // @graphql({"max_rows": 20})
488
    pub max_rows: Option<u64>,
489

490
    /*
491
    @graphql(
492
      {
493
        "foreign_keys": [
494
          {
495
            <REQUIRED>
496
            "local_columns": ["account_id"],
497
            "foriegn_schema": "public",
498
            "foriegn_table": "account",
499
            "foriegn_columns": ["id"],
500

501
            <OPTIONAL>
502
            "local_name": "foo",
503
            "foreign_name": "bar",
504
          },
505
        ]
506
      }
507
    )
508
    */
509
    pub foreign_keys: Option<Vec<TableDirectiveForeignKey>>,
510
}
511

512
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
513
pub struct Table {
514
    pub oid: u32,
515
    pub name: String,
516
    pub schema_oid: u32,
517
    pub schema: String,
518
    pub columns: Vec<Arc<Column>>,
519
    pub comment: Option<String>,
520
    pub is_rls_enabled: bool,
521
    pub relkind: String, // r = table, v = view, m = mat view, f = foreign table
522
    pub reltype: u32,
523
    pub permissions: TablePermissions,
524
    pub indexes: Vec<Index>,
525
    #[serde(default)]
526
    pub functions: Vec<Arc<Function>>,
527
    pub directives: TableDirectives,
528
}
529

530
impl Table {
531
    pub fn primary_key(&self) -> Option<Index> {
3,877✔
532
        let real_pkey = self.indexes.iter().find(|x| x.is_primary_key);
3,877✔
533

534
        if real_pkey.is_some() {
3,877✔
535
            return real_pkey.cloned();
3,657✔
536
        }
220✔
537

538
        // Check for a primary key definition in comment directives
539
        if let Some(column_names) = &self.directives.primary_key_columns {
220✔
540
            // validate that columns exist on the table
541
            let mut valid_column_names: Vec<&String> = vec![];
199✔
542
            for column_name in column_names {
199✔
543
                for column in &self.columns {
406✔
544
                    if column_name == &column.name {
406✔
545
                        valid_column_names.push(&column.name);
199✔
546
                    }
207✔
547
                }
548
            }
549
            if valid_column_names.len() != column_names.len() {
199✔
550
                // At least one of the column names didn't exist on the table
551
                // so the primary key directive is not valid
552
                // Ideally we'd throw an error here instead
553
                None
×
554
            } else {
555
                Some(Index {
199✔
556
                    table_oid: self.oid,
199✔
557
                    column_names: column_names.clone(),
199✔
558
                    is_unique: true,
199✔
559
                    is_primary_key: true,
199✔
560
                })
199✔
561
            }
562
        } else {
563
            None
21✔
564
        }
565
    }
3,877✔
566

567
    pub fn primary_key_columns(&self) -> Vec<&Arc<Column>> {
1,228✔
568
        self.primary_key()
1,228✔
569
            .map(|x| x.column_names)
1,228✔
570
            .unwrap_or_default()
1,228✔
571
            .iter()
1,228✔
572
            .map(|col_name| {
1,240✔
573
                self.columns
1,240✔
574
                    .iter()
1,240✔
575
                    .find(|col| &col.name == col_name)
1,252✔
576
                    .expect("Failed to unwrap pkey by column names")
1,240✔
577
            })
1,240✔
578
            .collect::<Vec<&Arc<Column>>>()
1,228✔
579
    }
1,228✔
580

581
    pub fn is_any_column_selectable(&self) -> bool {
1,867✔
582
        self.columns.iter().any(|x| x.permissions.is_selectable)
1,887✔
583
    }
1,867✔
584

585
    pub fn is_any_column_insertable(&self) -> bool {
406✔
586
        self.columns.iter().any(|x| x.permissions.is_insertable)
415✔
587
    }
406✔
588

589
    pub fn is_any_column_updatable(&self) -> bool {
406✔
590
        self.columns.iter().any(|x| x.permissions.is_updatable)
417✔
591
    }
406✔
592

593
    /// Get the effective max_rows value for this table.
594
    /// If table-specific max_rows is set, use that.
595
    /// Otherwise, fall back to schema-level max_rows.
596
    /// If neither is set, use the global default(set in load_sql_context.sql)
597
    pub fn max_rows(&self, schema: &Schema) -> u64 {
298✔
598
        self.directives
298✔
599
            .max_rows
298✔
600
            .unwrap_or(schema.directives.max_rows)
298✔
601
    }
298✔
602
}
603

604
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
605
pub struct SchemaDirectives {
606
    // @graphql({"inflect_names": true})
607
    pub inflect_names: bool,
608
    // @graphql({"max_rows": 20})
609
    pub max_rows: u64,
610
}
611

612
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
613
pub struct Schema {
614
    pub oid: u32,
615
    pub name: String,
616
    pub comment: Option<String>,
617
    pub directives: SchemaDirectives,
618
}
619

620
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
621
pub struct Config {
622
    pub search_path: Vec<String>,
623
    pub role: String,
624
    pub schema_version: i32,
625
}
626

627
#[derive(Deserialize, Debug, Eq, PartialEq)]
628
pub struct Context {
629
    pub config: Config,
630
    pub schemas: HashMap<u32, Schema>,
631
    pub tables: HashMap<u32, Arc<Table>>,
632
    foreign_keys: Vec<Arc<ForeignKey>>,
633
    pub types: HashMap<u32, Arc<Type>>,
634
    pub enums: HashMap<u32, Arc<Enum>>,
635
    pub composites: Vec<Arc<Composite>>,
636
    pub functions: Vec<Arc<Function>>,
637
}
638

639
impl Hash for Context {
640
    fn hash<H: Hasher>(&self, state: &mut H) {
2,319✔
641
        // Only the config is needed to has ha Context
642
        self.config.hash(state);
2,319✔
643
    }
2,319✔
644
}
645

646
impl Context {
647
    /// Collect all foreign keys referencing (inbound or outbound) a table
648
    pub fn foreign_keys(&self) -> Vec<Arc<ForeignKey>> {
378✔
649
        let mut fkeys: Vec<Arc<ForeignKey>> = self.foreign_keys.clone();
378✔
650

651
        // Add foreign keys defined in comment directives
652
        for table in self.tables.values() {
660✔
653
            let directive_fkeys: Vec<TableDirectiveForeignKey> =
660✔
654
                match &table.directives.foreign_keys {
660✔
655
                    Some(keys) => keys.clone(),
28✔
656
                    None => vec![],
632✔
657
                };
658

659
            for directive_fkey in directive_fkeys.iter() {
660✔
660
                let referenced_t = match self.get_table_by_name(
28✔
661
                    &directive_fkey.foreign_schema,
28✔
662
                    &directive_fkey.foreign_table,
28✔
663
                ) {
28✔
664
                    Some(t) => t,
28✔
665
                    None => {
666
                        // No table found with requested name. Skip.
667
                        continue;
×
668
                    }
669
                };
670

671
                let referenced_t_column_names: HashSet<&String> =
28✔
672
                    referenced_t.columns.iter().map(|x| &x.name).collect();
80✔
673

674
                // Verify all foreign column references are valid
675
                if !directive_fkey
28✔
676
                    .foreign_columns
28✔
677
                    .iter()
28✔
678
                    .all(|col| referenced_t_column_names.contains(col))
28✔
679
                {
680
                    // Skip if invalid references exist
681
                    continue;
×
682
                }
28✔
683

684
                let fk = ForeignKey {
28✔
685
                    local_table_meta: ForeignKeyTableInfo {
28✔
686
                        oid: table.oid,
28✔
687
                        name: table.name.clone(),
28✔
688
                        schema: table.schema.clone(),
28✔
689
                        is_rls_enabled: table.is_rls_enabled,
28✔
690
                        column_names: directive_fkey.local_columns.clone(),
28✔
691
                    },
28✔
692
                    referenced_table_meta: ForeignKeyTableInfo {
28✔
693
                        oid: referenced_t.oid,
28✔
694
                        name: referenced_t.name.clone(),
28✔
695
                        schema: referenced_t.schema.clone(),
28✔
696
                        is_rls_enabled: table.is_rls_enabled,
28✔
697
                        column_names: directive_fkey.foreign_columns.clone(),
28✔
698
                    },
28✔
699
                    directives: ForeignKeyDirectives {
28✔
700
                        local_name: directive_fkey.local_name.clone(),
28✔
701
                        foreign_name: directive_fkey.foreign_name.clone(),
28✔
702
                    },
28✔
703
                };
28✔
704

705
                fkeys.push(Arc::new(fk));
28✔
706
            }
707
        }
708

709
        fkeys
378✔
710
            .into_iter()
378✔
711
            .filter(|fk| self.fkey_is_selectable(fk))
378✔
712
            .collect()
378✔
713
    }
378✔
714

715
    /// Check if a type is a composite type
716
    pub fn is_composite(&self, type_oid: u32) -> bool {
2,808✔
717
        self.composites.iter().any(|x| x.oid == type_oid)
2,808✔
718
    }
2,808✔
719

720
    pub fn get_table_by_name(
28✔
721
        &self,
28✔
722
        schema_name: &String,
28✔
723
        table_name: &String,
28✔
724
    ) -> Option<&Arc<Table>> {
28✔
725
        self.tables
28✔
726
            .values()
28✔
727
            .find(|x| &x.schema == schema_name && &x.name == table_name)
62✔
728
    }
28✔
729

730
    pub fn get_table_by_oid(&self, oid: u32) -> Option<&Arc<Table>> {
800✔
731
        self.tables.get(&oid)
800✔
732
    }
800✔
733

734
    /// Check if the local side of a foreign key is comprised of unique columns
735
    pub fn fkey_is_locally_unique(&self, fkey: &ForeignKey) -> bool {
128✔
736
        let table: &Arc<Table> = match self.get_table_by_oid(fkey.local_table_meta.oid) {
128✔
737
            Some(table) => table,
128✔
738
            None => {
739
                return false;
×
740
            }
741
        };
742

743
        let fkey_columns: HashSet<&String> = fkey.local_table_meta.column_names.iter().collect();
128✔
744

745
        for index in table.indexes.iter().filter(|x| x.is_unique) {
134✔
746
            let index_column_names: HashSet<&String> = index.column_names.iter().collect();
134✔
747

748
            if index_column_names
134✔
749
                .iter()
134✔
750
                .all(|col_name| fkey_columns.contains(col_name))
135✔
751
            {
752
                return true;
10✔
753
            }
124✔
754
        }
755
        false
118✔
756
    }
128✔
757

758
    /// Are both sides of the foreign key composed of selectable columns
759
    pub fn fkey_is_selectable(&self, fkey: &ForeignKey) -> bool {
218✔
760
        let local_table: &Arc<Table> = match self.get_table_by_oid(fkey.local_table_meta.oid) {
218✔
761
            Some(table) => table,
218✔
762
            None => {
763
                return false;
×
764
            }
765
        };
766

767
        let referenced_table: &Arc<Table> =
218✔
768
            match self.get_table_by_oid(fkey.referenced_table_meta.oid) {
218✔
769
                Some(table) => table,
218✔
770
                None => {
771
                    return false;
×
772
                }
773
            };
774

775
        let fkey_local_columns = &fkey.local_table_meta.column_names;
218✔
776
        let fkey_referenced_columns = &fkey.referenced_table_meta.column_names;
218✔
777

778
        let local_columns_selectable: HashSet<&String> = local_table
218✔
779
            .columns
218✔
780
            .iter()
218✔
781
            .filter(|x| x.permissions.is_selectable)
714✔
782
            .map(|col| &col.name)
708✔
783
            .collect();
218✔
784

785
        let referenced_columns_selectable: HashSet<&String> = referenced_table
218✔
786
            .columns
218✔
787
            .iter()
218✔
788
            .filter(|x| x.permissions.is_selectable)
1,096✔
789
            .map(|col| &col.name)
1,090✔
790
            .collect();
218✔
791

792
        fkey_local_columns
218✔
793
            .iter()
218✔
794
            .all(|col| local_columns_selectable.contains(col))
218✔
795
            && fkey_referenced_columns
218✔
796
                .iter()
218✔
797
                .all(|col| referenced_columns_selectable.contains(col))
218✔
798
    }
218✔
799
}
800

801
/// This method is similar to `Spi::get_one` with the only difference
802
/// being that it calls `client.select` instead of `client.update`.
803
/// The `client.update` method generates a new transaction id so
804
/// calling `Spi::get_one` is not possible when postgres is in
805
/// recovery mode.
806
pub(crate) fn get_one_readonly<A: FromDatum + IntoDatum>(
947✔
807
    query: &str,
947✔
808
) -> std::result::Result<Option<A>, pgrx::spi::Error> {
947✔
809
    Spi::connect(|client| client.select(query, Some(1), &[])?.first().get_one())
947✔
810
}
947✔
811

812
pub fn load_sql_config() -> Config {
651✔
813
    let query = include_str!("../sql/load_sql_config.sql");
651✔
814
    let sql_result: serde_json::Value = get_one_readonly::<JsonB>(query)
651✔
815
        .expect("failed to read sql config")
651✔
816
        .expect("sql config is missing")
651✔
817
        .0;
651✔
818
    let config: Config =
651✔
819
        serde_json::from_value(sql_result).expect("failed to convert sql config into json");
651✔
820
    config
651✔
821
}
651✔
822

823
pub fn calculate_hash<T: Hash>(t: &T) -> u64 {
5,521✔
824
    let mut s = DefaultHasher::new();
5,521✔
825
    t.hash(&mut s);
5,521✔
826
    s.finish()
5,521✔
827
}
5,521✔
828

829
#[cached(
830
    type = "SizedCache<u64, GraphQLResult<Arc<Context>>>",
831
    create = "{ SizedCache::with_size(250) }",
832
    convert = r#"{ calculate_hash(_config) }"#
833
)]
834
pub fn load_sql_context(_config: &Config) -> GraphQLResult<Arc<Context>> {
296✔
835
    // cache value for next query
836
    let query = include_str!("../sql/load_sql_context.sql");
296✔
837
    let sql_result: serde_json::Value = get_one_readonly::<JsonB>(query)
296✔
838
        .expect("failed to read sql context")
296✔
839
        .expect("sql context is missing")
296✔
840
        .0;
296✔
841
    let context: Result<Context, serde_json::Error> = serde_json::from_value(sql_result);
296✔
842

843
    /// This pass cross-reference types with its details
844
    fn type_details(mut context: Context) -> Context {
294✔
845
        let mut array_types = HashMap::new();
294✔
846
        // We process types to cross-reference their details
847
        for (oid, type_) in context.types.iter_mut() {
183,130✔
848
            if let Some(mtype) = Arc::get_mut(type_) {
183,130✔
849
                // It should be possible to get a mutable reference to type at this point
850
                // as there are no other references to this Arc at this point.
851

852
                // Depending on type's category, locate the appropriate piece of information about it
853
                mtype.details = match mtype.category {
183,130✔
854
                    TypeCategory::Enum => context.enums.get(oid).cloned().map(TypeDetails::Enum),
27✔
855
                    TypeCategory::Composite => context
1✔
856
                        .composites
1✔
857
                        .iter()
1✔
858
                        .find(|c| c.oid == *oid)
1✔
859
                        .cloned()
1✔
860
                        .map(TypeDetails::Composite),
1✔
861
                    TypeCategory::Table => context.tables.get(oid).cloned().map(TypeDetails::Table),
62,870✔
862
                    TypeCategory::Array => {
863
                        // We can't cross-reference with `context.types` here as it is already mutably borrowed,
864
                        // so we instead memorize the reference to process later
865
                        if let Some(element_oid) = mtype.array_element_type_oid {
87,596✔
866
                            array_types.insert(*oid, element_oid);
87,596✔
867
                        }
87,596✔
868
                        None
87,596✔
869
                    }
870
                    _ => None,
32,636✔
871
                };
872
            }
×
873
        }
874

875
        // Ensure the types are ordered so that we don't run into a situation where we can't
876
        // update the type anymore as it has been referenced but the type details weren't completed yet
877
        let referenced_types = array_types.values().copied().collect::<Vec<_>>();
294✔
878
        let mut ordered_types = array_types
294✔
879
            .iter()
294✔
880
            .map(|(k, v)| (*k, *v))
87,596✔
881
            .collect::<Vec<_>>();
294✔
882
        // We sort them by their presence in referencing. If the type has been referenced,
883
        // it should be at the top.
884
        ordered_types.sort_by(|(k1, _), (k2, _)| {
199,293✔
885
            if referenced_types.contains(k1) && referenced_types.contains(k2) {
199,293✔
886
                Ordering::Equal
420✔
887
            } else if referenced_types.contains(k1) {
198,873✔
888
                Ordering::Less
42,766✔
889
            } else {
890
                Ordering::Greater
156,107✔
891
            }
892
        });
199,293✔
893

894
        // Now we're ready to process array types
895
        for (array_oid, element_oid) in ordered_types {
87,596✔
896
            // We remove the element type from the map to ensure there is no mutability conflict for when
897
            // we get a mutable reference to the array type. We will put it back after we're done with it,
898
            // a few lines below.
899
            if let Some(element_t) = context.types.remove(&element_oid) {
87,596✔
900
                if let Some(array_t) = context.types.get_mut(&array_oid) {
87,596✔
901
                    if let Some(array) = Arc::get_mut(array_t) {
87,596✔
902
                        // It should be possible to get a mutable reference to type at this point
87,596✔
903
                        // as there are no other references to this Arc at this point.
87,596✔
904
                        array.details = Some(TypeDetails::Element(element_t.clone()));
87,596✔
905
                    } else {
87,596✔
906
                        // For some reason, we weren't able to get it. It means something have changed
907
                        // in our logic and we're presenting an assertion violation. Let's report it.
908
                        // It's a bug.
909
                        pgrx::warning!(
×
910
                            "Assertion violation: array type with OID {} is already referenced",
911
                            array_oid
912
                        );
913
                        continue;
×
914
                    }
915
                    // Put the element type back. NB: Very important to keep this line! It'll be used
916
                    // further down the loop. There is a check at the end of each loop's iteration that
917
                    // we actually did this. Being defensive.
918
                    context.types.insert(element_oid, element_t);
87,596✔
919
                } else {
920
                    // We weren't able to find the OID of the array, which is odd because we just got
921
                    // it from the context. This means we messed something up and it is a bug. Report it.
922
                    pgrx::warning!(
×
923
                        "Assertion violation: array type with OID {} is not found",
924
                        array_oid
925
                    );
926
                    continue;
×
927
                }
928
            } else {
929
                // We weren't able to find the OID of the element type, which is also odd because we just got
930
                // it from the context. This means it's a bug as well. Report it.
931
                pgrx::warning!(
×
932
                    "Assertion violation: referenced element type with OID {} of array type with OID {} is not found",
933
                    element_oid,
934
                    array_oid
935
                );
UNCOV
936
                continue;
×
937
            }
938

939
            // Here we are asserting that we did in fact return the element type back to the list. Part of being
940
            // defensive here.
941
            if !context.types.contains_key(&element_oid) {
87,596✔
NEW
942
                pgrx::warning!(
×
943
                    "Assertion violation: referenced element type with OID {} was not returned to the list of types",
944
                    element_oid
945
                );
UNCOV
946
                continue;
×
947
            }
87,596✔
948
        }
949

950
        context
294✔
951
    }
294✔
952

953
    /// This pass cross-reference column types
954
    fn column_types(mut context: Context) -> Context {
294✔
955
        // We process tables to cross-reference their columns' types
956
        for (_oid, table) in context.tables.iter_mut() {
294✔
957
            if let Some(mtable) = Arc::get_mut(table) {
244✔
958
                // It should be possible to get a mutable reference to table at this point
959
                // as there are no other references to this Arc at this point.
960

961
                // We will now iterate over columns
962
                for column in mtable.columns.iter_mut() {
820✔
963
                    if let Some(mcolumn) = Arc::get_mut(column) {
820✔
964
                        // It should be possible to get a mutable reference to column at this point
820✔
965
                        // as there are no other references to this Arc at this point.
820✔
966

820✔
967
                        // Find a matching type
820✔
968
                        mcolumn.type_ = context.types.get(&mcolumn.type_oid).cloned();
820✔
969
                    }
820✔
970
                }
971
            }
×
972
        }
973
        context
294✔
974
    }
294✔
975

976
    /// This pass populates functions for tables
977
    fn populate_table_functions(mut context: Context) -> Context {
294✔
978
        let mut arg_type_to_func: HashMap<u32, Vec<&Arc<Function>>> = HashMap::new();
294✔
979
        for function in context.functions.iter().filter(|f| f.num_args == 1) {
1,446✔
980
            let functions = arg_type_to_func.entry(function.arg_types[0]).or_default();
235✔
981
            functions.push(function);
235✔
982
        }
235✔
983
        for table in &mut context.tables.values_mut() {
294✔
984
            if let Some(table) = Arc::get_mut(table)
244✔
985
                && let Some(functions) = arg_type_to_func.get(&table.reltype)
244✔
986
            {
987
                for function in functions {
36✔
988
                    table.functions.push(Arc::clone(function));
36✔
989
                }
36✔
990
            }
209✔
991
        }
992
        context
294✔
993
    }
294✔
994

995
    context
296✔
996
        .map(type_details)
296✔
997
        .map(column_types)
296✔
998
        .map(populate_table_functions)
296✔
999
        .map(Arc::new)
296✔
1000
        .map_err(|e| {
296✔
1001
            crate::error::GraphQLError::schema(format!(
2✔
1002
                "Error while loading schema, check comment directives. {}",
1003
                e
1004
            ))
1005
        })
2✔
1006
}
296✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc