• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

supabase / pg_replicate / 15162015069

21 May 2025 12:20PM UTC coverage: 53.936% (+13.9%) from 40.016%
15162015069

push

github

web-flow
Merge pull request #121 from supabase/riccardo/feat/add-integration-test

318 of 422 new or added lines in 14 files covered. (75.36%)

2 existing lines in 2 files now uncovered.

3494 of 6478 relevant lines covered (53.94%)

26.11 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.96
/pg_replicate/src/clients/postgres.rs
1
use std::collections::HashMap;
2

3
use pg_escape::{quote_identifier, quote_literal};
4
use postgres::schema::{ColumnSchema, TableId, TableName, TableSchema};
5
use postgres::tokio::options::PgDatabaseOptions;
6
use postgres_replication::LogicalReplicationStream;
7
use rustls::{pki_types::CertificateDer, ClientConfig};
8
use thiserror::Error;
9
use tokio_postgres::{
10
    config::ReplicationMode,
11
    types::{Kind, PgLsn, Type},
12
    Client as PostgresClient, Config, CopyOutStream, NoTls, SimpleQueryMessage,
13
};
14
use tokio_postgres_rustls::MakeRustlsConnect;
15
use tracing::{info, warn};
16

17
pub struct SlotInfo {
18
    pub confirmed_flush_lsn: PgLsn,
19
}
20

21
/// A client for Postgres logical replication
22
pub struct ReplicationClient {
23
    postgres_client: PostgresClient,
24
    in_txn: bool,
25
}
26

27
#[derive(Debug, Error)]
28
pub enum ReplicationClientError {
29
    #[error("tokio_postgres error: {0}")]
30
    TokioPostgresError(#[from] tokio_postgres::Error),
31

32
    #[error("column {0} is missing from table {1}")]
33
    MissingColumn(String, String),
34

35
    #[error("publication {0} doesn't exist")]
36
    MissingPublication(String),
37

38
    #[error("oid column is not a valid u32")]
39
    OidColumnNotU32,
40

41
    #[error("replica identity '{0}' not supported")]
42
    ReplicaIdentityNotSupported(String),
43

44
    #[error("type modifier column is not a valid u32")]
45
    TypeModifierColumnNotI32,
46

47
    #[error("column {0}'s type with oid {1} in relation {2} is not supported")]
48
    UnsupportedType(String, u32, String),
49

50
    #[error("table {0} doesn't exist")]
51
    MissingTable(TableName),
52

53
    #[error("not a valid PgLsn")]
54
    InvalidPgLsn,
55

56
    #[error("failed to create slot")]
57
    FailedToCreateSlot,
58

59
    #[error("rustls error: {0}")]
60
    RustlsError(#[from] rustls::Error),
61
}
62

63
impl ReplicationClient {
64
    /// Connect to a postgres database in logical replication mode without TLS
65
    pub async fn connect_no_tls(
3✔
66
        options: PgDatabaseOptions,
3✔
67
    ) -> Result<ReplicationClient, ReplicationClientError> {
3✔
68
        info!("connecting to postgres without TLS");
3✔
69

70
        let mut config: Config = options.into();
3✔
71
        config.replication_mode(ReplicationMode::Logical);
3✔
72

73
        let (postgres_client, connection) = config.connect(NoTls).await?;
3✔
74

75
        tokio::spawn(async move {
3✔
76
            info!("waiting for connection to terminate");
3✔
77
            if let Err(e) = connection.await {
3✔
78
                warn!("connection error: {}", e);
×
NEW
79
                return;
×
80
            }
3✔
81
            info!("connection terminated successfully")
3✔
82
        });
3✔
83

3✔
84
        info!("successfully connected to postgres");
3✔
85

86
        Ok(ReplicationClient {
3✔
87
            postgres_client,
3✔
88
            in_txn: false,
3✔
89
        })
3✔
90
    }
3✔
91

92
    /// Connect to a postgres database in logical replication mode with TLS
93
    pub async fn connect_tls(
×
NEW
94
        options: PgDatabaseOptions,
×
95
        trusted_root_certs: Vec<CertificateDer<'static>>,
×
96
    ) -> Result<ReplicationClient, ReplicationClientError> {
×
97
        info!("connecting to postgres with TLS");
×
98

NEW
99
        let mut config: Config = options.into();
×
NEW
100
        config.replication_mode(ReplicationMode::Logical);
×
101

×
102
        let mut root_store = rustls::RootCertStore::empty();
×
103
        for trusted_root_cert in trusted_root_certs {
×
104
            root_store.add(trusted_root_cert)?;
×
105
        }
106
        let tls_config = ClientConfig::builder()
×
107
            .with_root_certificates(root_store)
×
108
            .with_no_client_auth();
×
109

×
110
        let tls = MakeRustlsConnect::new(tls_config);
×
111

112
        let (postgres_client, connection) = config.connect(tls).await?;
×
113

114
        tokio::spawn(async move {
×
115
            info!("waiting for connection to terminate");
×
116
            if let Err(e) = connection.await {
×
117
                warn!("connection error: {}", e);
×
NEW
118
                return;
×
119
            }
×
NEW
120
            info!("connection terminated successfully")
×
121
        });
×
122

×
123
        info!("successfully connected to postgres");
×
124

125
        Ok(ReplicationClient {
×
126
            postgres_client,
×
127
            in_txn: false,
×
128
        })
×
129
    }
×
130

131
    /// Starts a read-only trasaction with repeatable read isolation level
132
    pub async fn begin_readonly_transaction(&mut self) -> Result<(), ReplicationClientError> {
4✔
133
        self.postgres_client
4✔
134
            .simple_query("begin read only isolation level repeatable read;")
4✔
135
            .await?;
4✔
136
        self.in_txn = true;
4✔
137
        Ok(())
4✔
138
    }
4✔
139

140
    /// Commits a transaction
141
    pub async fn commit_txn(&mut self) -> Result<(), ReplicationClientError> {
3✔
142
        if self.in_txn {
3✔
143
            self.postgres_client.simple_query("commit;").await?;
3✔
144
            self.in_txn = false;
3✔
145
        }
×
146
        Ok(())
3✔
147
    }
3✔
148

149
    async fn rollback_txn(&mut self) -> Result<(), ReplicationClientError> {
1✔
150
        if self.in_txn {
1✔
151
            self.postgres_client.simple_query("rollback;").await?;
1✔
152
            self.in_txn = false;
1✔
153
        }
×
154
        Ok(())
1✔
155
    }
1✔
156

157
    /// Returns a [CopyOutStream] for a table
158
    pub async fn get_table_copy_stream(
2✔
159
        &self,
2✔
160
        table_name: &TableName,
2✔
161
        column_schemas: &[ColumnSchema],
2✔
162
    ) -> Result<CopyOutStream, ReplicationClientError> {
2✔
163
        let column_list = column_schemas
2✔
164
            .iter()
2✔
165
            .map(|col| quote_identifier(&col.name))
4✔
166
            .collect::<Vec<_>>()
2✔
167
            .join(", ");
2✔
168

2✔
169
        let copy_query = format!(
2✔
170
            r#"COPY {} ({column_list}) TO STDOUT WITH (FORMAT text);"#,
2✔
171
            table_name.as_quoted_identifier(),
2✔
172
        );
2✔
173

174
        let stream = self.postgres_client.copy_out_simple(&copy_query).await?;
2✔
175

176
        Ok(stream)
2✔
177
    }
2✔
178

179
    /// Returns a vector of columns of a table, optionally filtered by a publication's column list
180
    pub async fn get_column_schemas(
3✔
181
        &self,
3✔
182
        table_id: TableId,
3✔
183
        publication: Option<&str>,
3✔
184
    ) -> Result<Vec<ColumnSchema>, ReplicationClientError> {
3✔
185
        let (pub_cte, pub_pred) = if let Some(publication) = publication {
3✔
186
            (
1✔
187
                format!(
1✔
188
                    "with pub_attrs as (
1✔
189
                        select unnest(r.prattrs)
1✔
190
                        from pg_publication_rel r
1✔
191
                        left join pg_publication p on r.prpubid = p.oid
1✔
192
                        where p.pubname = {publication}
1✔
193
                        and r.prrelid = {table_id}
1✔
194
                    )",
1✔
195
                    publication = quote_literal(publication),
1✔
196
                ),
1✔
197
                "and (
1✔
198
                    case (select count(*) from pub_attrs)
1✔
199
                    when 0 then true
1✔
200
                    else (a.attnum in (select * from pub_attrs))
1✔
201
                    end
1✔
202
                )",
1✔
203
            )
1✔
204
        } else {
205
            ("".into(), "")
2✔
206
        };
207

208
        let column_info_query = format!(
3✔
209
            "{pub_cte}
3✔
210
            select a.attname,
3✔
211
                a.atttypid,
3✔
212
                a.atttypmod,
3✔
213
                a.attnotnull,
3✔
214
                coalesce(i.indisprimary, false) as primary
3✔
215
            from pg_attribute a
3✔
216
            left join pg_index i
3✔
217
                on a.attrelid = i.indrelid
3✔
218
                and a.attnum = any(i.indkey)
3✔
219
                and i.indisprimary = true
3✔
220
            where a.attnum > 0::int2
3✔
221
            and not a.attisdropped
3✔
222
            and a.attgenerated = ''
3✔
223
            and a.attrelid = {table_id}
3✔
224
            {pub_pred}
3✔
225
            order by a.attnum
3✔
226
            ",
3✔
227
        );
3✔
228

3✔
229
        let mut column_schemas = vec![];
3✔
230

231
        for message in self
12✔
232
            .postgres_client
3✔
233
            .simple_query(&column_info_query)
3✔
234
            .await?
3✔
235
        {
236
            if let SimpleQueryMessage::Row(row) = message {
12✔
237
                let name = row
6✔
238
                    .try_get("attname")?
6✔
239
                    .ok_or(ReplicationClientError::MissingColumn(
6✔
240
                        "attname".to_string(),
6✔
241
                        "pg_attribute".to_string(),
6✔
242
                    ))?
6✔
243
                    .to_string();
6✔
244

245
                let type_oid = row
6✔
246
                    .try_get("atttypid")?
6✔
247
                    .ok_or(ReplicationClientError::MissingColumn(
6✔
248
                        "atttypid".to_string(),
6✔
249
                        "pg_attribute".to_string(),
6✔
250
                    ))?
6✔
251
                    .parse()
6✔
252
                    .map_err(|_| ReplicationClientError::OidColumnNotU32)?;
6✔
253

254
                //TODO: For now we assume all types are simple, fix it later
255
                let typ = Type::from_oid(type_oid).unwrap_or(Type::new(
6✔
256
                    format!("unnamed(oid: {type_oid})"),
6✔
257
                    type_oid,
6✔
258
                    Kind::Simple,
6✔
259
                    "pg_catalog".to_string(),
6✔
260
                ));
6✔
261

262
                let modifier = row
6✔
263
                    .try_get("atttypmod")?
6✔
264
                    .ok_or(ReplicationClientError::MissingColumn(
6✔
265
                        "atttypmod".to_string(),
6✔
266
                        "pg_attribute".to_string(),
6✔
267
                    ))?
6✔
268
                    .parse()
6✔
269
                    .map_err(|_| ReplicationClientError::TypeModifierColumnNotI32)?;
6✔
270

271
                let nullable =
6✔
272
                    row.try_get("attnotnull")?
6✔
273
                        .ok_or(ReplicationClientError::MissingColumn(
6✔
274
                            "attnotnull".to_string(),
6✔
275
                            "pg_attribute".to_string(),
6✔
276
                        ))?
6✔
277
                        == "f";
6✔
278

279
                let primary =
6✔
280
                    row.try_get("primary")?
6✔
281
                        .ok_or(ReplicationClientError::MissingColumn(
6✔
282
                            "indisprimary".to_string(),
6✔
283
                            "pg_index".to_string(),
6✔
284
                        ))?
6✔
285
                        == "t";
6✔
286

6✔
287
                column_schemas.push(ColumnSchema {
6✔
288
                    name,
6✔
289
                    typ,
6✔
290
                    modifier,
6✔
291
                    nullable,
6✔
292
                    primary,
6✔
293
                })
6✔
294
            }
6✔
295
        }
296

297
        Ok(column_schemas)
3✔
298
    }
3✔
299

300
    pub async fn get_table_schemas(
3✔
301
        &self,
3✔
302
        table_names: &[TableName],
3✔
303
        publication: Option<&str>,
3✔
304
    ) -> Result<HashMap<TableId, TableSchema>, ReplicationClientError> {
3✔
305
        let mut table_schemas = HashMap::new();
3✔
306

307
        for table_name in table_names {
6✔
308
            let table_schema = self
3✔
309
                .get_table_schema(table_name.clone(), publication)
3✔
310
                .await?;
3✔
311
            if !table_schema.has_primary_keys() {
3✔
312
                warn!(
×
313
                    "table {} with id {} will not be copied because it has no primary key",
×
314
                    table_schema.table_name, table_schema.table_id
315
                );
316
                continue;
×
317
            }
3✔
318
            table_schemas.insert(table_schema.table_id, table_schema);
3✔
319
        }
320

321
        Ok(table_schemas)
3✔
322
    }
3✔
323

324
    async fn get_table_schema(
3✔
325
        &self,
3✔
326
        table_name: TableName,
3✔
327
        publication: Option<&str>,
3✔
328
    ) -> Result<TableSchema, ReplicationClientError> {
3✔
329
        let table_id = self
3✔
330
            .get_table_id(&table_name)
3✔
331
            .await?
3✔
332
            .ok_or(ReplicationClientError::MissingTable(table_name.clone()))?;
3✔
333
        let column_schemas = self.get_column_schemas(table_id, publication).await?;
3✔
334
        Ok(TableSchema {
3✔
335
            table_name,
3✔
336
            table_id,
3✔
337
            column_schemas,
3✔
338
        })
3✔
339
    }
3✔
340

341
    /// Returns the table id (called relation id in Postgres) of a table
342
    /// Also checks whether the replica identity is default or full and
343
    /// returns an error if not.
344
    pub async fn get_table_id(
3✔
345
        &self,
3✔
346
        table: &TableName,
3✔
347
    ) -> Result<Option<TableId>, ReplicationClientError> {
3✔
348
        let quoted_schema = quote_literal(&table.schema);
3✔
349
        let quoted_name = quote_literal(&table.name);
3✔
350

3✔
351
        let table_info_query = format!(
3✔
352
            "select c.oid,
3✔
353
                c.relreplident
3✔
354
            from pg_class c
3✔
355
            join pg_namespace n
3✔
356
                on (c.relnamespace = n.oid)
3✔
357
            where n.nspname = {}
3✔
358
                and c.relname = {}
3✔
359
            ",
3✔
360
            quoted_schema, quoted_name
3✔
361
        );
3✔
362

363
        for message in self.postgres_client.simple_query(&table_info_query).await? {
6✔
364
            if let SimpleQueryMessage::Row(row) = message {
6✔
365
                let replica_identity =
3✔
366
                    row.try_get("relreplident")?
3✔
367
                        .ok_or(ReplicationClientError::MissingColumn(
3✔
368
                            "relreplident".to_string(),
3✔
369
                            "pg_class".to_string(),
3✔
370
                        ))?;
3✔
371

372
                if !(replica_identity == "d" || replica_identity == "f") {
3✔
373
                    return Err(ReplicationClientError::ReplicaIdentityNotSupported(
×
374
                        replica_identity.to_string(),
×
375
                    ));
×
376
                }
3✔
377

378
                let oid: u32 = row
3✔
379
                    .try_get("oid")?
3✔
380
                    .ok_or(ReplicationClientError::MissingColumn(
3✔
381
                        "oid".to_string(),
3✔
382
                        "pg_class".to_string(),
3✔
383
                    ))?
3✔
384
                    .parse()
3✔
385
                    .map_err(|_| ReplicationClientError::OidColumnNotU32)?;
3✔
386
                return Ok(Some(oid));
3✔
387
            }
3✔
388
        }
389

390
        Ok(None)
×
391
    }
3✔
392

393
    /// Returns the slot info of an existing slot. The slot info currently only has the
394
    /// confirmed_flush_lsn column of the pg_replication_slots table.
395
    async fn get_slot(&self, slot_name: &str) -> Result<Option<SlotInfo>, ReplicationClientError> {
1✔
396
        let query = format!(
1✔
397
            r#"select confirmed_flush_lsn from pg_replication_slots where slot_name = {};"#,
1✔
398
            quote_literal(slot_name)
1✔
399
        );
1✔
400

401
        let query_result = self.postgres_client.simple_query(&query).await?;
1✔
402

403
        for res in &query_result {
3✔
404
            if let SimpleQueryMessage::Row(row) = res {
2✔
405
                let confirmed_flush_lsn = row
×
406
                    .get("confirmed_flush_lsn")
×
407
                    .ok_or(ReplicationClientError::MissingColumn(
×
408
                        "confirmed_flush_lsn".to_string(),
×
409
                        "pg_replication_slots".to_string(),
×
410
                    ))?
×
411
                    .parse()
×
412
                    .map_err(|_| ReplicationClientError::InvalidPgLsn)?;
×
413

414
                return Ok(Some(SlotInfo {
×
415
                    confirmed_flush_lsn,
×
416
                }));
×
417
            }
2✔
418
        }
419

420
        Ok(None)
1✔
421
    }
1✔
422

423
    /// Creates a logical replication slot. This will only succeed if the postgres connection
424
    /// is in logical replication mode. Otherwise it will fail with the following error:
425
    /// `syntax error at or near "CREATE_REPLICATION_SLOT"``
426
    ///
427
    /// Returns the consistent_point column as slot info.
428
    async fn create_slot(&self, slot_name: &str) -> Result<SlotInfo, ReplicationClientError> {
1✔
429
        let query = format!(
1✔
430
            r#"CREATE_REPLICATION_SLOT {} LOGICAL pgoutput USE_SNAPSHOT"#,
1✔
431
            quote_identifier(slot_name)
1✔
432
        );
1✔
433
        let results = self.postgres_client.simple_query(&query).await?;
1✔
434

435
        for result in results {
2✔
436
            if let SimpleQueryMessage::Row(row) = result {
2✔
437
                let consistent_point: PgLsn = row
1✔
438
                    .get("consistent_point")
1✔
439
                    .ok_or(ReplicationClientError::MissingColumn(
1✔
440
                        "consistent_point".to_string(),
1✔
441
                        "create_replication_slot".to_string(),
1✔
442
                    ))?
1✔
443
                    .parse()
1✔
444
                    .map_err(|_| ReplicationClientError::InvalidPgLsn)?;
1✔
445
                return Ok(SlotInfo {
1✔
446
                    confirmed_flush_lsn: consistent_point,
1✔
447
                });
1✔
448
            }
1✔
449
        }
450
        Err(ReplicationClientError::FailedToCreateSlot)
×
451
    }
1✔
452

453
    /// Either return the slot info of an existing slot or creates a new
454
    /// slot and returns its slot info.
455
    pub async fn get_or_create_slot(
1✔
456
        &mut self,
1✔
457
        slot_name: &str,
1✔
458
    ) -> Result<SlotInfo, ReplicationClientError> {
1✔
459
        if let Some(slot_info) = self.get_slot(slot_name).await? {
1✔
460
            Ok(slot_info)
×
461
        } else {
462
            self.rollback_txn().await?;
1✔
463
            self.begin_readonly_transaction().await?;
1✔
464
            Ok(self.create_slot(slot_name).await?)
1✔
465
        }
466
    }
1✔
467

468
    /// Returns all table names in a publication
469
    pub async fn get_publication_table_names(
1✔
470
        &self,
1✔
471
        publication: &str,
1✔
472
    ) -> Result<Vec<TableName>, ReplicationClientError> {
1✔
473
        let publication_query = format!(
1✔
474
            "select schemaname, tablename from pg_publication_tables where pubname = {};",
1✔
475
            quote_literal(publication)
1✔
476
        );
1✔
477

1✔
478
        let mut table_names = vec![];
1✔
479
        for msg in self
3✔
480
            .postgres_client
1✔
481
            .simple_query(&publication_query)
1✔
482
            .await?
1✔
483
        {
484
            if let SimpleQueryMessage::Row(row) = msg {
3✔
485
                let schema = row
1✔
486
                    .get(0)
1✔
487
                    .ok_or(ReplicationClientError::MissingColumn(
1✔
488
                        "schemaname".to_string(),
1✔
489
                        "pg_publication_tables".to_string(),
1✔
490
                    ))?
1✔
491
                    .to_string();
1✔
492

493
                let name = row
1✔
494
                    .get(1)
1✔
495
                    .ok_or(ReplicationClientError::MissingColumn(
1✔
496
                        "tablename".to_string(),
1✔
497
                        "pg_publication_tables".to_string(),
1✔
498
                    ))?
1✔
499
                    .to_string();
1✔
500

1✔
501
                table_names.push(TableName { schema, name })
1✔
502
            }
2✔
503
        }
504

505
        Ok(table_names)
1✔
506
    }
1✔
507

508
    pub async fn publication_exists(
1✔
509
        &self,
1✔
510
        publication: &str,
1✔
511
    ) -> Result<bool, ReplicationClientError> {
1✔
512
        let publication_exists_query = format!(
1✔
513
            "select 1 as exists from pg_publication where pubname = {};",
1✔
514
            quote_literal(publication)
1✔
515
        );
1✔
516
        for msg in self
2✔
517
            .postgres_client
1✔
518
            .simple_query(&publication_exists_query)
1✔
519
            .await?
1✔
520
        {
521
            if let SimpleQueryMessage::Row(_) = msg {
2✔
522
                return Ok(true);
1✔
523
            }
1✔
524
        }
525
        Ok(false)
×
526
    }
1✔
527

528
    pub async fn get_logical_replication_stream(
1✔
529
        &self,
1✔
530
        publication: &str,
1✔
531
        slot_name: &str,
1✔
532
        start_lsn: PgLsn,
1✔
533
    ) -> Result<LogicalReplicationStream, ReplicationClientError> {
1✔
534
        let options = format!(
1✔
535
            r#"("proto_version" '1', "publication_names" {})"#,
1✔
536
            quote_literal(quote_identifier(publication).as_ref()),
1✔
537
        );
1✔
538

1✔
539
        let query = format!(
1✔
540
            r#"START_REPLICATION SLOT {} LOGICAL {} {}"#,
1✔
541
            quote_identifier(slot_name),
1✔
542
            start_lsn,
1✔
543
            options
1✔
544
        );
1✔
545

546
        let copy_stream = self
1✔
547
            .postgres_client
1✔
548
            .copy_both_simple::<bytes::Bytes>(&query)
1✔
549
            .await?;
1✔
550

551
        let stream = LogicalReplicationStream::new(copy_stream);
1✔
552

1✔
553
        Ok(stream)
1✔
554
    }
1✔
555
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc