• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

supabase / pg_graphql / 21189676270

20 Jan 2026 10:30PM UTC coverage: 91.364% (+0.004%) from 91.36%
21189676270

Pull #590

github

web-flow
Merge b9b381aee into 0d7a66bfd
Pull Request #590: Add support for single record queries by primary key

258 of 288 new or added lines in 5 files covered. (89.58%)

9 existing lines in 1 file now uncovered.

7797 of 8534 relevant lines covered (91.36%)

1127.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.83
/src/sql_types.rs
1
use bimap::BiBTreeMap;
2
use cached::proc_macro::cached;
3
use cached::SizedCache;
4
use lazy_static::lazy_static;
5
use pgrx::*;
6
use serde::{Deserialize, Serialize};
7

8
use crate::error::GraphQLResult;
9
use std::cmp::Ordering;
10
use std::collections::hash_map::DefaultHasher;
11
use std::collections::{HashMap, HashSet};
12
use std::hash::{Hash, Hasher};
13
use std::sync::Arc;
14
use std::*;
15

16
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
17
pub struct ColumnPermissions {
18
    pub is_insertable: bool,
19
    pub is_selectable: bool,
20
    pub is_updatable: bool,
21
    // admin interface
22
    // alterable?
23
}
24

25
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
26
pub struct ColumnDirectives {
27
    pub name: Option<String>,
28
    pub description: Option<String>,
29
}
30

31
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
32
pub struct Column {
33
    pub name: String,
34
    pub type_oid: u32,
35
    pub type_: Option<Arc<Type>>,
36
    pub type_name: String,
37
    pub max_characters: Option<i32>,
38
    pub schema_oid: u32,
39
    pub is_not_null: bool,
40
    pub is_serial: bool,
41
    pub is_generated: bool,
42
    pub has_default: bool,
43
    pub attribute_num: i32,
44
    pub permissions: ColumnPermissions,
45
    pub comment: Option<String>,
46
    pub directives: ColumnDirectives,
47
}
48

49
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
50
pub struct FunctionDirectives {
51
    pub name: Option<String>,
52
    // @graphql({"description": "the address of ..." })
53
    pub description: Option<String>,
54
}
55

56
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
57
pub struct FunctionPermissions {
58
    pub is_executable: bool,
59
}
60

61
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
62
pub enum FunctionVolatility {
63
    #[serde(rename(deserialize = "v"))]
64
    Volatile,
65
    #[serde(rename(deserialize = "s"))]
66
    Stable,
67
    #[serde(rename(deserialize = "i"))]
68
    Immutable,
69
}
70

71
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
72
pub struct Function {
73
    pub oid: u32,
74
    pub name: String,
75
    pub schema_oid: u32,
76
    pub schema_name: String,
77
    pub arg_types: Vec<u32>,
78
    pub arg_names: Option<Vec<String>>,
79
    pub arg_defaults: Option<String>,
80
    pub num_args: u32,
81
    pub num_default_args: u32,
82
    pub arg_type_names: Vec<String>,
83
    pub volatility: FunctionVolatility,
84
    pub type_oid: u32,
85
    pub type_name: String,
86
    pub is_set_of: bool,
87
    pub comment: Option<String>,
88
    pub directives: FunctionDirectives,
89
    pub permissions: FunctionPermissions,
90
}
91

92
impl Function {
93
    pub fn args(&self) -> impl Iterator<Item = (u32, &str, Option<&str>, Option<DefaultValue>)> {
5,358✔
94
        ArgsIterator::new(
5,358✔
95
            &self.arg_types,
5,358✔
96
            &self.arg_type_names,
5,358✔
97
            &self.arg_names,
5,358✔
98
            &self.arg_defaults,
5,358✔
99
            self.num_default_args,
5,358✔
100
        )
101
    }
5,358✔
102

103
    pub fn function_names_to_count(all_functions: &[Arc<Function>]) -> HashMap<&String, u32> {
547✔
104
        let mut function_name_to_count = HashMap::new();
547✔
105
        for function_name in all_functions.iter().map(|f| &f.name) {
2,233✔
106
            let entry = function_name_to_count.entry(function_name).or_insert(0u32);
2,233✔
107
            *entry += 1;
2,233✔
108
        }
2,233✔
109
        function_name_to_count
547✔
110
    }
547✔
111

112
    pub fn is_supported(
2,233✔
113
        &self,
2,233✔
114
        context: &Context,
2,233✔
115
        function_name_to_count: &HashMap<&String, u32>,
2,233✔
116
    ) -> bool {
2,233✔
117
        let types = &context.types;
2,233✔
118
        self.return_type_is_supported(types)
2,233✔
119
            && self.arg_types_are_supported(types)
2,197✔
120
            && !self.is_function_overloaded(function_name_to_count)
2,086✔
121
            && !self.has_a_nameless_arg()
1,906✔
122
            && self.permissions.is_executable
1,694✔
123
            && !self.is_in_a_system_schema()
1,691✔
124
    }
2,233✔
125

126
    fn arg_types_are_supported(&self, types: &HashMap<u32, Arc<Type>>) -> bool {
2,197✔
127
        self.args().all(|(arg_type, _, _, _)| {
3,407✔
128
            if let Some(arg_type) = types.get(&arg_type) {
3,407✔
129
                let array_element_type_is_supported = self.array_element_type_is_supported(
3,407✔
130
                    &arg_type.category,
3,407✔
131
                    arg_type.array_element_type_oid,
3,407✔
132
                    types,
3,407✔
133
                );
134
                arg_type.category == TypeCategory::Other
3,407✔
135
                    || (arg_type.category == TypeCategory::Array && array_element_type_is_supported)
790✔
136
            } else {
137
                false
×
138
            }
139
        })
3,407✔
140
    }
2,197✔
141

142
    fn return_type_is_supported(&self, types: &HashMap<u32, Arc<Type>>) -> bool {
2,233✔
143
        if let Some(return_type) = types.get(&self.type_oid) {
2,233✔
144
            let array_element_type_is_supported = self.array_element_type_is_supported(
2,233✔
145
                &return_type.category,
2,233✔
146
                return_type.array_element_type_oid,
2,233✔
147
                types,
2,233✔
148
            );
149
            return_type.category != TypeCategory::Pseudo
2,233✔
150
                && return_type.category != TypeCategory::Enum
2,233✔
151
                && return_type.name != "record"
2,228✔
152
                && return_type.name != "trigger"
2,213✔
153
                && return_type.name != "event_trigger"
2,209✔
154
                && array_element_type_is_supported
2,202✔
155
        } else {
156
            false
×
157
        }
158
    }
2,233✔
159

160
    fn array_element_type_is_supported(
5,640✔
161
        &self,
5,640✔
162
        type_category: &TypeCategory,
5,640✔
163
        array_element_type_oid: Option<u32>,
5,640✔
164
        types: &HashMap<u32, Arc<Type>>,
5,640✔
165
    ) -> bool {
5,640✔
166
        if *type_category == TypeCategory::Array {
5,640✔
167
            if let Some(array_element_type_oid) = array_element_type_oid {
1,002✔
168
                if let Some(array_element_type) = types.get(&array_element_type_oid) {
1,002✔
169
                    array_element_type.category == TypeCategory::Other
1,002✔
170
                } else {
171
                    false
×
172
                }
173
            } else {
174
                false
×
175
            }
176
        } else {
177
            true
4,638✔
178
        }
179
    }
5,640✔
180

181
    fn is_function_overloaded(&self, function_name_to_count: &HashMap<&String, u32>) -> bool {
2,086✔
182
        if let Some(&count) = function_name_to_count.get(&self.name) {
2,086✔
183
            count > 1
2,086✔
184
        } else {
185
            false
×
186
        }
187
    }
2,086✔
188

189
    fn has_a_nameless_arg(&self) -> bool {
1,906✔
190
        self.args().any(|(_, _, arg_name, _)| arg_name.is_none())
2,798✔
191
    }
1,906✔
192

193
    fn is_in_a_system_schema(&self) -> bool {
1,691✔
194
        // These are default schemas in supabase configuration
195
        let system_schemas = &["graphql", "graphql_public", "auth", "extensions"];
1,691✔
196
        system_schemas.contains(&self.schema_name.as_str())
1,691✔
197
    }
1,691✔
198
}
199

200
struct ArgsIterator<'a> {
201
    index: usize,
202
    arg_types: &'a [u32],
203
    arg_type_names: &'a Vec<String>,
204
    arg_names: &'a Option<Vec<String>>,
205
    arg_defaults: Vec<Option<DefaultValue>>,
206
}
207

208
#[derive(Clone)]
209
pub enum DefaultValue {
210
    NonNull(String),
211
    Null,
212
}
213

214
impl<'a> ArgsIterator<'a> {
215
    fn new(
5,358✔
216
        arg_types: &'a [u32],
5,358✔
217
        arg_type_names: &'a Vec<String>,
5,358✔
218
        arg_names: &'a Option<Vec<String>>,
5,358✔
219
        arg_defaults: &'a Option<String>,
5,358✔
220
        num_default_args: u32,
5,358✔
221
    ) -> ArgsIterator<'a> {
5,358✔
222
        ArgsIterator {
5,358✔
223
            index: 0,
5,358✔
224
            arg_types,
5,358✔
225
            arg_type_names,
5,358✔
226
            arg_names,
5,358✔
227
            arg_defaults: Self::defaults(
5,358✔
228
                arg_types,
5,358✔
229
                arg_defaults,
5,358✔
230
                num_default_args as usize,
5,358✔
231
                arg_types.len(),
5,358✔
232
            ),
5,358✔
233
        }
5,358✔
234
    }
5,358✔
235

236
    fn defaults(
5,358✔
237
        arg_types: &'a [u32],
5,358✔
238
        arg_defaults: &'a Option<String>,
5,358✔
239
        num_default_args: usize,
5,358✔
240
        num_total_args: usize,
5,358✔
241
    ) -> Vec<Option<DefaultValue>> {
5,358✔
242
        let mut defaults = vec![None; num_total_args];
5,358✔
243
        let Some(arg_defaults) = arg_defaults else {
5,358✔
244
            return defaults;
5,233✔
245
        };
246

247
        if num_default_args == 0 {
125✔
248
            return defaults;
×
249
        }
125✔
250

251
        let default_strs: Vec<&str> = arg_defaults.split(',').collect();
125✔
252

253
        if default_strs.len() != num_default_args {
125✔
254
            return defaults;
13✔
255
        }
112✔
256

257
        debug_assert!(num_default_args <= num_total_args);
112✔
258
        let start_idx = num_total_args - num_default_args;
112✔
259
        for i in start_idx..num_total_args {
471✔
260
            defaults[i] = Self::sql_to_graphql_default(default_strs[i - start_idx], arg_types[i])
471✔
261
        }
262

263
        defaults
112✔
264
    }
5,358✔
265

266
    fn sql_to_graphql_default(default_str: &str, type_oid: u32) -> Option<DefaultValue> {
471✔
267
        let trimmed = default_str.trim();
471✔
268

269
        if trimmed.starts_with("NULL::") {
471✔
270
            return Some(DefaultValue::Null);
318✔
271
        }
153✔
272

273
        let res = match type_oid {
153✔
274
            21 | 23 => trimmed
72✔
275
                .parse::<i32>()
72✔
276
                .ok()
72✔
277
                .map(|i| DefaultValue::NonNull(i.to_string())),
72✔
278
            16 => trimmed
11✔
279
                .parse::<bool>()
11✔
280
                .ok()
11✔
281
                .map(|i| DefaultValue::NonNull(i.to_string())),
11✔
282
            700 | 701 => trimmed
22✔
283
                .parse::<f64>()
22✔
284
                .ok()
22✔
285
                .map(|i| DefaultValue::NonNull(i.to_string())),
22✔
286
            25 => trimmed.strip_suffix("::text").map(|i| {
26✔
287
                DefaultValue::NonNull(format!("\"{}\"", i.trim_matches(',').trim_matches('\'')))
16✔
288
            }),
16✔
289
            _ => None,
22✔
290
        };
291

292
        // return the non-parsed value as default if for whatever reason the default value can't
293
        // be parsed into a value of the required type. This fixes problems where the default
294
        // is a complex expression like a function call etc. See test/sql/issue_533.sql for
295
        // a test case for this scenario.
296
        if res.is_some() {
153✔
297
            res
101✔
298
        } else {
299
            Some(DefaultValue::Null)
52✔
300
        }
301
    }
471✔
302
}
303

304
lazy_static! {
305
    static ref TEXT_TYPE: String = "text".to_string();
306
}
307

308
impl<'a> Iterator for ArgsIterator<'a> {
309
    type Item = (u32, &'a str, Option<&'a str>, Option<DefaultValue>);
310

311
    fn next(&mut self) -> Option<Self::Item> {
13,219✔
312
        if self.index < self.arg_types.len() {
13,219✔
313
            debug_assert!(self.arg_types.len() == self.arg_type_names.len());
8,184✔
314
            let arg_name = if let Some(arg_names) = self.arg_names {
8,184✔
315
                debug_assert!(arg_names.len() >= self.arg_types.len());
7,423✔
316
                let arg_name = arg_names[self.index].as_str();
7,423✔
317
                if !arg_name.is_empty() {
7,423✔
318
                    Some(arg_name)
7,423✔
319
                } else {
320
                    None
×
321
                }
322
            } else {
323
                None
761✔
324
            };
325
            let arg_type = self.arg_types[self.index];
8,184✔
326
            let mut arg_type_name = &self.arg_type_names[self.index];
8,184✔
327
            if arg_type_name == "character" {
8,184✔
328
                arg_type_name = &TEXT_TYPE;
86✔
329
            }
8,098✔
330
            let arg_default = self.arg_defaults[self.index].clone();
8,184✔
331
            self.index += 1;
8,184✔
332
            Some((arg_type, arg_type_name, arg_name, arg_default))
8,184✔
333
        } else {
334
            None
5,035✔
335
        }
336
    }
13,219✔
337
}
338

339
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
340
pub struct TablePermissions {
341
    pub is_insertable: bool,
342
    pub is_selectable: bool,
343
    pub is_updatable: bool,
344
    pub is_deletable: bool,
345
}
346

347
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
348
pub struct TypePermissions {
349
    pub is_usable: bool,
350
}
351

352
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
353
pub enum TypeCategory {
354
    Enum,
355
    Composite,
356
    Table,
357
    Array,
358
    Pseudo,
359
    Other,
360
}
361

362
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
363
pub struct Type {
364
    pub oid: u32,
365
    pub schema_oid: u32,
366
    pub name: String,
367
    pub category: TypeCategory,
368
    pub array_element_type_oid: Option<u32>,
369
    pub table_oid: Option<u32>,
370
    pub comment: Option<String>,
371
    pub permissions: TypePermissions,
372
    pub details: Option<TypeDetails>,
373
}
374

375
// `TypeDetails` derives `Deserialized` but is not expected to come
376
// from the SQL context. Instead, it is populated in a separate pass.
377
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
378
pub enum TypeDetails {
379
    Enum(Arc<Enum>),
380
    Composite(Arc<Composite>),
381
    Table(Arc<Table>),
382
    Element(Arc<Type>),
383
}
384

385
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
386
pub struct EnumValue {
387
    pub oid: u32,
388
    pub name: String,
389
    pub sort_order: i32,
390
}
391

392
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
393
pub struct Enum {
394
    pub oid: u32,
395
    pub schema_oid: u32,
396
    pub name: String,
397
    pub values: Vec<EnumValue>,
398
    pub array_element_type_oid: Option<u32>,
399
    pub comment: Option<String>,
400
    pub permissions: TypePermissions,
401
    pub directives: EnumDirectives,
402
}
403

404
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
405
pub struct EnumDirectives {
406
    pub name: Option<String>,
407
    pub mappings: Option<BiBTreeMap<String, String>>,
408
}
409

410
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
411
pub struct Composite {
412
    pub oid: u32,
413
    pub schema_oid: u32,
414
}
415

416
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
417
pub struct Index {
418
    pub table_oid: u32,
419
    pub column_names: Vec<String>,
420
    pub is_unique: bool,
421
    pub is_primary_key: bool,
422
}
423

424
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
425
pub struct ForeignKeyTableInfo {
426
    pub oid: u32,
427
    // The table's actual name
428
    pub name: String,
429
    pub schema: String,
430
    pub is_rls_enabled: bool,
431
    pub column_names: Vec<String>,
432
}
433

434
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
435
pub struct ForeignKeyDirectives {
436
    pub local_name: Option<String>,
437
    pub foreign_name: Option<String>,
438
}
439

440
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
441
pub struct ForeignKey {
442
    pub directives: ForeignKeyDirectives,
443
    pub local_table_meta: ForeignKeyTableInfo,
444
    pub referenced_table_meta: ForeignKeyTableInfo,
445
}
446

447
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
448
pub struct TableDirectiveTotalCount {
449
    pub enabled: bool,
450
}
451

452
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
453
pub struct TableDirectiveAggregate {
454
    pub enabled: bool,
455
}
456

457
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
458
pub struct TableDirectiveForeignKey {
459
    // Equivalent to ForeignKeyDirectives.local_name
460
    pub local_name: Option<String>,
461
    pub local_columns: Vec<String>,
462

463
    // Equivalent to ForeignKeyDirectives.foreign_name
464
    pub foreign_name: Option<String>,
465
    pub foreign_schema: String,
466
    pub foreign_table: String,
467
    pub foreign_columns: Vec<String>,
468
}
469

470
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
471
pub struct TableDirectives {
472
    // @graphql({"name": "Foo" })
473
    pub name: Option<String>,
474

475
    // @graphql({"description": "the address of ..." })
476
    pub description: Option<String>,
477

478
    // @graphql({"totalCount": { "enabled": true } })
479
    pub total_count: Option<TableDirectiveTotalCount>,
480

481
    // @graphql({"aggregate": { "enabled": true } })
482
    pub aggregate: Option<TableDirectiveAggregate>,
483

484
    // @graphql({"primary_key_columns": ["id"]})
485
    pub primary_key_columns: Option<Vec<String>>,
486

487
    // @graphql({"max_rows": 20})
488
    pub max_rows: Option<u64>,
489

490
    /*
491
    @graphql(
492
      {
493
        "foreign_keys": [
494
          {
495
            <REQUIRED>
496
            "local_columns": ["account_id"],
497
            "foriegn_schema": "public",
498
            "foriegn_table": "account",
499
            "foriegn_columns": ["id"],
500

501
            <OPTIONAL>
502
            "local_name": "foo",
503
            "foreign_name": "bar",
504
          },
505
        ]
506
      }
507
    )
508
    */
509
    pub foreign_keys: Option<Vec<TableDirectiveForeignKey>>,
510
}
511

512
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
513
pub struct Table {
514
    pub oid: u32,
515
    pub name: String,
516
    pub schema_oid: u32,
517
    pub schema: String,
518
    pub columns: Vec<Arc<Column>>,
519
    pub comment: Option<String>,
520
    pub is_rls_enabled: bool,
521
    pub relkind: String, // r = table, v = view, m = mat view, f = foreign table
522
    pub reltype: u32,
523
    pub permissions: TablePermissions,
524
    pub indexes: Vec<Index>,
525
    #[serde(default)]
526
    pub functions: Vec<Arc<Function>>,
527
    pub directives: TableDirectives,
528
}
529

530
impl Table {
531
    pub fn primary_key(&self) -> Option<Index> {
4,537✔
532
        let real_pkey = self.indexes.iter().find(|x| x.is_primary_key);
4,537✔
533

534
        if real_pkey.is_some() {
4,537✔
535
            return real_pkey.cloned();
4,277✔
536
        }
260✔
537

538
        // Check for a primary key definition in comment directives
539
        if let Some(column_names) = &self.directives.primary_key_columns {
260✔
540
            // validate that columns exist on the table
541
            let mut valid_column_names: Vec<&String> = vec![];
239✔
542
            for column_name in column_names {
478✔
543
                for column in &self.columns {
713✔
544
                    if column_name == &column.name {
474✔
545
                        valid_column_names.push(&column.name);
239✔
546
                    }
239✔
547
                }
548
            }
549
            if valid_column_names.len() != column_names.len() {
239✔
550
                // At least one of the column names didn't exist on the table
551
                // so the primary key directive is not valid
552
                // Ideally we'd throw an error here instead
553
                None
×
554
            } else {
555
                Some(Index {
239✔
556
                    table_oid: self.oid,
239✔
557
                    column_names: column_names.clone(),
239✔
558
                    is_unique: true,
239✔
559
                    is_primary_key: true,
239✔
560
                })
239✔
561
            }
562
        } else {
563
            None
21✔
564
        }
565
    }
4,537✔
566

567
    pub fn primary_key_columns(&self) -> Vec<&Arc<Column>> {
1,526✔
568
        self.primary_key()
1,526✔
569
            .map(|x| x.column_names)
1,526✔
570
            .unwrap_or_default()
1,526✔
571
            .iter()
1,526✔
572
            .map(|col_name| {
1,544✔
573
                self.columns
1,544✔
574
                    .iter()
1,544✔
575
                    .find(|col| &col.name == col_name)
1,562✔
576
                    .expect("Failed to unwrap pkey by column names")
1,544✔
577
            })
1,544✔
578
            .collect::<Vec<&Arc<Column>>>()
1,526✔
579
    }
1,526✔
580

581
    pub fn has_supported_pk_types_for_by_pk(&self) -> bool {
261✔
582
        let pk_columns = self.primary_key_columns();
261✔
583
        if pk_columns.is_empty() {
261✔
NEW
584
            return false;
×
585
        }
261✔
586

587
        // Check that all primary key columns have supported types
588
        pk_columns.iter().all(|col| {
267✔
589
            SupportedPrimaryKeyType::from_type_name(&col.type_name).is_some()
267✔
590
        })
267✔
591
    }
261✔
592

593
    pub fn is_any_column_selectable(&self) -> bool {
1,899✔
594
        self.columns.iter().any(|x| x.permissions.is_selectable)
1,919✔
595
    }
1,899✔
596

597
    pub fn is_any_column_insertable(&self) -> bool {
406✔
598
        self.columns.iter().any(|x| x.permissions.is_insertable)
415✔
599
    }
406✔
600

601
    pub fn is_any_column_updatable(&self) -> bool {
406✔
602
        self.columns.iter().any(|x| x.permissions.is_updatable)
417✔
603
    }
406✔
604

605
    /// Get the effective max_rows value for this table.
606
    /// If table-specific max_rows is set, use that.
607
    /// Otherwise, fall back to schema-level max_rows.
608
    /// If neither is set, use the global default(set in load_sql_context.sql)
609
    pub fn max_rows(&self, schema: &Schema) -> u64 {
309✔
610
        self.directives
309✔
611
            .max_rows
309✔
612
            .unwrap_or(schema.directives.max_rows)
309✔
613
    }
309✔
614
}
615

616
#[derive(Debug, PartialEq)]
617
pub enum SupportedPrimaryKeyType {
618
    // Integer types
619
    Int,      // int, int4, integer
620
    BigInt,   // bigint, int8
621
    SmallInt, // smallint, int2
622
    // String types
623
    Text,     // text
624
    VarChar,  // varchar
625
    Char,     // char, bpchar
626
    CiText,   // citext
627
    // UUID
628
    UUID,     // uuid
629
}
630

631
impl SupportedPrimaryKeyType {
632
    fn from_type_name(type_name: &str) -> Option<Self> {
267✔
633
        match type_name {
267✔
634
            // Integer types
635
            "int" | "int4" | "integer" => Some(Self::Int),
267✔
636
            "bigint" | "int8" => Some(Self::BigInt),
19✔
637
            "smallint" | "int2" => Some(Self::SmallInt),
17✔
638
            // String types
639
            "text" => Some(Self::Text),
17✔
640
            "varchar" => Some(Self::VarChar),
12✔
641
            "char" | "bpchar" => Some(Self::Char),
12✔
642
            "citext" => Some(Self::CiText),
12✔
643
            // UUID
644
            "uuid" => Some(Self::UUID),
12✔
645
            // Any other type is not supported
NEW
646
            _ => None,
×
647
        }
648
    }
267✔
649
}
650

651
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
652
pub struct SchemaDirectives {
653
    // @graphql({"inflect_names": true})
654
    pub inflect_names: bool,
655
    // @graphql({"max_rows": 20})
656
    pub max_rows: u64,
657
}
658

659
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
660
pub struct Schema {
661
    pub oid: u32,
662
    pub name: String,
663
    pub comment: Option<String>,
664
    pub directives: SchemaDirectives,
665
}
666

667
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
668
pub struct Config {
669
    pub search_path: Vec<String>,
670
    pub role: String,
671
    pub schema_version: i32,
672
}
673

674
#[derive(Deserialize, Debug, Eq, PartialEq)]
675
pub struct Context {
676
    pub config: Config,
677
    pub schemas: HashMap<u32, Schema>,
678
    pub tables: HashMap<u32, Arc<Table>>,
679
    foreign_keys: Vec<Arc<ForeignKey>>,
680
    pub types: HashMap<u32, Arc<Type>>,
681
    pub enums: HashMap<u32, Arc<Enum>>,
682
    pub composites: Vec<Arc<Composite>>,
683
    pub functions: Vec<Arc<Function>>,
684
}
685

686
impl Hash for Context {
687
    fn hash<H: Hasher>(&self, state: &mut H) {
2,440✔
688
        // Only the config is needed to has ha Context
689
        self.config.hash(state);
2,440✔
690
    }
2,440✔
691
}
692

693
impl Context {
694
    /// Collect all foreign keys referencing (inbound or outbound) a table
695
    pub fn foreign_keys(&self) -> Vec<Arc<ForeignKey>> {
402✔
696
        let mut fkeys: Vec<Arc<ForeignKey>> = self.foreign_keys.clone();
402✔
697

698
        // Add foreign keys defined in comment directives
699
        for table in self.tables.values() {
768✔
700
            let directive_fkeys: Vec<TableDirectiveForeignKey> =
768✔
701
                match &table.directives.foreign_keys {
768✔
702
                    Some(keys) => keys.clone(),
28✔
703
                    None => vec![],
740✔
704
                };
705

706
            for directive_fkey in directive_fkeys.iter() {
768✔
707
                let referenced_t = match self.get_table_by_name(
28✔
708
                    &directive_fkey.foreign_schema,
28✔
709
                    &directive_fkey.foreign_table,
28✔
710
                ) {
28✔
711
                    Some(t) => t,
28✔
712
                    None => {
713
                        // No table found with requested name. Skip.
714
                        continue;
×
715
                    }
716
                };
717

718
                let referenced_t_column_names: HashSet<&String> =
28✔
719
                    referenced_t.columns.iter().map(|x| &x.name).collect();
80✔
720

721
                // Verify all foreign column references are valid
722
                if !directive_fkey
28✔
723
                    .foreign_columns
28✔
724
                    .iter()
28✔
725
                    .all(|col| referenced_t_column_names.contains(col))
28✔
726
                {
727
                    // Skip if invalid references exist
728
                    continue;
×
729
                }
28✔
730

731
                let fk = ForeignKey {
28✔
732
                    local_table_meta: ForeignKeyTableInfo {
28✔
733
                        oid: table.oid,
28✔
734
                        name: table.name.clone(),
28✔
735
                        schema: table.schema.clone(),
28✔
736
                        is_rls_enabled: table.is_rls_enabled,
28✔
737
                        column_names: directive_fkey.local_columns.clone(),
28✔
738
                    },
28✔
739
                    referenced_table_meta: ForeignKeyTableInfo {
28✔
740
                        oid: referenced_t.oid,
28✔
741
                        name: referenced_t.name.clone(),
28✔
742
                        schema: referenced_t.schema.clone(),
28✔
743
                        is_rls_enabled: table.is_rls_enabled,
28✔
744
                        column_names: directive_fkey.foreign_columns.clone(),
28✔
745
                    },
28✔
746
                    directives: ForeignKeyDirectives {
28✔
747
                        local_name: directive_fkey.local_name.clone(),
28✔
748
                        foreign_name: directive_fkey.foreign_name.clone(),
28✔
749
                    },
28✔
750
                };
28✔
751

752
                fkeys.push(Arc::new(fk));
28✔
753
            }
754
        }
755

756
        fkeys
402✔
757
            .into_iter()
402✔
758
            .filter(|fk| self.fkey_is_selectable(fk))
402✔
759
            .collect()
402✔
760
    }
402✔
761

762
    /// Check if a type is a composite type
763
    pub fn is_composite(&self, type_oid: u32) -> bool {
2,860✔
764
        self.composites.iter().any(|x| x.oid == type_oid)
2,860✔
765
    }
2,860✔
766

767
    pub fn get_table_by_name(
28✔
768
        &self,
28✔
769
        schema_name: &String,
28✔
770
        table_name: &String,
28✔
771
    ) -> Option<&Arc<Table>> {
28✔
772
        self.tables
28✔
773
            .values()
28✔
774
            .find(|x| &x.schema == schema_name && &x.name == table_name)
48✔
775
    }
28✔
776

777
    pub fn get_table_by_oid(&self, oid: u32) -> Option<&Arc<Table>> {
866✔
778
        self.tables.get(&oid)
866✔
779
    }
866✔
780

781
    /// Check if the local side of a foreign key is comprised of unique columns
782
    pub fn fkey_is_locally_unique(&self, fkey: &ForeignKey) -> bool {
140✔
783
        let table: &Arc<Table> = match self.get_table_by_oid(fkey.local_table_meta.oid) {
140✔
784
            Some(table) => table,
140✔
785
            None => {
786
                return false;
×
787
            }
788
        };
789

790
        let fkey_columns: HashSet<&String> = fkey.local_table_meta.column_names.iter().collect();
140✔
791

792
        for index in table.indexes.iter().filter(|x| x.is_unique) {
146✔
793
            let index_column_names: HashSet<&String> = index.column_names.iter().collect();
146✔
794

795
            if index_column_names
146✔
796
                .iter()
146✔
797
                .all(|col_name| fkey_columns.contains(col_name))
147✔
798
            {
799
                return true;
10✔
800
            }
136✔
801
        }
802
        false
130✔
803
    }
140✔
804

805
    /// Are both sides of the foreign key composed of selectable columns
806
    pub fn fkey_is_selectable(&self, fkey: &ForeignKey) -> bool {
236✔
807
        let local_table: &Arc<Table> = match self.get_table_by_oid(fkey.local_table_meta.oid) {
236✔
808
            Some(table) => table,
236✔
809
            None => {
810
                return false;
×
811
            }
812
        };
813

814
        let referenced_table: &Arc<Table> =
236✔
815
            match self.get_table_by_oid(fkey.referenced_table_meta.oid) {
236✔
816
                Some(table) => table,
236✔
817
                None => {
818
                    return false;
×
819
                }
820
            };
821

822
        let fkey_local_columns = &fkey.local_table_meta.column_names;
236✔
823
        let fkey_referenced_columns = &fkey.referenced_table_meta.column_names;
236✔
824

825
        let local_columns_selectable: HashSet<&String> = local_table
236✔
826
            .columns
236✔
827
            .iter()
236✔
828
            .filter(|x| x.permissions.is_selectable)
768✔
829
            .map(|col| &col.name)
762✔
830
            .collect();
236✔
831

832
        let referenced_columns_selectable: HashSet<&String> = referenced_table
236✔
833
            .columns
236✔
834
            .iter()
236✔
835
            .filter(|x| x.permissions.is_selectable)
1,132✔
836
            .map(|col| &col.name)
1,126✔
837
            .collect();
236✔
838

839
        fkey_local_columns
236✔
840
            .iter()
236✔
841
            .all(|col| local_columns_selectable.contains(col))
236✔
842
            && fkey_referenced_columns
236✔
843
                .iter()
236✔
844
                .all(|col| referenced_columns_selectable.contains(col))
236✔
845
    }
236✔
846
}
847

848
/// This method is similar to `Spi::get_one` with the only difference
849
/// being that it calls `client.select` instead of `client.update`.
850
/// The `client.update` method generates a new transaction id so
851
/// calling `Spi::get_one` is not possible when postgres is in
852
/// recovery mode.
853
pub(crate) fn get_one_readonly<A: FromDatum + IntoDatum>(
982✔
854
    query: &str,
982✔
855
) -> std::result::Result<Option<A>, pgrx::spi::Error> {
982✔
856
    Spi::connect(|client| client.select(query, Some(1), &[])?.first().get_one())
982✔
857
}
982✔
858

859
pub fn load_sql_config() -> Config {
681✔
860
    let query = include_str!("../sql/load_sql_config.sql");
681✔
861
    let sql_result: serde_json::Value = get_one_readonly::<JsonB>(query)
681✔
862
        .expect("failed to read sql config")
681✔
863
        .expect("sql config is missing")
681✔
864
        .0;
681✔
865
    let config: Config =
681✔
866
        serde_json::from_value(sql_result).expect("failed to convert sql config into json");
681✔
867
    config
681✔
868
}
681✔
869

870
pub fn calculate_hash<T: Hash>(t: &T) -> u64 {
5,694✔
871
    let mut s = DefaultHasher::new();
5,694✔
872
    t.hash(&mut s);
5,694✔
873
    s.finish()
5,694✔
874
}
5,694✔
875

876
#[cached(
877
    type = "SizedCache<u64, GraphQLResult<Arc<Context>>>",
878
    create = "{ SizedCache::with_size(250) }",
879
    convert = r#"{ calculate_hash(_config) }"#
880
)]
881
pub fn load_sql_context(_config: &Config) -> GraphQLResult<Arc<Context>> {
301✔
882
    // cache value for next query
883
    let query = include_str!("../sql/load_sql_context.sql");
301✔
884
    let sql_result: serde_json::Value = get_one_readonly::<JsonB>(query)
301✔
885
        .expect("failed to read sql context")
301✔
886
        .expect("sql context is missing")
301✔
887
        .0;
301✔
888
    let context: Result<Context, serde_json::Error> = serde_json::from_value(sql_result);
301✔
889

890
    /// This pass cross-reference types with its details
891
    fn type_details(mut context: Context) -> Context {
299✔
892
        let mut array_types = HashMap::new();
299✔
893
        // We process types to cross-reference their details
894
        for (oid, type_) in context.types.iter_mut() {
186,281✔
895
            if let Some(mtype) = Arc::get_mut(type_) {
186,281✔
896
                // It should be possible to get a mutable reference to type at this point
897
                // as there are no other references to this Arc at this point.
898

899
                // Depending on type's category, locate the appropriate piece of information about it
900
                mtype.details = match mtype.category {
186,281✔
901
                    TypeCategory::Enum => context.enums.get(oid).cloned().map(TypeDetails::Enum),
27✔
902
                    TypeCategory::Composite => context
1✔
903
                        .composites
1✔
904
                        .iter()
1✔
905
                        .find(|c| c.oid == *oid)
1✔
906
                        .cloned()
1✔
907
                        .map(TypeDetails::Composite),
1✔
908
                    TypeCategory::Table => context.tables.get(oid).cloned().map(TypeDetails::Table),
63,958✔
909
                    TypeCategory::Array => {
910
                        // We can't cross-reference with `context.types` here as it is already mutably borrowed,
911
                        // so we instead memorize the reference to process later
912
                        if let Some(element_oid) = mtype.array_element_type_oid {
89,104✔
913
                            array_types.insert(*oid, element_oid);
89,104✔
914
                        }
89,104✔
915
                        None
89,104✔
916
                    }
917
                    _ => None,
33,191✔
918
                };
919
            }
×
920
        }
921

922
        // Ensure the types are ordered so that we don't run into a situation where we can't
923
        // update the type anymore as it has been referenced but the type details weren't completed yet
924
        let referenced_types = array_types.values().copied().collect::<Vec<_>>();
299✔
925
        let mut ordered_types = array_types
299✔
926
            .iter()
299✔
927
            .map(|(k, v)| (*k, *v))
89,104✔
928
            .collect::<Vec<_>>();
299✔
929
        // We sort them by their presence in referencing. If the type has been referenced,
930
        // it should be at the top.
931
        ordered_types.sort_by(|(k1, _), (k2, _)| {
202,398✔
932
            if referenced_types.contains(k1) && referenced_types.contains(k2) {
202,398✔
933
                Ordering::Equal
429✔
934
            } else if referenced_types.contains(k1) {
201,969✔
935
                Ordering::Less
44,557✔
936
            } else {
937
                Ordering::Greater
157,412✔
938
            }
939
        });
202,398✔
940

941
        // Now we're ready to process array types
942
        for (array_oid, element_oid) in ordered_types {
89,403✔
943
            // We remove the element type from the map to ensure there is no mutability conflict for when
944
            // we get a mutable reference to the array type. We will put it back after we're done with it,
945
            // a few lines below.
946
            if let Some(element_t) = context.types.remove(&element_oid) {
89,104✔
947
                if let Some(array_t) = context.types.get_mut(&array_oid) {
89,104✔
948
                    if let Some(array) = Arc::get_mut(array_t) {
89,104✔
949
                        // It should be possible to get a mutable reference to type at this point
89,104✔
950
                        // as there are no other references to this Arc at this point.
89,104✔
951
                        array.details = Some(TypeDetails::Element(element_t.clone()));
89,104✔
952
                    } else {
89,104✔
953
                        // For some reason, we weren't able to get it. It means something have changed
954
                        // in our logic and we're presenting an assertion violation. Let's report it.
955
                        // It's a bug.
956
                        pgrx::warning!(
×
957
                            "Assertion violation: array type with OID {} is already referenced",
×
958
                            array_oid
959
                        );
960
                        continue;
×
961
                    }
962
                    // Put the element type back. NB: Very important to keep this line! It'll be used
963
                    // further down the loop. There is a check at the end of each loop's iteration that
964
                    // we actually did this. Being defensive.
965
                    context.types.insert(element_oid, element_t);
89,104✔
966
                } else {
967
                    // We weren't able to find the OID of the array, which is odd because we just got
968
                    // it from the context. This means we messed something up and it is a bug. Report it.
969
                    pgrx::warning!(
×
970
                        "Assertion violation: array type with OID {} is not found",
×
971
                        array_oid
972
                    );
973
                    continue;
×
974
                }
975
            } else {
976
                // We weren't able to find the OID of the element type, which is also odd because we just got
977
                // it from the context. This means it's a bug as well. Report it.
978
                pgrx::warning!(
×
979
                        "Assertion violation: referenced element type with OID {} of array type with OID {} is not found",
×
980
                        element_oid, array_oid);
981
                continue;
×
982
            }
983

984
            // Here we are asserting that we did in fact return the element type back to the list. Part of being
985
            // defensive here.
986
            if !context.types.contains_key(&element_oid) {
89,104✔
987
                pgrx::warning!("Assertion violation: referenced element type with OID {} was not returned to the list of types", element_oid );
×
988
                continue;
×
989
            }
89,104✔
990
        }
991

992
        context
299✔
993
    }
299✔
994

995
    /// This pass cross-reference column types
996
    fn column_types(mut context: Context) -> Context {
299✔
997
        // We process tables to cross-reference their columns' types
998
        for (_oid, table) in context.tables.iter_mut() {
299✔
999
            if let Some(mtable) = Arc::get_mut(table) {
267✔
1000
                // It should be possible to get a mutable reference to table at this point
1001
                // as there are no other references to this Arc at this point.
1002

1003
                // We will now iterate over columns
1004
                for column in mtable.columns.iter_mut() {
890✔
1005
                    if let Some(mcolumn) = Arc::get_mut(column) {
890✔
1006
                        // It should be possible to get a mutable reference to column at this point
890✔
1007
                        // as there are no other references to this Arc at this point.
890✔
1008

890✔
1009
                        // Find a matching type
890✔
1010
                        mcolumn.type_ = context.types.get(&mcolumn.type_oid).cloned();
890✔
1011
                    }
890✔
1012
                }
1013
            }
×
1014
        }
1015
        context
299✔
1016
    }
299✔
1017

1018
    /// This pass populates functions for tables
1019
    fn populate_table_functions(mut context: Context) -> Context {
299✔
1020
        let mut arg_type_to_func: HashMap<u32, Vec<&Arc<Function>>> = HashMap::new();
299✔
1021
        for function in context.functions.iter().filter(|f| f.num_args == 1) {
1,460✔
1022
            let functions = arg_type_to_func.entry(function.arg_types[0]).or_default();
249✔
1023
            functions.push(function);
249✔
1024
        }
249✔
1025
        for table in &mut context.tables.values_mut() {
299✔
1026
            if let Some(table) = Arc::get_mut(table) {
267✔
1027
                if let Some(functions) = arg_type_to_func.get(&table.reltype) {
267✔
1028
                    for function in functions {
89✔
1029
                        table.functions.push(Arc::clone(function));
50✔
1030
                    }
50✔
1031
                }
228✔
1032
            }
×
1033
        }
1034
        context
299✔
1035
    }
299✔
1036

1037
    context
301✔
1038
        .map(type_details)
301✔
1039
        .map(column_types)
301✔
1040
        .map(populate_table_functions)
301✔
1041
        .map(Arc::new)
301✔
1042
        .map_err(|e| {
301✔
1043
            crate::error::GraphQLError::schema(format!(
2✔
1044
                "Error while loading schema, check comment directives. {}",
2✔
1045
                e
1046
            ))
1047
        })
2✔
1048
}
301✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc