• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

supabase / pg_graphql / 16431879287

22 Jul 2025 01:05AM UTC coverage: 92.335% (-1.7%) from 94.056%
16431879287

Pull #590

github

web-flow
Merge e434169cd into a899acda9
Pull Request #590: Add support for single record queries by primary key

219 of 317 new or added lines in 5 files covered. (69.09%)

43 existing lines in 4 files now uncovered.

7649 of 8284 relevant lines covered (92.33%)

1137.48 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.8
/src/sql_types.rs
1
use bimap::BiBTreeMap;
2
use cached::proc_macro::cached;
3
use cached::SizedCache;
4
use lazy_static::lazy_static;
5
use pgrx::*;
6
use serde::{Deserialize, Serialize};
7
use std::cmp::Ordering;
8
use std::collections::hash_map::DefaultHasher;
9
use std::collections::{HashMap, HashSet};
10
use std::hash::{Hash, Hasher};
11
use std::sync::Arc;
12
use std::*;
13

14
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
15
pub struct ColumnPermissions {
16
    pub is_insertable: bool,
17
    pub is_selectable: bool,
18
    pub is_updatable: bool,
19
    // admin interface
20
    // alterable?
21
}
22

23
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
24
pub struct ColumnDirectives {
25
    pub name: Option<String>,
26
    pub description: Option<String>,
27
}
28

29
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
30
pub struct Column {
31
    pub name: String,
32
    pub type_oid: u32,
33
    pub type_: Option<Arc<Type>>,
34
    pub type_name: String,
35
    pub max_characters: Option<i32>,
36
    pub schema_oid: u32,
37
    pub is_not_null: bool,
38
    pub is_serial: bool,
39
    pub is_generated: bool,
40
    pub has_default: bool,
41
    pub attribute_num: i32,
42
    pub permissions: ColumnPermissions,
43
    pub comment: Option<String>,
44
    pub directives: ColumnDirectives,
45
}
46

47
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
48
pub struct FunctionDirectives {
49
    pub name: Option<String>,
50
    // @graphql({"description": "the address of ..." })
51
    pub description: Option<String>,
52
}
53

54
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
55
pub struct FunctionPermissions {
56
    pub is_executable: bool,
57
}
58

59
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
60
pub enum FunctionVolatility {
61
    #[serde(rename(deserialize = "v"))]
62
    Volatile,
63
    #[serde(rename(deserialize = "s"))]
64
    Stable,
65
    #[serde(rename(deserialize = "i"))]
66
    Immutable,
67
}
68

69
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
70
pub struct Function {
71
    pub oid: u32,
72
    pub name: String,
73
    pub schema_oid: u32,
74
    pub schema_name: String,
75
    pub arg_types: Vec<u32>,
76
    pub arg_names: Option<Vec<String>>,
77
    pub arg_defaults: Option<String>,
78
    pub num_args: u32,
79
    pub num_default_args: u32,
80
    pub arg_type_names: Vec<String>,
81
    pub volatility: FunctionVolatility,
82
    pub type_oid: u32,
83
    pub type_name: String,
84
    pub is_set_of: bool,
85
    pub comment: Option<String>,
86
    pub directives: FunctionDirectives,
87
    pub permissions: FunctionPermissions,
88
}
89

90
impl Function {
91
    pub fn args(&self) -> impl Iterator<Item = (u32, &str, Option<&str>, Option<DefaultValue>)> {
5,344✔
92
        ArgsIterator::new(
5,344✔
93
            &self.arg_types,
5,344✔
94
            &self.arg_type_names,
5,344✔
95
            &self.arg_names,
5,344✔
96
            &self.arg_defaults,
5,344✔
97
            self.num_default_args,
5,344✔
98
        )
99
    }
5,344✔
100

101
    pub fn function_names_to_count(all_functions: &[Arc<Function>]) -> HashMap<&String, u32> {
543✔
102
        let mut function_name_to_count = HashMap::new();
543✔
103
        for function_name in all_functions.iter().map(|f| &f.name) {
2,219✔
104
            let entry = function_name_to_count.entry(function_name).or_insert(0u32);
2,219✔
105
            *entry += 1;
2,219✔
106
        }
2,219✔
107
        function_name_to_count
543✔
108
    }
543✔
109

110
    pub fn is_supported(
2,219✔
111
        &self,
2,219✔
112
        context: &Context,
2,219✔
113
        function_name_to_count: &HashMap<&String, u32>,
2,219✔
114
    ) -> bool {
2,219✔
115
        let types = &context.types;
2,219✔
116
        self.return_type_is_supported(types)
2,219✔
117
            && self.arg_types_are_supported(types)
2,183✔
118
            && !self.is_function_overloaded(function_name_to_count)
2,086✔
119
            && !self.has_a_nameless_arg()
1,906✔
120
            && self.permissions.is_executable
1,694✔
121
            && !self.is_in_a_system_schema()
1,691✔
122
    }
2,219✔
123

124
    fn arg_types_are_supported(&self, types: &HashMap<u32, Arc<Type>>) -> bool {
2,183✔
125
        self.args().all(|(arg_type, _, _, _)| {
3,393✔
126
            if let Some(arg_type) = types.get(&arg_type) {
3,393✔
127
                let array_element_type_is_supported = self.array_element_type_is_supported(
3,393✔
128
                    &arg_type.category,
3,393✔
129
                    arg_type.array_element_type_oid,
3,393✔
130
                    types,
3,393✔
131
                );
132
                arg_type.category == TypeCategory::Other
3,393✔
133
                    || (arg_type.category == TypeCategory::Array && array_element_type_is_supported)
776✔
134
            } else {
135
                false
×
136
            }
137
        })
3,393✔
138
    }
2,183✔
139

140
    fn return_type_is_supported(&self, types: &HashMap<u32, Arc<Type>>) -> bool {
2,219✔
141
        if let Some(return_type) = types.get(&self.type_oid) {
2,219✔
142
            let array_element_type_is_supported = self.array_element_type_is_supported(
2,219✔
143
                &return_type.category,
2,219✔
144
                return_type.array_element_type_oid,
2,219✔
145
                types,
2,219✔
146
            );
147
            return_type.category != TypeCategory::Pseudo
2,219✔
148
                && return_type.category != TypeCategory::Enum
2,219✔
149
                && return_type.name != "record"
2,214✔
150
                && return_type.name != "trigger"
2,199✔
151
                && return_type.name != "event_trigger"
2,195✔
152
                && array_element_type_is_supported
2,188✔
153
        } else {
154
            false
×
155
        }
156
    }
2,219✔
157

158
    fn array_element_type_is_supported(
5,612✔
159
        &self,
5,612✔
160
        type_category: &TypeCategory,
5,612✔
161
        array_element_type_oid: Option<u32>,
5,612✔
162
        types: &HashMap<u32, Arc<Type>>,
5,612✔
163
    ) -> bool {
5,612✔
164
        if *type_category == TypeCategory::Array {
5,612✔
165
            if let Some(array_element_type_oid) = array_element_type_oid {
999✔
166
                if let Some(array_element_type) = types.get(&array_element_type_oid) {
999✔
167
                    array_element_type.category == TypeCategory::Other
999✔
168
                } else {
169
                    false
×
170
                }
171
            } else {
172
                false
×
173
            }
174
        } else {
175
            true
4,613✔
176
        }
177
    }
5,612✔
178

179
    fn is_function_overloaded(&self, function_name_to_count: &HashMap<&String, u32>) -> bool {
2,086✔
180
        if let Some(&count) = function_name_to_count.get(&self.name) {
2,086✔
181
            count > 1
2,086✔
182
        } else {
183
            false
×
184
        }
185
    }
2,086✔
186

187
    fn has_a_nameless_arg(&self) -> bool {
1,906✔
188
        self.args().any(|(_, _, arg_name, _)| arg_name.is_none())
2,798✔
189
    }
1,906✔
190

191
    fn is_in_a_system_schema(&self) -> bool {
1,691✔
192
        // These are default schemas in supabase configuration
193
        let system_schemas = &["graphql", "graphql_public", "auth", "extensions"];
1,691✔
194
        system_schemas.contains(&self.schema_name.as_str())
1,691✔
195
    }
1,691✔
196
}
197

198
struct ArgsIterator<'a> {
199
    index: usize,
200
    arg_types: &'a [u32],
201
    arg_type_names: &'a Vec<String>,
202
    arg_names: &'a Option<Vec<String>>,
203
    arg_defaults: Vec<Option<DefaultValue>>,
204
}
205

206
#[derive(Clone)]
207
pub enum DefaultValue {
208
    NonNull(String),
209
    Null,
210
}
211

212
impl<'a> ArgsIterator<'a> {
213
    fn new(
5,344✔
214
        arg_types: &'a [u32],
5,344✔
215
        arg_type_names: &'a Vec<String>,
5,344✔
216
        arg_names: &'a Option<Vec<String>>,
5,344✔
217
        arg_defaults: &'a Option<String>,
5,344✔
218
        num_default_args: u32,
5,344✔
219
    ) -> ArgsIterator<'a> {
5,344✔
220
        ArgsIterator {
5,344✔
221
            index: 0,
5,344✔
222
            arg_types,
5,344✔
223
            arg_type_names,
5,344✔
224
            arg_names,
5,344✔
225
            arg_defaults: Self::defaults(
5,344✔
226
                arg_types,
5,344✔
227
                arg_defaults,
5,344✔
228
                num_default_args as usize,
5,344✔
229
                arg_types.len(),
5,344✔
230
            ),
5,344✔
231
        }
5,344✔
232
    }
5,344✔
233

234
    fn defaults(
5,344✔
235
        arg_types: &'a [u32],
5,344✔
236
        arg_defaults: &'a Option<String>,
5,344✔
237
        num_default_args: usize,
5,344✔
238
        num_total_args: usize,
5,344✔
239
    ) -> Vec<Option<DefaultValue>> {
5,344✔
240
        let mut defaults = vec![None; num_total_args];
5,344✔
241
        let Some(arg_defaults) = arg_defaults else {
5,344✔
242
            return defaults;
5,219✔
243
        };
244

245
        if num_default_args == 0 {
125✔
246
            return defaults;
×
247
        }
125✔
248

249
        let default_strs: Vec<&str> = arg_defaults.split(',').collect();
125✔
250

251
        if default_strs.len() != num_default_args {
125✔
252
            return defaults;
13✔
253
        }
112✔
254

255
        debug_assert!(num_default_args <= num_total_args);
112✔
256
        let start_idx = num_total_args - num_default_args;
112✔
257
        for i in start_idx..num_total_args {
471✔
258
            defaults[i] = Self::sql_to_graphql_default(default_strs[i - start_idx], arg_types[i])
471✔
259
        }
260

261
        defaults
112✔
262
    }
5,344✔
263

264
    fn sql_to_graphql_default(default_str: &str, type_oid: u32) -> Option<DefaultValue> {
471✔
265
        let trimmed = default_str.trim();
471✔
266

267
        if trimmed.starts_with("NULL::") {
471✔
268
            return Some(DefaultValue::Null);
318✔
269
        }
153✔
270

271
        let res = match type_oid {
153✔
272
            21 | 23 => trimmed
72✔
273
                .parse::<i32>()
72✔
274
                .ok()
72✔
275
                .map(|i| DefaultValue::NonNull(i.to_string())),
72✔
276
            16 => trimmed
11✔
277
                .parse::<bool>()
11✔
278
                .ok()
11✔
279
                .map(|i| DefaultValue::NonNull(i.to_string())),
11✔
280
            700 | 701 => trimmed
22✔
281
                .parse::<f64>()
22✔
282
                .ok()
22✔
283
                .map(|i| DefaultValue::NonNull(i.to_string())),
22✔
284
            25 => trimmed.strip_suffix("::text").map(|i| {
26✔
285
                DefaultValue::NonNull(format!("\"{}\"", i.trim_matches(',').trim_matches('\'')))
16✔
286
            }),
16✔
287
            _ => None,
22✔
288
        };
289

290
        // return the non-parsed value as default if for whatever reason the default value can't
291
        // be parsed into a value of the required type. This fixes problems where the default
292
        // is a complex expression like a function call etc. See test/sql/issue_533.sql for
293
        // a test case for this scenario.
294
        if res.is_some() {
153✔
295
            res
101✔
296
        } else {
297
            Some(DefaultValue::Null)
52✔
298
        }
299
    }
471✔
300
}
301

302
lazy_static! {
303
    static ref TEXT_TYPE: String = "text".to_string();
304
}
305

306
impl<'a> Iterator for ArgsIterator<'a> {
307
    type Item = (u32, &'a str, Option<&'a str>, Option<DefaultValue>);
308

309
    fn next(&mut self) -> Option<Self::Item> {
13,205✔
310
        if self.index < self.arg_types.len() {
13,205✔
311
            debug_assert!(self.arg_types.len() == self.arg_type_names.len());
8,170✔
312
            let arg_name = if let Some(arg_names) = self.arg_names {
8,170✔
313
                debug_assert!(arg_names.len() >= self.arg_types.len());
7,220✔
314
                let arg_name = arg_names[self.index].as_str();
7,220✔
315
                if !arg_name.is_empty() {
7,220✔
316
                    Some(arg_name)
7,220✔
317
                } else {
318
                    None
×
319
                }
320
            } else {
321
                None
950✔
322
            };
323
            let arg_type = self.arg_types[self.index];
8,170✔
324
            let mut arg_type_name = &self.arg_type_names[self.index];
8,170✔
325
            if arg_type_name == "character" {
8,170✔
326
                arg_type_name = &TEXT_TYPE;
86✔
327
            }
8,084✔
328
            let arg_default = self.arg_defaults[self.index].clone();
8,170✔
329
            self.index += 1;
8,170✔
330
            Some((arg_type, arg_type_name, arg_name, arg_default))
8,170✔
331
        } else {
332
            None
5,035✔
333
        }
334
    }
13,205✔
335
}
336

337
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
338
pub struct TablePermissions {
339
    pub is_insertable: bool,
340
    pub is_selectable: bool,
341
    pub is_updatable: bool,
342
    pub is_deletable: bool,
343
}
344

345
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
346
pub struct TypePermissions {
347
    pub is_usable: bool,
348
}
349

350
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
351
pub enum TypeCategory {
352
    Enum,
353
    Composite,
354
    Table,
355
    Array,
356
    Pseudo,
357
    Other,
358
}
359

360
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
361
pub struct Type {
362
    pub oid: u32,
363
    pub schema_oid: u32,
364
    pub name: String,
365
    pub category: TypeCategory,
366
    pub array_element_type_oid: Option<u32>,
367
    pub table_oid: Option<u32>,
368
    pub comment: Option<String>,
369
    pub permissions: TypePermissions,
370
    pub details: Option<TypeDetails>,
371
}
372

373
// `TypeDetails` derives `Deserialized` but is not expected to come
374
// from the SQL context. Instead, it is populated in a separate pass.
375
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
376
pub enum TypeDetails {
377
    Enum(Arc<Enum>),
378
    Composite(Arc<Composite>),
379
    Table(Arc<Table>),
380
    Element(Arc<Type>),
381
}
382

383
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
384
pub struct EnumValue {
385
    pub oid: u32,
386
    pub name: String,
387
    pub sort_order: i32,
388
}
389

390
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
391
pub struct Enum {
392
    pub oid: u32,
393
    pub schema_oid: u32,
394
    pub name: String,
395
    pub values: Vec<EnumValue>,
396
    pub array_element_type_oid: Option<u32>,
397
    pub comment: Option<String>,
398
    pub permissions: TypePermissions,
399
    pub directives: EnumDirectives,
400
}
401

402
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
403
pub struct EnumDirectives {
404
    pub name: Option<String>,
405
    pub mappings: Option<BiBTreeMap<String, String>>,
406
}
407

408
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
409
pub struct Composite {
410
    pub oid: u32,
411
    pub schema_oid: u32,
412
}
413

414
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
415
pub struct Index {
416
    pub table_oid: u32,
417
    pub column_names: Vec<String>,
418
    pub is_unique: bool,
419
    pub is_primary_key: bool,
420
}
421

422
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
423
pub struct ForeignKeyTableInfo {
424
    pub oid: u32,
425
    // The table's actual name
426
    pub name: String,
427
    pub schema: String,
428
    pub is_rls_enabled: bool,
429
    pub column_names: Vec<String>,
430
}
431

432
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
433
pub struct ForeignKeyDirectives {
434
    pub local_name: Option<String>,
435
    pub foreign_name: Option<String>,
436
}
437

438
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
439
pub struct ForeignKey {
440
    pub directives: ForeignKeyDirectives,
441
    pub local_table_meta: ForeignKeyTableInfo,
442
    pub referenced_table_meta: ForeignKeyTableInfo,
443
}
444

445
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
446
pub struct TableDirectiveTotalCount {
447
    pub enabled: bool,
448
}
449

450
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
451
pub struct TableDirectiveAggregate {
452
    pub enabled: bool,
453
}
454

455
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
456
pub struct TableDirectiveForeignKey {
457
    // Equivalent to ForeignKeyDirectives.local_name
458
    pub local_name: Option<String>,
459
    pub local_columns: Vec<String>,
460

461
    // Equivalent to ForeignKeyDirectives.foreign_name
462
    pub foreign_name: Option<String>,
463
    pub foreign_schema: String,
464
    pub foreign_table: String,
465
    pub foreign_columns: Vec<String>,
466
}
467

468
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
469
pub struct TableDirectives {
470
    // @graphql({"name": "Foo" })
471
    pub name: Option<String>,
472

473
    // @graphql({"description": "the address of ..." })
474
    pub description: Option<String>,
475

476
    // @graphql({"totalCount": { "enabled": true } })
477
    pub total_count: Option<TableDirectiveTotalCount>,
478

479
    // @graphql({"aggregate": { "enabled": true } })
480
    pub aggregate: Option<TableDirectiveAggregate>,
481

482
    // @graphql({"primary_key_columns": ["id"]})
483
    pub primary_key_columns: Option<Vec<String>>,
484

485
    // @graphql({"max_rows": 20})
486
    pub max_rows: Option<u64>,
487

488
    /*
489
    @graphql(
490
      {
491
        "foreign_keys": [
492
          {
493
            <REQUIRED>
494
            "local_columns": ["account_id"],
495
            "foriegn_schema": "public",
496
            "foriegn_table": "account",
497
            "foriegn_columns": ["id"],
498

499
            <OPTIONAL>
500
            "local_name": "foo",
501
            "foreign_name": "bar",
502
          },
503
        ]
504
      }
505
    )
506
    */
507
    pub foreign_keys: Option<Vec<TableDirectiveForeignKey>>,
508
}
509

510
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
511
pub struct Table {
512
    pub oid: u32,
513
    pub name: String,
514
    pub schema_oid: u32,
515
    pub schema: String,
516
    pub columns: Vec<Arc<Column>>,
517
    pub comment: Option<String>,
518
    pub is_rls_enabled: bool,
519
    pub relkind: String, // r = table, v = view, m = mat view, f = foreign table
520
    pub reltype: u32,
521
    pub permissions: TablePermissions,
522
    pub indexes: Vec<Index>,
523
    #[serde(default)]
524
    pub functions: Vec<Arc<Function>>,
525
    pub directives: TableDirectives,
526
}
527

528
impl Table {
529
    pub fn primary_key(&self) -> Option<Index> {
4,304✔
530
        let real_pkey = self.indexes.iter().find(|x| x.is_primary_key);
4,304✔
531

532
        if real_pkey.is_some() {
4,304✔
533
            return real_pkey.cloned();
4,119✔
534
        }
185✔
535

536
        // Check for a primary key definition in comment directives
537
        if let Some(column_names) = &self.directives.primary_key_columns {
185✔
538
            // validate that columns exist on the table
539
            let mut valid_column_names: Vec<&String> = vec![];
164✔
540
            for column_name in column_names {
328✔
541
                for column in &self.columns {
563✔
542
                    if column_name == &column.name {
399✔
543
                        valid_column_names.push(&column.name);
164✔
544
                    }
235✔
545
                }
546
            }
547
            if valid_column_names.len() != column_names.len() {
164✔
548
                // At least one of the column names didn't exist on the table
549
                // so the primary key directive is not valid
550
                // Ideally we'd throw an error here instead
551
                None
×
552
            } else {
553
                Some(Index {
164✔
554
                    table_oid: self.oid,
164✔
555
                    column_names: column_names.clone(),
164✔
556
                    is_unique: true,
164✔
557
                    is_primary_key: true,
164✔
558
                })
164✔
559
            }
560
        } else {
561
            None
21✔
562
        }
563
    }
4,304✔
564

565
    pub fn primary_key_columns(&self) -> Vec<&Arc<Column>> {
1,421✔
566
        self.primary_key()
1,421✔
567
            .map(|x| x.column_names)
1,421✔
568
            .unwrap_or_default()
1,421✔
569
            .iter()
1,421✔
570
            .map(|col_name| {
1,435✔
571
                self.columns
1,435✔
572
                    .iter()
1,435✔
573
                    .find(|col| &col.name == col_name)
1,449✔
574
                    .expect("Failed to unwrap pkey by column names")
1,435✔
575
            })
1,435✔
576
            .collect::<Vec<&Arc<Column>>>()
1,421✔
577
    }
1,421✔
578

579
    pub fn has_supported_pk_types_for_by_pk(&self) -> bool {
232✔
580
        let pk_columns = self.primary_key_columns();
232✔
581
        if pk_columns.is_empty() {
232✔
NEW
582
            return false;
×
583
        }
232✔
584

585
        // Check that all primary key columns have supported types
586
        pk_columns.iter().all(|col| {
234✔
587
            SupportedPrimaryKeyType::from_type_name(&col.type_name).is_some()
234✔
588
        })
234✔
589
    }
232✔
590

591
    pub fn is_any_column_selectable(&self) -> bool {
1,861✔
592
        self.columns.iter().any(|x| x.permissions.is_selectable)
1,881✔
593
    }
1,861✔
594

595
    pub fn is_any_column_insertable(&self) -> bool {
406✔
596
        self.columns.iter().any(|x| x.permissions.is_insertable)
415✔
597
    }
406✔
598

599
    pub fn is_any_column_updatable(&self) -> bool {
406✔
600
        self.columns.iter().any(|x| x.permissions.is_updatable)
417✔
601
    }
406✔
602

603
    /// Get the effective max_rows value for this table.
604
    /// If table-specific max_rows is set, use that.
605
    /// Otherwise, fall back to schema-level max_rows.
606
    /// If neither is set, use the global default(set in load_sql_context.sql)
607
    pub fn max_rows(&self, schema: &Schema) -> u64 {
292✔
608
        self.directives.max_rows.unwrap_or(schema.directives.max_rows)
292✔
609
    }
292✔
610
}
611

612
#[derive(Debug, PartialEq)]
613
pub enum SupportedPrimaryKeyType {
614
    // Integer types
615
    Int,      // int, int4, integer
616
    BigInt,   // bigint, int8
617
    SmallInt, // smallint, int2
618
    // String types
619
    Text,     // text
620
    VarChar,  // varchar
621
    Char,     // char, bpchar
622
    CiText,   // citext
623
    // UUID
624
    UUID,     // uuid
625
}
626

627
impl SupportedPrimaryKeyType {
628
    fn from_type_name(type_name: &str) -> Option<Self> {
234✔
629
        match type_name {
234✔
630
            // Integer types
631
            "int" | "int4" | "integer" => Some(Self::Int),
234✔
632
            "bigint" | "int8" => Some(Self::BigInt),
15✔
633
            "smallint" | "int2" => Some(Self::SmallInt),
13✔
634
            // String types
635
            "text" => Some(Self::Text),
13✔
636
            "varchar" => Some(Self::VarChar),
12✔
637
            "char" | "bpchar" => Some(Self::Char),
12✔
638
            "citext" => Some(Self::CiText),
12✔
639
            // UUID
640
            "uuid" => Some(Self::UUID),
12✔
641
            // Any other type is not supported
NEW
642
            _ => None,
×
643
        }
644
    }
234✔
645
}
646

647
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
648
pub struct SchemaDirectives {
649
    // @graphql({"inflect_names": true})
650
    pub inflect_names: bool,
651
    // @graphql({"max_rows": 20})
652
    pub max_rows: u64,
653
}
654

655
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
656
pub struct Schema {
657
    pub oid: u32,
658
    pub name: String,
659
    pub comment: Option<String>,
660
    pub directives: SchemaDirectives,
661
}
662

663
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash)]
664
pub struct Config {
665
    pub search_path: Vec<String>,
666
    pub role: String,
667
    pub schema_version: i32,
668
}
669

670
#[derive(Deserialize, Debug, Eq, PartialEq)]
671
pub struct Context {
672
    pub config: Config,
673
    pub schemas: HashMap<u32, Schema>,
674
    pub tables: HashMap<u32, Arc<Table>>,
675
    foreign_keys: Vec<Arc<ForeignKey>>,
676
    pub types: HashMap<u32, Arc<Type>>,
677
    pub enums: HashMap<u32, Arc<Enum>>,
678
    pub composites: Vec<Arc<Composite>>,
679
    pub functions: Vec<Arc<Function>>,
680
}
681

682
impl Hash for Context {
683
    fn hash<H: Hasher>(&self, state: &mut H) {
2,305✔
684
        // Only the config is needed to has ha Context
685
        self.config.hash(state);
2,305✔
686
    }
2,305✔
687
}
688

689
impl Context {
690
    /// Collect all foreign keys referencing (inbound or outbound) a table
691
    pub fn foreign_keys(&self) -> Vec<Arc<ForeignKey>> {
372✔
692
        let mut fkeys: Vec<Arc<ForeignKey>> = self.foreign_keys.clone();
372✔
693

694
        // Add foreign keys defined in comment directives
695
        for table in self.tables.values() {
624✔
696
            let directive_fkeys: Vec<TableDirectiveForeignKey> =
624✔
697
                match &table.directives.foreign_keys {
624✔
698
                    Some(keys) => keys.clone(),
28✔
699
                    None => vec![],
596✔
700
                };
701

702
            for directive_fkey in directive_fkeys.iter() {
624✔
703
                let referenced_t = match self.get_table_by_name(
28✔
704
                    &directive_fkey.foreign_schema,
28✔
705
                    &directive_fkey.foreign_table,
28✔
706
                ) {
28✔
707
                    Some(t) => t,
28✔
708
                    None => {
709
                        // No table found with requested name. Skip.
710
                        continue;
×
711
                    }
712
                };
713

714
                let referenced_t_column_names: HashSet<&String> =
28✔
715
                    referenced_t.columns.iter().map(|x| &x.name).collect();
80✔
716

717
                // Verify all foreign column references are valid
718
                if !directive_fkey
28✔
719
                    .foreign_columns
28✔
720
                    .iter()
28✔
721
                    .all(|col| referenced_t_column_names.contains(col))
28✔
722
                {
723
                    // Skip if invalid references exist
724
                    continue;
×
725
                }
28✔
726

727
                let fk = ForeignKey {
28✔
728
                    local_table_meta: ForeignKeyTableInfo {
28✔
729
                        oid: table.oid,
28✔
730
                        name: table.name.clone(),
28✔
731
                        schema: table.schema.clone(),
28✔
732
                        is_rls_enabled: table.is_rls_enabled,
28✔
733
                        column_names: directive_fkey.local_columns.clone(),
28✔
734
                    },
28✔
735
                    referenced_table_meta: ForeignKeyTableInfo {
28✔
736
                        oid: referenced_t.oid,
28✔
737
                        name: referenced_t.name.clone(),
28✔
738
                        schema: referenced_t.schema.clone(),
28✔
739
                        is_rls_enabled: table.is_rls_enabled,
28✔
740
                        column_names: directive_fkey.foreign_columns.clone(),
28✔
741
                    },
28✔
742
                    directives: ForeignKeyDirectives {
28✔
743
                        local_name: directive_fkey.local_name.clone(),
28✔
744
                        foreign_name: directive_fkey.foreign_name.clone(),
28✔
745
                    },
28✔
746
                };
28✔
747

748
                fkeys.push(Arc::new(fk));
28✔
749
            }
750
        }
751

752
        fkeys
372✔
753
            .into_iter()
372✔
754
            .filter(|fk| self.fkey_is_selectable(fk))
372✔
755
            .collect()
372✔
756
    }
372✔
757

758
    /// Check if a type is a composite type
759
    pub fn is_composite(&self, type_oid: u32) -> bool {
2,800✔
760
        self.composites.iter().any(|x| x.oid == type_oid)
2,800✔
761
    }
2,800✔
762

763
    pub fn get_table_by_name(
28✔
764
        &self,
28✔
765
        schema_name: &String,
28✔
766
        table_name: &String,
28✔
767
    ) -> Option<&Arc<Table>> {
28✔
768
        self.tables
28✔
769
            .values()
28✔
770
            .find(|x| &x.schema == schema_name && &x.name == table_name)
66✔
771
    }
28✔
772

773
    pub fn get_table_by_oid(&self, oid: u32) -> Option<&Arc<Table>> {
800✔
774
        self.tables.get(&oid)
800✔
775
    }
800✔
776

777
    /// Check if the local side of a foreign key is comprised of unique columns
778
    pub fn fkey_is_locally_unique(&self, fkey: &ForeignKey) -> bool {
128✔
779
        let table: &Arc<Table> = match self.get_table_by_oid(fkey.local_table_meta.oid) {
128✔
780
            Some(table) => table,
128✔
781
            None => {
782
                return false;
×
783
            }
784
        };
785

786
        let fkey_columns: HashSet<&String> = fkey.local_table_meta.column_names.iter().collect();
128✔
787

788
        for index in table.indexes.iter().filter(|x| x.is_unique) {
134✔
789
            let index_column_names: HashSet<&String> = index.column_names.iter().collect();
134✔
790

791
            if index_column_names
134✔
792
                .iter()
134✔
793
                .all(|col_name| fkey_columns.contains(col_name))
136✔
794
            {
795
                return true;
10✔
796
            }
124✔
797
        }
798
        false
118✔
799
    }
128✔
800

801
    /// Are both sides of the foreign key composed of selectable columns
802
    pub fn fkey_is_selectable(&self, fkey: &ForeignKey) -> bool {
218✔
803
        let local_table: &Arc<Table> = match self.get_table_by_oid(fkey.local_table_meta.oid) {
218✔
804
            Some(table) => table,
218✔
805
            None => {
806
                return false;
×
807
            }
808
        };
809

810
        let referenced_table: &Arc<Table> =
218✔
811
            match self.get_table_by_oid(fkey.referenced_table_meta.oid) {
218✔
812
                Some(table) => table,
218✔
813
                None => {
814
                    return false;
×
815
                }
816
            };
817

818
        let fkey_local_columns = &fkey.local_table_meta.column_names;
218✔
819
        let fkey_referenced_columns = &fkey.referenced_table_meta.column_names;
218✔
820

821
        let local_columns_selectable: HashSet<&String> = local_table
218✔
822
            .columns
218✔
823
            .iter()
218✔
824
            .filter(|x| x.permissions.is_selectable)
714✔
825
            .map(|col| &col.name)
708✔
826
            .collect();
218✔
827

828
        let referenced_columns_selectable: HashSet<&String> = referenced_table
218✔
829
            .columns
218✔
830
            .iter()
218✔
831
            .filter(|x| x.permissions.is_selectable)
1,096✔
832
            .map(|col| &col.name)
1,090✔
833
            .collect();
218✔
834

835
        fkey_local_columns
218✔
836
            .iter()
218✔
837
            .all(|col| local_columns_selectable.contains(col))
218✔
838
            && fkey_referenced_columns
218✔
839
                .iter()
218✔
840
                .all(|col| referenced_columns_selectable.contains(col))
218✔
841
    }
218✔
842
}
843

844
/// This method is similar to `Spi::get_one` with the only difference
845
/// being that it calls `client.select` instead of `client.update`.
846
/// The `client.update` method generates a new transaction id so
847
/// calling `Spi::get_one` is not possible when postgres is in
848
/// recovery mode.
849
pub(crate) fn get_one_readonly<A: FromDatum + IntoDatum>(
953✔
850
    query: &str,
953✔
851
) -> std::result::Result<Option<A>, pgrx::spi::Error> {
953✔
852
    Spi::connect(|client| client.select(query, Some(1), None)?.first().get_one())
953✔
853
}
953✔
854

855
pub fn load_sql_config() -> Config {
656✔
856
    let query = include_str!("../sql/load_sql_config.sql");
656✔
857
    let sql_result: serde_json::Value = get_one_readonly::<JsonB>(query)
656✔
858
        .expect("failed to read sql config")
656✔
859
        .expect("sql config is missing")
656✔
860
        .0;
656✔
861
    let config: Config =
656✔
862
        serde_json::from_value(sql_result).expect("failed to convert sql config into json");
656✔
863
    config
656✔
864
}
656✔
865

866
pub fn calculate_hash<T: Hash>(t: &T) -> u64 {
5,526✔
867
    let mut s = DefaultHasher::new();
5,526✔
868
    t.hash(&mut s);
5,526✔
869
    s.finish()
5,526✔
870
}
5,526✔
871

872
#[cached(
873
    type = "SizedCache<u64, Result<Arc<Context>, String>>",
874
    create = "{ SizedCache::with_size(250) }",
875
    convert = r#"{ calculate_hash(_config) }"#
876
)]
877
pub fn load_sql_context(_config: &Config) -> Result<Arc<Context>, String> {
297✔
878
    // cache value for next query
879
    let query = include_str!("../sql/load_sql_context.sql");
297✔
880
    let sql_result: serde_json::Value = get_one_readonly::<JsonB>(query)
297✔
881
        .expect("failed to read sql context")
297✔
882
        .expect("sql context is missing")
297✔
883
        .0;
297✔
884
    let context: Result<Context, serde_json::Error> = serde_json::from_value(sql_result);
297✔
885

886
    /// This pass cross-reference types with its details
887
    fn type_details(mut context: Context) -> Context {
295✔
888
        let mut array_types = HashMap::new();
295✔
889
        // We process types to cross-reference their details
890
        for (oid, type_) in context.types.iter_mut() {
181,379✔
891
            if let Some(mtype) = Arc::get_mut(type_) {
181,379✔
892
                // It should be possible to get a mutable reference to type at this point
893
                // as there are no other references to this Arc at this point.
894

895
                // Depending on type's category, locate the appropriate piece of information about it
896
                mtype.details = match mtype.category {
181,379✔
897
                    TypeCategory::Enum => context.enums.get(oid).cloned().map(TypeDetails::Enum),
27✔
898
                    TypeCategory::Composite => context
1✔
899
                        .composites
1✔
900
                        .iter()
1✔
901
                        .find(|c| c.oid == *oid)
1✔
902
                        .cloned()
1✔
903
                        .map(TypeDetails::Composite),
1✔
904
                    TypeCategory::Table => context.tables.get(oid).cloned().map(TypeDetails::Table),
61,897✔
905
                    TypeCategory::Array => {
906
                        // We can't cross-reference with `context.types` here as it is already mutably borrowed,
907
                        // so we instead memorize the reference to process later
908
                        if let Some(element_oid) = mtype.array_element_type_oid {
86,707✔
909
                            array_types.insert(*oid, element_oid);
86,707✔
910
                        }
86,707✔
911
                        None
86,707✔
912
                    }
913
                    _ => None,
32,747✔
914
                };
915
            }
×
916
        }
917

918
        // Ensure the types are ordered so that we don't run into a situation where we can't
919
        // update the type anymore as it has been referenced but the type details weren't completed yet
920
        let referenced_types = array_types.values().copied().collect::<Vec<_>>();
295✔
921
        let mut ordered_types = array_types
295✔
922
            .iter()
295✔
923
            .map(|(k, v)| (*k, *v))
86,707✔
924
            .collect::<Vec<_>>();
295✔
925
        // We sort them by their presence in referencing. If the type has been referenced,
926
        // it should be at the top.
927
        ordered_types.sort_by(|(k1, _), (k2, _)| {
197,191✔
928
            if referenced_types.contains(k1) && referenced_types.contains(k2) {
197,191✔
929
                Ordering::Equal
420✔
930
            } else if referenced_types.contains(k1) {
196,771✔
931
                Ordering::Less
42,893✔
932
            } else {
933
                Ordering::Greater
153,878✔
934
            }
935
        });
197,191✔
936

937
        // Now we're ready to process array types
938
        for (array_oid, element_oid) in ordered_types {
87,002✔
939
            // We remove the element type from the map to ensure there is no mutability conflict for when
940
            // we get a mutable reference to the array type. We will put it back after we're done with it,
941
            // a few lines below.
942
            if let Some(element_t) = context.types.remove(&element_oid) {
86,707✔
943
                if let Some(array_t) = context.types.get_mut(&array_oid) {
86,707✔
944
                    if let Some(array) = Arc::get_mut(array_t) {
86,707✔
945
                        // It should be possible to get a mutable reference to type at this point
86,707✔
946
                        // as there are no other references to this Arc at this point.
86,707✔
947
                        array.details = Some(TypeDetails::Element(element_t.clone()));
86,707✔
948
                    } else {
86,707✔
949
                        // For some reason, we weren't able to get it. It means something have changed
950
                        // in our logic and we're presenting an assertion violation. Let's report it.
951
                        // It's a bug.
952
                        pgrx::warning!(
×
953
                            "Assertion violation: array type with OID {} is already referenced",
×
954
                            array_oid
955
                        );
956
                        continue;
×
957
                    }
958
                    // Put the element type back. NB: Very important to keep this line! It'll be used
959
                    // further down the loop. There is a check at the end of each loop's iteration that
960
                    // we actually did this. Being defensive.
961
                    context.types.insert(element_oid, element_t);
86,707✔
962
                } else {
963
                    // We weren't able to find the OID of the array, which is odd because we just got
964
                    // it from the context. This means we messed something up and it is a bug. Report it.
965
                    pgrx::warning!(
×
966
                        "Assertion violation: array type with OID {} is not found",
×
967
                        array_oid
968
                    );
969
                    continue;
×
970
                }
971
            } else {
972
                // We weren't able to find the OID of the element type, which is also odd because we just got
973
                // it from the context. This means it's a bug as well. Report it.
974
                pgrx::warning!(
×
975
                        "Assertion violation: referenced element type with OID {} of array type with OID {} is not found",
×
976
                        element_oid, array_oid);
977
                continue;
×
978
            }
979

980
            // Here we are asserting that we did in fact return the element type back to the list. Part of being
981
            // defensive here.
982
            if !context.types.contains_key(&element_oid) {
86,707✔
983
                pgrx::warning!("Assertion violation: referenced element type with OID {} was not returned to the list of types", element_oid );
×
984
                continue;
×
985
            }
86,707✔
986
        }
987

988
        context
295✔
989
    }
295✔
990

991
    /// This pass cross-reference column types
992
    fn column_types(mut context: Context) -> Context {
295✔
993
        // We process tables to cross-reference their columns' types
994
        for (_oid, table) in context.tables.iter_mut() {
295✔
995
            if let Some(mtable) = Arc::get_mut(table) {
238✔
996
                // It should be possible to get a mutable reference to table at this point
997
                // as there are no other references to this Arc at this point.
998

999
                // We will now iterate over columns
1000
                for column in mtable.columns.iter_mut() {
821✔
1001
                    if let Some(mcolumn) = Arc::get_mut(column) {
821✔
1002
                        // It should be possible to get a mutable reference to column at this point
821✔
1003
                        // as there are no other references to this Arc at this point.
821✔
1004

821✔
1005
                        // Find a matching type
821✔
1006
                        mcolumn.type_ = context.types.get(&mcolumn.type_oid).cloned();
821✔
1007
                    }
821✔
1008
                }
1009
            }
×
1010
        }
1011
        context
295✔
1012
    }
295✔
1013

1014
    /// This pass populates functions for tables
1015
    fn populate_table_functions(mut context: Context) -> Context {
295✔
1016
        let mut arg_type_to_func: HashMap<u32, Vec<&Arc<Function>>> = HashMap::new();
295✔
1017
        for function in context.functions.iter().filter(|f| f.num_args == 1) {
1,446✔
1018
            let functions = arg_type_to_func.entry(function.arg_types[0]).or_default();
235✔
1019
            functions.push(function);
235✔
1020
        }
235✔
1021
        for table in &mut context.tables.values_mut() {
295✔
1022
            if let Some(table) = Arc::get_mut(table) {
238✔
1023
                if let Some(functions) = arg_type_to_func.get(&table.reltype) {
238✔
1024
                    for function in functions {
71✔
1025
                        table.functions.push(Arc::clone(function));
36✔
1026
                    }
36✔
1027
                }
203✔
1028
            }
×
1029
        }
1030
        context
295✔
1031
    }
295✔
1032

1033
    context
297✔
1034
        .map(type_details)
297✔
1035
        .map(column_types)
297✔
1036
        .map(populate_table_functions)
297✔
1037
        .map(Arc::new)
297✔
1038
        .map_err(|e| {
297✔
1039
            format!(
2✔
1040
                "Error while loading schema, check comment directives. {}",
2✔
1041
                e
1042
            )
1043
        })
2✔
1044
}
297✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc