• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

gatewayd-io / gatewayd / 22284447224

22 Feb 2026 08:10PM UTC coverage: 59.676% (+0.2%) from 59.468%
22284447224

Pull #731

github

mostafa
Fetch latest version
Pull Request #731: Extensive plugin tests

4 of 4 new or added lines in 1 file covered. (100.0%)

38 existing lines in 4 files now uncovered.

5791 of 9704 relevant lines covered (59.68%)

17.73 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.31
/network/proxy.go
1
package network
2

3
import (
4
        "bytes"
5
        "context"
6
        "errors"
7
        "io"
8
        "net"
9
        "time"
10

11
        "github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres"
12
        v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
13
        "github.com/gatewayd-io/gatewayd/config"
14
        gerr "github.com/gatewayd-io/gatewayd/errors"
15
        "github.com/gatewayd-io/gatewayd/metrics"
16
        "github.com/gatewayd-io/gatewayd/plugin"
17
        "github.com/gatewayd-io/gatewayd/pool"
18
        "github.com/getsentry/sentry-go"
19
        "github.com/go-co-op/gocron"
20
        "github.com/rs/zerolog"
21
        "go.opentelemetry.io/otel"
22
        "go.opentelemetry.io/otel/trace"
23
)
24

25
//nolint:interfacebloat
26
type IProxy interface {
27
        Connect(conn *ConnWrapper) *gerr.GatewayDError
28
        Disconnect(conn *ConnWrapper) *gerr.GatewayDError
29
        PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.GatewayDError
30
        PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.GatewayDError
31
        IsHealthy(cl *Client) (*Client, *gerr.GatewayDError)
32
        IsExhausted() bool
33
        Shutdown()
34
        AvailableConnectionsString() []string
35
        BusyConnectionsString() []string
36
        GetGroupName() string
37
        GetBlockName() string
38
        // ExpireBackendReadDeadline expires the backend connection's read deadline
39
        // to unblock any goroutine blocked on Receive().
40
        ExpireBackendReadDeadline(conn *ConnWrapper)
41
        // ClearBackendDeadline clears all deadlines on the backend connection so
42
        // it can be reused for session reset (DISCARD ALL) or other operations.
43
        ClearBackendDeadline(conn *ConnWrapper)
44
}
45

46
type Proxy struct {
47
        GroupName            string
48
        BlockName            string
49
        AvailableConnections pool.IPool
50
        busyConnections      pool.IPool
51
        Logger               zerolog.Logger
52
        PluginRegistry       *plugin.Registry
53
        scheduler            *gocron.Scheduler
54
        ctx                  context.Context //nolint:containedctx
55
        PluginTimeout        time.Duration
56
        HealthCheckPeriod    time.Duration
57

58
        // ClientConfig is used for reconnection
59
        ClientConfig *config.Client
60
}
61

62
var _ IProxy = (*Proxy)(nil)
63

64
// NewProxy creates a new proxy.
65
func NewProxy(
66
        ctx context.Context,
67
        pxy Proxy,
68
) *Proxy {
3✔
69
        proxyCtx, span := otel.Tracer(config.TracerName).Start(ctx, "NewProxy")
3✔
70
        defer span.End()
3✔
71

3✔
72
        proxy := Proxy{
3✔
73
                GroupName:            pxy.GroupName,
3✔
74
                BlockName:            pxy.BlockName,
3✔
75
                AvailableConnections: pxy.AvailableConnections,
3✔
76
                busyConnections:      pool.NewPool(proxyCtx, config.EmptyPoolCapacity),
3✔
77
                Logger:               pxy.Logger,
3✔
78
                PluginRegistry:       pxy.PluginRegistry,
3✔
79
                scheduler:            gocron.NewScheduler(time.UTC),
3✔
80
                ctx:                  proxyCtx,
3✔
81
                PluginTimeout:        pxy.PluginTimeout,
3✔
82
                ClientConfig:         pxy.ClientConfig,
3✔
83
                HealthCheckPeriod:    pxy.HealthCheckPeriod,
3✔
84
        }
3✔
85

3✔
86
        connHealthCheck := func() {
3✔
87
                now := time.Now()
×
88
                proxy.Logger.Trace().Msg("Running the client health check to recycle connection(s).")
×
89
                span.AddEvent("Running the client health check to recycle connection(s).")
×
90
                proxy.AvailableConnections.ForEach(func(_, value any) bool {
×
91
                        client, ok := value.(*Client)
×
92
                        if !ok {
×
93
                                proxy.Logger.Error().Msg("Failed to cast the client to the Client type")
×
94
                                return true
×
95
                        }
×
96

97
                        // Connection is probably dead by now.
98
                        proxy.AvailableConnections.Remove(client.ID)
×
99
                        client.Close()
×
100

×
101
                        // Create a new client.
×
102
                        client = NewClient(
×
103
                                proxyCtx, proxy.ClientConfig, proxy.Logger,
×
104
                                NewRetry(
×
105
                                        Retry{
×
106
                                                Retries: proxy.ClientConfig.Retries,
×
107
                                                Backoff: config.If(
×
108
                                                        proxy.ClientConfig.Backoff > 0,
×
109
                                                        proxy.ClientConfig.Backoff,
×
110
                                                        config.DefaultBackoff,
×
111
                                                ),
×
112
                                                BackoffMultiplier:  proxy.ClientConfig.BackoffMultiplier,
×
113
                                                DisableBackoffCaps: proxy.ClientConfig.DisableBackoffCaps,
×
114
                                                Logger:             proxy.Logger,
×
115
                                        },
×
116
                                ),
×
117
                        )
×
118
                        if client != nil && client.ID != "" {
×
119
                                if err := proxy.AvailableConnections.Put(client.ID, client); err != nil {
×
120
                                        proxy.Logger.Err(err).Msg("Failed to update the client connection")
×
121
                                        // Close the client, because we don't want to have orphaned connections.
×
122
                                        client.Close()
×
123
                                }
×
124
                        } else {
×
125
                                proxy.Logger.Error().Msg("Failed to create a new client connection")
×
126
                                span.RecordError(gerr.ErrClientNotConnected)
×
127
                        }
×
128
                        return true
×
129
                })
130
                proxy.Logger.Trace().
×
131
                        Str("duration", time.Since(now).String()).
×
132
                        Msg("Finished the client health check")
×
133
                span.AddEvent("Finished the client health check")
×
134
                metrics.ProxyHealthChecks.WithLabelValues(
×
135
                        proxy.GetGroupName(), proxy.GetBlockName()).Inc()
×
136
        }
137

138
        // Schedule the client health check.
139
        startDelay := time.Now().Add(proxy.HealthCheckPeriod)
3✔
140
        _, err := proxy.scheduler.
3✔
141
                Every(proxy.HealthCheckPeriod).
3✔
142
                SingletonMode().
3✔
143
                StartAt(startDelay).
3✔
144
                Do(connHealthCheck)
3✔
145
        if err != nil {
3✔
146
                proxy.Logger.Error().Err(err).Msg("Failed to schedule the client health check")
×
147
                sentry.CaptureException(err)
×
148
                span.RecordError(err)
×
149
        }
×
150

151
        // Start the scheduler.
152
        proxy.scheduler.StartAsync()
3✔
153
        proxy.Logger.Info().Fields(
3✔
154
                map[string]any{
3✔
155
                        "startDelay":        startDelay.Format(time.RFC3339),
3✔
156
                        "healthCheckPeriod": proxy.HealthCheckPeriod.String(),
3✔
157
                },
3✔
158
        ).Msg("Started the client health check scheduler")
3✔
159
        span.AddEvent("Started the client health check scheduler")
3✔
160

3✔
161
        return &proxy
3✔
162
}
163

164
func (pr *Proxy) GetBlockName() string {
48✔
165
        return pr.BlockName
48✔
166
}
48✔
167

168
func (pr *Proxy) GetGroupName() string {
42✔
169
        return pr.GroupName
42✔
170
}
42✔
171

172
// Connect maps a server connection from the available connection pool to a incoming connection.
173
// It returns an error if the pool is exhausted.
174
func (pr *Proxy) Connect(conn *ConnWrapper) *gerr.GatewayDError {
2✔
175
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "Connect")
2✔
176
        defer span.End()
2✔
177

2✔
178
        var clientID string
2✔
179
        // Get the first available client from the pool.
2✔
180
        pr.AvailableConnections.ForEach(func(key, _ any) bool {
4✔
181
                if cid, ok := key.(string); ok {
4✔
182
                        clientID = cid
2✔
183
                        return false // stop the loop.
2✔
184
                }
2✔
185
                return true
×
186
        })
187

188
        var client *Client
2✔
189
        if pr.IsExhausted() {
2✔
190
                // Pool is exhausted
×
191
                span.AddEvent(gerr.ErrPoolExhausted.Error())
×
192
                return gerr.ErrPoolExhausted
×
193
        }
×
194
        // Get the client from the pool with the given clientID.
195
        if cl, ok := pr.AvailableConnections.Pop(clientID).(*Client); ok {
4✔
196
                client = cl
2✔
197
        }
2✔
198

199
        client, err := pr.IsHealthy(client)
2✔
200
        if err != nil {
2✔
201
                pr.Logger.Error().Err(err).Msg("Failed to connect to the client")
×
202
                span.RecordError(err)
×
203
        }
×
204

205
        if err := pr.busyConnections.Put(conn, client); err != nil {
2✔
206
                // This should never happen.
×
207
                span.RecordError(err)
×
208
                return err
×
209
        }
×
210

211
        metrics.ProxiedConnections.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Inc()
2✔
212

2✔
213
        fields := map[string]any{
2✔
214
                "function": "proxy.connect",
2✔
215
                "client":   "unknown",
2✔
216
                "server":   RemoteAddr(conn.Conn()),
2✔
217
        }
2✔
218
        if client.ID != "" {
4✔
219
                fields["client"] = client.ID[:7]
2✔
220
        }
2✔
221
        pr.Logger.Debug().Fields(fields).Msg("Client has been assigned")
2✔
222

2✔
223
        pr.Logger.Debug().Fields(
2✔
224
                map[string]any{
2✔
225
                        "function": "proxy.connect",
2✔
226
                        "count":    pr.AvailableConnections.Size(),
2✔
227
                },
2✔
228
        ).Msg("Available client connections")
2✔
229
        pr.Logger.Debug().Fields(
2✔
230
                map[string]any{
2✔
231
                        "function": "proxy.connect",
2✔
232
                        "count":    pr.busyConnections.Size(),
2✔
233
                },
2✔
234
        ).Msg("Busy client connections")
2✔
235

2✔
236
        return nil
2✔
237
}
238

239
// Disconnect removes the client from the busy connection pool and tries to recycle
240
// the server connection.
241
func (pr *Proxy) Disconnect(conn *ConnWrapper) *gerr.GatewayDError {
2✔
242
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "Disconnect")
2✔
243
        defer span.End()
2✔
244

2✔
245
        client := pr.busyConnections.Pop(conn)
2✔
246
        if client == nil {
2✔
UNCOV
247
                // If this ever happens, it means that the client connection
×
UNCOV
248
                // is pre-empted from the busy connections pool.
×
UNCOV
249
                pr.Logger.Debug().Msg("Client connection is pre-empted from the busy connections pool")
×
UNCOV
250
                span.RecordError(gerr.ErrClientNotFound)
×
UNCOV
251
                return gerr.ErrClientNotFound
×
UNCOV
252
        }
×
253

254
        if client, ok := client.(*Client); ok {
4✔
255
                pr.recycleClientConnection(client, span)
2✔
256
        } else {
2✔
257
                // This should never happen, but if it does,
×
258
                // then there are some serious issues with the pool.
×
259
                pr.Logger.Error().Msg("Failed to cast the client to the Client type")
×
260
                span.RecordError(gerr.ErrCastFailed)
×
261
                return gerr.ErrCastFailed
×
262
        }
×
263

264
        metrics.ProxiedConnections.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Dec()
2✔
265

2✔
266
        pr.Logger.Debug().Fields(
2✔
267
                map[string]any{
2✔
268
                        "function": "proxy.disconnect",
2✔
269
                        "count":    pr.AvailableConnections.Size(),
2✔
270
                },
2✔
271
        ).Msg("Available client connections")
2✔
272
        pr.Logger.Debug().Fields(
2✔
273
                map[string]any{
2✔
274
                        "function": "proxy.disconnect",
2✔
275
                        "count":    pr.busyConnections.Size(),
2✔
276
                },
2✔
277
        ).Msg("Busy client connections")
2✔
278

2✔
279
        return nil
2✔
280
}
281

282
// PassThroughToServer sends the data from the client to the server.
283
func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.GatewayDError {
6✔
284
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough")
6✔
285
        defer span.End()
6✔
286

6✔
287
        var client *Client
6✔
288
        // Check if the proxy has a egress client for the incoming connection.
6✔
289
        if pr.busyConnections.Get(conn) == nil {
6✔
290
                span.RecordError(gerr.ErrClientNotFound)
×
291
                return gerr.ErrClientNotFound
×
292
        }
×
293

294
        // Get the client from the busy connection pool.
295
        if cl, ok := pr.busyConnections.Get(conn).(*Client); ok {
12✔
296
                client = cl
6✔
297
        } else {
6✔
298
                span.RecordError(gerr.ErrCastFailed)
×
299
                return gerr.ErrCastFailed
×
300
        }
×
301
        span.AddEvent("Got the client from the busy connection pool")
6✔
302

6✔
303
        if !client.IsConnected() {
6✔
304
                return gerr.ErrClientNotConnected
×
305
        }
×
306

307
        // Receive the request from the client.
308
        request, origErr := pr.receiveTrafficFromClient(conn.Conn())
6✔
309
        span.AddEvent("Received traffic from client")
6✔
310

6✔
311
        // Run the OnTrafficFromClient hooks.
6✔
312
        pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.PluginTimeout)
6✔
313
        defer cancel()
6✔
314

6✔
315
        result, err := pr.PluginRegistry.Run(
6✔
316
                pluginTimeoutCtx,
6✔
317
                trafficData(
6✔
318
                        conn.Conn(),
6✔
319
                        client,
6✔
320
                        []Field{
6✔
321
                                {
6✔
322
                                        Name:  "request",
6✔
323
                                        Value: request,
6✔
324
                                },
6✔
325
                        },
6✔
326
                        origErr),
6✔
327
                v1.HookName_HOOK_NAME_ON_TRAFFIC_FROM_CLIENT)
6✔
328
        if err != nil {
6✔
329
                pr.Logger.Error().Err(err).Msg("Error running hook")
×
330
                span.RecordError(err)
×
331
        }
×
332
        span.AddEvent("Ran the OnTrafficFromClient hooks")
6✔
333

6✔
334
        if origErr != nil && errors.Is(origErr, io.EOF) {
8✔
335
                // Client closed the connection.
2✔
336
                span.AddEvent("Client closed the connection")
2✔
337
                return gerr.ErrClientNotConnected.Wrap(origErr)
2✔
338
        }
2✔
339

340
        // Check if the client sent a SSL request and the server supports SSL.
341
        //nolint:nestif
342
        if conn.IsTLSEnabled() && postgres.IsPostgresSSLRequest(request) {
4✔
343
                // Perform TLS handshake.
×
344
                if err := conn.UpgradeToTLS(func(net.Conn) {
×
345
                        // Acknowledge the SSL request:
×
346
                        // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SSL
×
347
                        if sent, err := conn.Write([]byte{'S'}); err != nil {
×
348
                                pr.Logger.Error().Err(err).Msg("Failed to acknowledge the SSL request")
×
349
                                span.RecordError(err)
×
350
                        } else {
×
351
                                pr.Logger.Debug().Fields(
×
352
                                        map[string]any{
×
353
                                                "function": "upgradeToTLS",
×
354
                                                "local":    LocalAddr(conn.Conn()),
×
355
                                                "remote":   RemoteAddr(conn.Conn()),
×
356
                                                "length":   sent,
×
357
                                        },
×
358
                                ).Msg("Sent data to database")
×
359
                        }
×
360
                }); err != nil {
×
361
                        pr.Logger.Error().Err(err).Msg("Failed to perform the TLS handshake")
×
362
                        span.RecordError(err)
×
363
                }
×
364

365
                // Check if the TLS handshake was successful.
366
                if conn.IsTLSEnabled() {
×
367
                        pr.Logger.Debug().Fields(
×
368
                                map[string]any{
×
369
                                        "local":  LocalAddr(conn.Conn()),
×
370
                                        "remote": RemoteAddr(conn.Conn()),
×
371
                                },
×
372
                        ).Msg("Performed the TLS handshake")
×
373
                        span.AddEvent("Performed the TLS handshake")
×
374
                        metrics.TLSConnections.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Inc()
×
375
                } else {
×
376
                        pr.Logger.Error().Fields(
×
377
                                map[string]any{
×
378
                                        "local":  LocalAddr(conn.Conn()),
×
379
                                        "remote": RemoteAddr(conn.Conn()),
×
380
                                },
×
381
                        ).Msg("Failed to perform the TLS handshake")
×
382
                        span.AddEvent("Failed to perform the TLS handshake")
×
383
                }
×
384

385
                // This return causes the client to start sending
386
                // StartupMessage over the TLS connection.
387
                return nil
×
388
        } else if !conn.IsTLSEnabled() && postgres.IsPostgresSSLRequest(request) {
4✔
389
                // Client sent a SSL request, but the server does not support SSL.
×
390

×
391
                pr.Logger.Warn().Fields(
×
392
                        map[string]any{
×
393
                                "local":  LocalAddr(conn.Conn()),
×
394
                                "remote": RemoteAddr(conn.Conn()),
×
395
                        },
×
396
                ).Msg("Server does not support SSL, but SSL was requested by the client")
×
397
                span.AddEvent("Server does not support SSL, but SSL was requested by the client")
×
398

×
399
                // Server does not support SSL, and SSL was preferred by the client,
×
400
                // so we need to switch to a plaintext connection:
×
401
                // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SSL
×
402
                if _, err := conn.Write([]byte{'N'}); err != nil {
×
403
                        pr.Logger.Warn().Err(err).Msg("Server does not support SSL, but SSL was required by the client")
×
404
                        span.RecordError(err)
×
405
                }
×
406

407
                // This return causes the client to start sending
408
                // StartupMessage over the plaintext connection.
409
                return nil
×
410
        }
411

412
        // Push the client's request to the stack.
413
        stack.Push(&Request{Data: request})
4✔
414

4✔
415
        // If the hook wants to terminate the connection, do it.
4✔
416
        if terminate, resp := pr.shouldTerminate(result); terminate {
4✔
417
                if resp != nil {
×
418
                        pr.Logger.Trace().Fields(
×
419
                                map[string]any{
×
420
                                        "function": "proxy.passthrough",
×
421
                                        "result":   resp,
×
422
                                },
×
423
                        ).Msg("Terminating connection with a result from the action")
×
424

×
425
                        // If the terminate action returned a result, use it.
×
426
                        result = resp
×
427
                }
×
428

429
                if modResponse, modReceived := pr.getPluginModifiedResponse(result); modResponse != nil {
×
430
                        metrics.ProxyPassThroughsToClient.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Inc()
×
431
                        metrics.ProxyPassThroughTerminations.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Inc()
×
432
                        metrics.BytesSentToClient.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(modReceived))
×
433
                        metrics.TotalTrafficBytes.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(modReceived))
×
434

×
435
                        span.AddEvent("Terminating connection")
×
436

×
437
                        // Remove the request from the stack if the response is modified.
×
438
                        stack.PopLastRequest()
×
439

×
440
                        return pr.sendTrafficToClient(conn.Conn(), modResponse, modReceived)
×
441
                }
×
442
                span.RecordError(gerr.ErrHookTerminatedConnection)
×
443
                return gerr.ErrHookTerminatedConnection
×
444
        }
445
        // If the hook modified the request, use the modified request.
446
        if modRequest := pr.getPluginModifiedRequest(result); modRequest != nil {
8✔
447
                request = modRequest
4✔
448
                span.AddEvent("Plugin(s) modified the request")
4✔
449
        }
4✔
450

451
        stack.UpdateLastRequest(&Request{Data: request})
4✔
452

4✔
453
        // Send the request to the server.
4✔
454
        _, err = pr.sendTrafficToServer(client, request)
4✔
455
        span.AddEvent("Sent traffic to server")
4✔
456

4✔
457
        pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.PluginTimeout)
4✔
458
        defer cancel()
4✔
459

4✔
460
        // Run the OnTrafficToServer hooks.
4✔
461
        result, err = pr.PluginRegistry.Run(
4✔
462
                pluginTimeoutCtx,
4✔
463
                trafficData(
4✔
464
                        conn.Conn(),
4✔
465
                        client,
4✔
466
                        []Field{
4✔
467
                                {
4✔
468
                                        Name:  "request",
4✔
469
                                        Value: request,
4✔
470
                                },
4✔
471
                        },
4✔
472
                        err),
4✔
473
                v1.HookName_HOOK_NAME_ON_TRAFFIC_TO_SERVER)
4✔
474
        if err != nil {
4✔
475
                pr.Logger.Error().Err(err).Msg("Error running hook")
×
476
                span.RecordError(err)
×
477
        }
×
478
        if result != nil {
8✔
479
                _ = pr.PluginRegistry.ActRegistry.RunAll(result)
4✔
480
        }
4✔
481

482
        span.AddEvent("Ran the OnTrafficToServer hooks")
4✔
483
        metrics.ProxyPassThroughsToServer.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Inc()
4✔
484

4✔
485
        return nil
4✔
486
}
487

488
// PassThroughToClient sends the data from the server to the client.
489
func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.GatewayDError {
4✔
490
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "PassThrough")
4✔
491
        defer span.End()
4✔
492

4✔
493
        var client *Client
4✔
494
        // Check if the proxy has a egress client for the incoming connection.
4✔
495
        if pr.busyConnections.Get(conn) == nil {
4✔
UNCOV
496
                span.RecordError(gerr.ErrClientNotFound)
×
UNCOV
497
                return gerr.ErrClientNotFound
×
UNCOV
498
        }
×
499

500
        // Get the client from the busy connection pool.
501
        if cl, ok := pr.busyConnections.Get(conn).(*Client); ok {
8✔
502
                client = cl
4✔
503
        } else {
4✔
504
                span.RecordError(gerr.ErrCastFailed)
×
505
                return gerr.ErrCastFailed
×
506
        }
×
507
        span.AddEvent("Got the client from the busy connection pool")
4✔
508

4✔
509
        if !client.IsConnected() {
4✔
UNCOV
510
                return gerr.ErrClientNotConnected
×
UNCOV
511
        }
×
512

513
        // Receive the response from the server.
514
        received, response, err := pr.receiveTrafficFromServer(client)
4✔
515
        span.AddEvent("Received traffic from server")
4✔
516

4✔
517
        // If there is no data to send to the client,
4✔
518
        // we don't need to run the hooks and
4✔
519
        // we obviously have no data to send to the client.
4✔
520
        if received == 0 {
6✔
521
                span.AddEvent("No data to send to client")
2✔
522
                stack.PopLastRequest()
2✔
523
                if err != nil {
4✔
524
                        span.RecordError(err)
2✔
525
                        return err
2✔
526
                }
2✔
UNCOV
527
                return nil
×
528
        }
529

530
        // If there is an error, close the ingress connection.
531
        if err != nil {
2✔
532
                fields := map[string]any{"function": "proxy.passthrough"}
×
533
                if client.LocalAddr() != "" {
×
534
                        fields["localAddr"] = client.LocalAddr()
×
535
                }
×
536
                if client.RemoteAddr() != "" {
×
537
                        fields["remoteAddr"] = client.RemoteAddr()
×
538
                }
×
539
                pr.Logger.Debug().Fields(fields).Msg("No data to send to client")
×
540
                span.AddEvent("No data to send to client")
×
541
                span.RecordError(err)
×
542

×
543
                stack.PopLastRequest()
×
544

×
545
                return err
×
546
        }
547

548
        pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), pr.PluginTimeout)
2✔
549
        defer cancel()
2✔
550

2✔
551
        // Get the last request from the stack.
2✔
552
        lastRequest := stack.PopLastRequest()
2✔
553
        request := []byte{}
2✔
554
        if lastRequest != nil {
4✔
555
                request = lastRequest.Data
2✔
556
        }
2✔
557

558
        // Run the OnTrafficFromServer hooks.
559
        result, err := pr.PluginRegistry.Run(
2✔
560
                pluginTimeoutCtx,
2✔
561
                trafficData(
2✔
562
                        conn.Conn(),
2✔
563
                        client,
2✔
564
                        []Field{
2✔
565
                                {
2✔
566
                                        Name:  "request",
2✔
567
                                        Value: request,
2✔
568
                                },
2✔
569
                                {
2✔
570
                                        Name:  "response",
2✔
571
                                        Value: response[:received],
2✔
572
                                },
2✔
573
                        },
2✔
574
                        err),
2✔
575
                v1.HookName_HOOK_NAME_ON_TRAFFIC_FROM_SERVER)
2✔
576
        if err != nil {
2✔
577
                pr.Logger.Error().Err(err).Msg("Error running hook")
×
578
                span.RecordError(err)
×
579
        }
×
580
        if result != nil {
4✔
581
                result = pr.PluginRegistry.ActRegistry.RunAll(result)
2✔
582
        }
2✔
583
        span.AddEvent("Ran the OnTrafficFromServer hooks")
2✔
584

2✔
585
        // If the hook modified the response, use the modified response.
2✔
586
        if modResponse, modReceived := pr.getPluginModifiedResponse(result); modResponse != nil {
2✔
587
                response = modResponse
×
588
                received = modReceived
×
589
                span.AddEvent("Plugin(s) modified the response")
×
590
        }
×
591

592
        // Send the response to the client.
593
        errVerdict := pr.sendTrafficToClient(conn.Conn(), response, received)
2✔
594
        span.AddEvent("Sent traffic to client")
2✔
595

2✔
596
        // Run the OnTrafficToClient hooks.
2✔
597
        pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.PluginTimeout)
2✔
598
        defer cancel()
2✔
599

2✔
600
        result, err = pr.PluginRegistry.Run(
2✔
601
                pluginTimeoutCtx,
2✔
602
                trafficData(
2✔
603
                        conn.Conn(),
2✔
604
                        client,
2✔
605
                        []Field{
2✔
606
                                {
2✔
607
                                        Name:  "request",
2✔
608
                                        Value: request,
2✔
609
                                },
2✔
610
                                {
2✔
611
                                        Name:  "response",
2✔
612
                                        Value: response[:received],
2✔
613
                                },
2✔
614
                        },
2✔
615
                        nil,
2✔
616
                ),
2✔
617
                v1.HookName_HOOK_NAME_ON_TRAFFIC_TO_CLIENT)
2✔
618
        if err != nil {
2✔
619
                pr.Logger.Error().Err(err).Msg("Error running hook")
×
620
                span.RecordError(err)
×
621
        }
×
622
        if result != nil {
4✔
623
                _ = pr.PluginRegistry.ActRegistry.RunAll(result)
2✔
624
        }
2✔
625
        span.AddEvent("Ran the OnTrafficToClient hooks")
2✔
626

2✔
627
        if errVerdict != nil {
2✔
628
                span.RecordError(errVerdict)
×
629
        }
×
630

631
        metrics.ProxyPassThroughsToClient.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Inc()
2✔
632

2✔
633
        return errVerdict
2✔
634
}
635

636
// IsHealthy checks if the pool is exhausted or the client is disconnected.
637
func (pr *Proxy) IsHealthy(client *Client) (*Client, *gerr.GatewayDError) {
3✔
638
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "IsHealthy")
3✔
639
        defer span.End()
3✔
640

3✔
641
        if pr.IsExhausted() {
3✔
642
                pr.Logger.Error().Msg("No more available connections")
×
643
                span.RecordError(gerr.ErrPoolExhausted)
×
644
                return client, gerr.ErrPoolExhausted
×
645
        }
×
646

647
        if !client.IsConnected() {
3✔
648
                pr.Logger.Error().Msg("Client is disconnected")
×
649
                span.RecordError(gerr.ErrClientNotConnected)
×
650
        }
×
651

652
        return client, nil
3✔
653
}
654

655
// IsExhausted checks if the available connection pool is exhausted.
656
func (pr *Proxy) IsExhausted() bool {
6✔
657
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "IsExhausted")
6✔
658
        defer span.End()
6✔
659
        return pr.AvailableConnections.Size() == 0 && pr.AvailableConnections.Cap() > 0
6✔
660
}
6✔
661

662
// Shutdown closes all connections and clears the connection pools.
663
func (pr *Proxy) Shutdown() {
5✔
664
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "Shutdown")
5✔
665
        defer span.End()
5✔
666

5✔
667
        pr.AvailableConnections.ForEach(func(_, value any) bool {
11✔
668
                if client, ok := value.(*Client); ok {
12✔
669
                        if client.IsConnected() {
12✔
670
                                client.Close()
6✔
671
                        }
6✔
672
                }
673
                return true
6✔
674
        })
675
        pr.AvailableConnections.Clear()
5✔
676
        pr.Logger.Debug().Msg("All available connections have been closed")
5✔
677

5✔
678
        pr.busyConnections.ForEach(func(key, value any) bool {
5✔
UNCOV
679
                if conn, ok := key.(net.Conn); ok {
×
680
                        // This will stop all the Conn.Read() and Conn.Write() calls.
×
681
                        if err := conn.SetDeadline(time.Now()); err != nil {
×
682
                                pr.Logger.Error().Err(err).Msg("Error setting the deadline")
×
683
                                span.RecordError(err)
×
684
                        }
×
685
                        if err := conn.Close(); err != nil {
×
686
                                pr.Logger.Error().Err(err).Msg("Failed to close the connection")
×
687
                                span.RecordError(err)
×
688
                        }
×
689
                }
UNCOV
690
                if client, ok := value.(*Client); ok {
×
UNCOV
691
                        if client != nil {
×
UNCOV
692
                                client.Close()
×
UNCOV
693
                        }
×
694
                }
UNCOV
695
                return true
×
696
        })
697
        pr.busyConnections.Clear()
5✔
698
        pr.scheduler.Stop()
5✔
699
        pr.scheduler.Clear()
5✔
700
        pr.Logger.Debug().Msg("All busy connections have been closed")
5✔
701
}
702

703
// AvailableConnectionsString returns a list of available connections. This list enumerates
704
// the local addresses of the outgoing connections to the server.
705
func (pr *Proxy) AvailableConnectionsString() []string {
×
706
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "AvailableConnections")
×
707
        defer span.End()
×
708

×
709
        connections := make([]string, 0)
×
710
        pr.AvailableConnections.ForEach(func(_, value any) bool {
×
711
                if cl, ok := value.(*Client); ok {
×
712
                        connections = append(connections, cl.LocalAddr())
×
713
                }
×
714
                return true
×
715
        })
716
        return connections
×
717
}
718

719
// BusyConnectionsString returns a list of busy connections. This list enumerates
720
// the remote addresses of the incoming connections from a database client like psql.
721
func (pr *Proxy) BusyConnectionsString() []string {
×
722
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "BusyConnectionsString")
×
723
        defer span.End()
×
724

×
725
        connections := make([]string, 0)
×
726
        pr.busyConnections.ForEach(func(key, _ any) bool {
×
727
                if conn, ok := key.(*ConnWrapper); ok {
×
728
                        connections = append(connections, RemoteAddr(conn.Conn()))
×
729
                }
×
730
                return true
×
731
        })
732
        return connections
×
733
}
734

735
// ExpireBackendReadDeadline sets the backend connection's read deadline to now,
736
// causing any pending Receive() call to return immediately with a deadline exceeded error.
737
// This is used to unblock the server->client goroutine when the client->server goroutine exits.
738
func (pr *Proxy) ExpireBackendReadDeadline(conn *ConnWrapper) {
2✔
739
        if cl, ok := pr.busyConnections.Get(conn).(*Client); ok && cl != nil && cl.conn != nil {
4✔
740
                if err := cl.conn.SetReadDeadline(time.Now()); err != nil {
2✔
741
                        pr.Logger.Error().Err(err).Msg("Failed to expire backend read deadline")
×
742
                }
×
743
        }
744
}
745

746
// ClearBackendDeadline clears any deadline on the backend connection so it can
747
// be reused for session reset (DISCARD ALL) or other operations after both
748
// traffic goroutines have exited.
749
func (pr *Proxy) ClearBackendDeadline(conn *ConnWrapper) {
2✔
750
        if cl, ok := pr.busyConnections.Get(conn).(*Client); ok && cl != nil && cl.conn != nil {
4✔
751
                if err := cl.conn.SetDeadline(time.Time{}); err != nil {
2✔
UNCOV
752
                        // During shutdown, the backend connection may already be closed by
×
UNCOV
753
                        // proxy.Shutdown(), so SetDeadline will fail with net.ErrClosed.
×
UNCOV
754
                        if errors.Is(err, net.ErrClosed) {
×
UNCOV
755
                                pr.Logger.Debug().Err(err).Msg("Backend connection already closed, skipping deadline clear")
×
UNCOV
756
                        } else {
×
757
                                pr.Logger.Error().Err(err).Msg("Failed to clear backend deadline")
×
758
                        }
×
759
                }
760
        }
761
}
762

763
// recycleClientConnection resets or reconnects the client connection and returns it to the pool.
764
func (pr *Proxy) recycleClientConnection(client *Client, span trace.Span) {
2✔
765
        // Try to reset the session without tearing down the TCP connection.
2✔
766
        // This sends DISCARD ALL and is much cheaper than a full reconnect.
2✔
767
        // If it fails (broken pipe, bad state, etc.), fall back to a full reconnect.
2✔
768
        if client.StartupParams == nil || !client.IsConnected() {
4✔
769
                // No startup params configured or client disconnected:
2✔
770
                // use the original reconnect behavior.
2✔
771
                if err := client.Reconnect(); err != nil {
2✔
772
                        pr.Logger.Error().Err(err).Msg("Failed to reconnect to the client")
×
773
                        span.RecordError(err)
×
774
                }
×
775
        } else if err := client.ResetSession(); err != nil {
×
776
                // Reset failed (broken pipe, bad state, etc.), fall back to full reconnect.
×
777
                pr.Logger.Warn().Err(err).Msg(
×
778
                        "Session reset failed, falling back to full reconnect")
×
779
                span.RecordError(err)
×
780
                if err := client.Reconnect(); err != nil {
×
781
                        pr.Logger.Error().Err(err).Msg("Failed to reconnect to the client")
×
782
                        span.RecordError(err)
×
783
                }
×
784
        }
785

786
        // If the client is not in the pool, put it back.
787
        if err := pr.AvailableConnections.Put(client.ID, client); err != nil {
2✔
788
                pr.Logger.Error().Err(err).Msg("Failed to put the client back in the pool")
×
789
                span.RecordError(err)
×
790
        }
×
791
}
792

793
// receiveTrafficFromClient is a function that waits to receive data from the client.
794
func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayDError) {
6✔
795
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "receiveTrafficFromClient")
6✔
796
        defer span.End()
6✔
797

6✔
798
        // request contains the data from the client.
6✔
799
        total := 0
6✔
800
        buffer := bytes.NewBuffer(nil)
6✔
801
        for {
12✔
802
                chunk := make([]byte, pr.ClientConfig.ReceiveChunkSize)
6✔
803
                read, err := conn.Read(chunk)
6✔
804
                if read > 0 {
10✔
805
                        total += read
4✔
806
                        buffer.Write(chunk[:read])
4✔
807
                }
4✔
808
                if read == 0 || err != nil {
8✔
809
                        pr.Logger.Debug().Err(err).Msg("Error reading from client")
2✔
810
                        span.RecordError(err)
2✔
811

2✔
812
                        metrics.BytesReceivedFromClient.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(read))
2✔
813
                        metrics.TotalTrafficBytes.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(read))
2✔
814

2✔
815
                        return buffer.Bytes(), gerr.ErrReadFailed.Wrap(err)
2✔
816
                }
2✔
817

818
                if read < pr.ClientConfig.ReceiveChunkSize {
8✔
819
                        break
4✔
820
                }
821

822
                if !pr.isConnectionHealthy(conn) {
×
823
                        break
×
824
                }
825
        }
826

827
        pr.Logger.Debug().Fields(
4✔
828
                map[string]any{
4✔
829
                        "length": total,
4✔
830
                        "local":  LocalAddr(conn),
4✔
831
                        "remote": RemoteAddr(conn),
4✔
832
                },
4✔
833
        ).Msg("Received data from client")
4✔
834

4✔
835
        span.AddEvent("Received data from client")
4✔
836

4✔
837
        metrics.BytesReceivedFromClient.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(total))
4✔
838
        metrics.TotalTrafficBytes.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(total))
4✔
839

4✔
840
        return buffer.Bytes(), nil
4✔
841
}
842

843
// sendTrafficToServer is a function that sends data to the server.
844
func (pr *Proxy) sendTrafficToServer(client *Client, request []byte) (int, *gerr.GatewayDError) {
4✔
845
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "sendTrafficToServer")
4✔
846
        defer span.End()
4✔
847

4✔
848
        if len(request) == 0 {
4✔
849
                pr.Logger.Trace().Msg("Empty request")
×
850
                return 0, nil
×
851
        }
×
852

853
        // Send the request to the server.
854
        sent, err := client.Send(request)
4✔
855
        if err != nil {
4✔
856
                pr.Logger.Error().Err(err).Msg("Error sending request to database")
×
857
                span.RecordError(err)
×
858
        }
×
859
        pr.Logger.Debug().Fields(
4✔
860
                map[string]any{
4✔
861
                        "function": "proxy.passthrough",
4✔
862
                        "length":   sent,
4✔
863
                        "local":    client.LocalAddr(),
4✔
864
                        "remote":   client.RemoteAddr(),
4✔
865
                },
4✔
866
        ).Msg("Sent data to database")
4✔
867

4✔
868
        span.AddEvent("Sent data to database")
4✔
869

4✔
870
        metrics.BytesSentToServer.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(sent))
4✔
871
        metrics.TotalTrafficBytes.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(sent))
4✔
872

4✔
873
        return sent, err
4✔
874
}
875

876
// receiveTrafficFromServer is a function that receives data from the server.
877
func (pr *Proxy) receiveTrafficFromServer(client *Client) (int, []byte, *gerr.GatewayDError) {
4✔
878
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "receiveTrafficFromServer")
4✔
879
        defer span.End()
4✔
880

4✔
881
        // Receive the response from the server.
4✔
882
        received, response, err := client.Receive()
4✔
883

4✔
884
        fields := map[string]any{
4✔
885
                "function": "proxy.passthrough",
4✔
886
                "length":   received,
4✔
887
        }
4✔
888
        if client.LocalAddr() != "" {
8✔
889
                fields["local"] = client.LocalAddr()
4✔
890
        }
4✔
891
        if client.RemoteAddr() != "" {
8✔
892
                fields["remote"] = client.RemoteAddr()
4✔
893
        }
4✔
894

895
        pr.Logger.Debug().Fields(fields).Msg("Received data from database")
4✔
896

4✔
897
        span.AddEvent("Received data from database")
4✔
898

4✔
899
        metrics.BytesReceivedFromServer.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(received))
4✔
900
        metrics.TotalTrafficBytes.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(received))
4✔
901

4✔
902
        return received, response, err
4✔
903
}
904

905
// sendTrafficToClient is a function that sends data to the client.
906
func (pr *Proxy) sendTrafficToClient(
907
        conn net.Conn, response []byte, received int,
908
) *gerr.GatewayDError {
2✔
909
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "sendTrafficToClient")
2✔
910
        defer span.End()
2✔
911

2✔
912
        // Send the response to the client async.
2✔
913
        sent := 0
2✔
914
        for {
6✔
915
                if sent >= received {
6✔
916
                        break
2✔
917
                }
918

919
                written, origErr := conn.Write(response[:received])
2✔
920
                if origErr != nil {
2✔
921
                        pr.Logger.Error().Err(origErr).Msg("Error writing to client")
×
922
                        span.RecordError(origErr)
×
923
                        return gerr.ErrServerSendFailed.Wrap(origErr)
×
924
                }
×
925

926
                sent += written
2✔
927
        }
928

929
        pr.Logger.Debug().Fields(
2✔
930
                map[string]any{
2✔
931
                        "function": "proxy.passthrough",
2✔
932
                        "length":   sent,
2✔
933
                        "local":    LocalAddr(conn),
2✔
934
                        "remote":   RemoteAddr(conn),
2✔
935
                },
2✔
936
        ).Msg("Sent data to client")
2✔
937

2✔
938
        span.AddEvent("Sent data to client")
2✔
939

2✔
940
        metrics.BytesSentToClient.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(received))
2✔
941
        metrics.TotalTrafficBytes.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Observe(float64(received))
2✔
942

2✔
943
        return nil
2✔
944
}
945

946
// shouldTerminate is a function that retrieves the terminate field from the hook result.
947
// Only the OnTrafficFromClient hook will terminate the request.
948
func (pr *Proxy) shouldTerminate(result map[string]any) (bool, map[string]any) {
4✔
949
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "shouldTerminate")
4✔
950
        defer span.End()
4✔
951

4✔
952
        if result == nil {
4✔
953
                return false, result
×
954
        }
×
955

956
        terminate := pr.PluginRegistry.ActRegistry.ShouldTerminate(result)
4✔
957
        actionResult := pr.PluginRegistry.ActRegistry.RunAll(result)
4✔
958
        if terminate {
4✔
959
                pr.Logger.Debug().Fields(
×
960
                        map[string]any{
×
961
                                "function": "proxy.passthrough",
×
962
                                "reason":   "terminate",
×
963
                        },
×
964
                ).Msg("Terminating request")
×
965
        }
×
966
        return terminate, actionResult
4✔
967
}
968

969
// getPluginModifiedRequest is a function that retrieves the modified request
970
// from the hook result.
971
func (pr *Proxy) getPluginModifiedRequest(result map[string]any) []byte {
4✔
972
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "getPluginModifiedRequest")
4✔
973
        defer span.End()
4✔
974

4✔
975
        // If the hook modified the request, use the modified request.
4✔
976
        if modRequest, errMsg := extractFieldValue(result, "request"); errMsg != "" {
4✔
977
                pr.Logger.Error().Str("error", errMsg).Msg("Error in hook")
×
978
        } else if modRequest != nil {
8✔
979
                return modRequest
4✔
980
        }
4✔
981

982
        return nil
×
983
}
984

985
// getPluginModifiedResponse is a function that retrieves the modified response
986
// from the hook result.
987
func (pr *Proxy) getPluginModifiedResponse(result map[string]any) ([]byte, int) {
2✔
988
        _, span := otel.Tracer(config.TracerName).Start(pr.ctx, "getPluginModifiedResponse")
2✔
989
        defer span.End()
2✔
990

2✔
991
        // If the hook returns a response, use it instead of the original response.
2✔
992
        if modResponse, errMsg := extractFieldValue(result, "response"); errMsg != "" {
2✔
993
                pr.Logger.Error().Str("error", errMsg).Msg("Error in hook")
×
994
        } else if modResponse != nil {
2✔
995
                return modResponse, len(modResponse)
×
996
        }
×
997

998
        return nil, 0
2✔
999
}
1000

1001
func (pr *Proxy) isConnectionHealthy(conn net.Conn) bool {
×
1002
        if n, err := conn.Read([]byte{}); n == 0 && err != nil {
×
1003
                pr.Logger.Debug().Fields(
×
1004
                        map[string]any{
×
1005
                                "remote": RemoteAddr(conn),
×
1006
                                "local":  LocalAddr(conn),
×
1007
                                "reason": "read 0 bytes",
×
1008
                        }).Msg("Connection to client is closed")
×
1009
                return false
×
1010
        }
×
1011

1012
        return true
×
1013
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc