• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

gatewayd-io / gatewayd-plugin-cache / 22264668986

21 Feb 2026 09:24PM UTC coverage: 49.538% (+8.3%) from 41.21%
22264668986

push

github

web-flow
Improvements (#87)

* Fix return -> continue in UpdateCache to prevent goroutine death

The UpdateCache goroutine used `return` instead of `continue` on
error paths (empty database, query parse failure, table extraction
failure). A single transient error would permanently kill the
goroutine, stopping all server-response caching for the plugin's
lifetime. Resolves #64.

* Fix IsCacheNeeded to check parsed query instead of cache key

IsCacheNeeded was called with the full cache key (server:db:request_bytes)
instead of the actual SQL query string. Date/time function detection
against raw binary bytes was unreliable. Now the query is parsed first
via GetQueryFromRequest and uppercased before checking.

* Fix ExitOnStartupError ordering and close(nil) panic in main.go

Move config field assignments (including ExitOnStartupError) before
the API client initialization block so the flag is set when first
checked. Also guard the deferred channel close to prevent a panic
when config is nil and the channel was never initialized.

* Add CacheErrorsCounter, fix double-counting misses, and fix typos

Introduce a dedicated CacheErrorsCounter to separate Redis operation
failures from genuine cache misses. Previously CacheMissesCounter was
incremented for SET/DEL/SCAN errors, making the metric unreliable.
Also fix double-counting in OnTrafficFromClient where a single miss
incremented CacheMissesCounter twice. Fix typos: cachedRespnseKey ->
cachedResponseKey, DateFucntion -> DateFunction.

* Extract duplicated startup error handling into helper function

Replace three copies of the ExitOnStartupError check + manual close +
os.Exit(1) pattern with a single handleStartupError helper that
accepts variadic io.Closer resources. Reduces boilerplate and
ensures consistent behavior across all startup error paths.

* Add explicit config validation with sensible defaults

Previously cast.To* silently returned zero-values for invalid config.
Now redisURL, expiry, and scanCount are... (continued)

20 of 82 new or added lines in 2 files covered. (24.39%)

5 existing lines in 2 files now uncovered.

268 of 541 relevant lines covered (49.54%)

2.16 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.14
/plugin/plugin.go
1
package plugin
2

3
import (
4
        "context"
5
        "encoding/base64"
6
        "strings"
7
        "sync"
8
        "time"
9

10
        sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
11
        "github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres"
12
        sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin"
13
        v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
14
        apiV1 "github.com/gatewayd-io/gatewayd/api/v1"
15
        "github.com/hashicorp/go-hclog"
16
        goplugin "github.com/hashicorp/go-plugin"
17
        goRedis "github.com/redis/go-redis/v9"
18
        "github.com/spf13/cast"
19
        "google.golang.org/grpc"
20
)
21

22
type Plugin struct {
23
        goplugin.GRPCPlugin
24
        v1.GatewayDPluginServiceServer
25

26
        Logger hclog.Logger
27

28
        APIClient apiV1.GatewayDAdminAPIServiceClient
29

30
        // Cache configuration.
31
        RedisClient        *goRedis.Client
32
        RedisURL           string
33
        Expiry             time.Duration
34
        DefaultDBName      string
35
        ScanCount          int64
36
        ExitOnStartupError bool
37

38
        UpdateCacheChannel chan *v1.Struct
39
        WaitGroup          *sync.WaitGroup
40

41
        // Periodic invalidator configuration.
42
        PeriodicInvalidatorEnabled    bool
43
        PeriodicInvalidatorStartDelay time.Duration
44
        PeriodicInvalidatorInterval   time.Duration
45
}
46

47
type CachePlugin struct {
48
        goplugin.NetRPCUnsupportedPlugin
49
        Impl Plugin
50
}
51

52
// Define a set for PostgreSQL date/time functions
53
// https://www.postgresql.org/docs/8.2/functions-datetime.html
54
var pgDateTimeFunctions = map[string]struct{}{
55
        "AGE":                   {},
56
        "CLOCK_TIMESTAMP":       {},
57
        "CURRENT_DATE":          {},
58
        "CURRENT_TIME":          {},
59
        "CURRENT_TIMESTAMP":     {},
60
        "LOCALTIME":             {},
61
        "LOCALTIMESTAMP":        {},
62
        "NOW":                   {},
63
        "STATEMENT_TIMESTAMP":   {},
64
        "TIMEOFDAY":             {},
65
        "TRANSACTION_TIMESTAMP": {},
66
}
67

68
// NewCachePlugin returns a new instance of the CachePlugin.
69
func NewCachePlugin(impl Plugin) *CachePlugin {
7✔
70
        return &CachePlugin{
7✔
71
                NetRPCUnsupportedPlugin: goplugin.NetRPCUnsupportedPlugin{},
7✔
72
                Impl:                    impl,
7✔
73
        }
7✔
74
}
7✔
75

76
// GRPCServer registers the plugin with the gRPC server.
77
func (p *CachePlugin) GRPCServer(_ *goplugin.GRPCBroker, s *grpc.Server) error {
×
78
        v1.RegisterGatewayDPluginServiceServer(s, &p.Impl)
×
79
        return nil
×
80
}
×
81

82
// GRPCClient returns the plugin client.
83
func (p *CachePlugin) GRPCClient(_ context.Context, _ *goplugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
×
84
        return v1.NewGatewayDPluginServiceClient(c), nil
×
85
}
×
86

87
// GetPluginConfig returns the plugin config.
88
func (p *Plugin) GetPluginConfig(
89
        _ context.Context, _ *v1.Struct,
90
) (*v1.Struct, error) {
1✔
91
        GetPluginConfigCounter.Inc()
1✔
92
        return v1.NewStruct(PluginConfig)
1✔
93
}
1✔
94

95
// OnTrafficFromClient is called when a request is received by GatewayD from the client.
96
func (p *Plugin) OnTrafficFromClient(
97
        ctx context.Context, req *v1.Struct,
98
) (*v1.Struct, error) {
4✔
99
        OnTrafficFromClientCounter.Inc()
4✔
100
        req, err := postgres.HandleClientMessage(req, p.Logger)
4✔
101
        if err != nil {
4✔
102
                p.Logger.Info("Failed to handle client message", "error", err)
×
103
        }
×
104

105
        // This is used as a fallback if the database is not found in the startup message.
106
        database := p.DefaultDBName
4✔
107
        if database == "" {
8✔
108
                client := cast.ToStringMapString(sdkPlugin.GetAttr(req, "client", nil))
4✔
109
                database = p.getDBFromStartupMessage(ctx, req, database, client)
4✔
110

4✔
111
                // Get the database from the cache if it's not found in the startup message or
4✔
112
                // if the current request is not a startup message.
4✔
113
                if database == "" {
6✔
114
                        database, err = p.RedisClient.Get(ctx, client["remote"]).Result()
2✔
115
                        if err != nil {
2✔
NEW
116
                                CacheErrorsCounter.Inc()
×
117
                                p.Logger.Debug("Failed to get cache", "error", err)
×
118
                        }
×
119
                        CacheGetsCounter.Inc()
2✔
120
                        p.Logger.Debug("Get the database in the cache for the current session",
2✔
121
                                "database", database, "client", client["remote"])
2✔
122
                }
123
        }
124

125
        // If the database is still not found, return the response as is without caching.
126
        // This might also happen if the cache is cleared while the client is still connected.
127
        // In this case, the client should reconnect and the error will go away.
128
        preconditions := sdkPlugin.GetAttr(req, "sslRequest", "") != "" ||
4✔
129
                sdkPlugin.GetAttr(req, "saslInitialResponse", "") != "" ||
4✔
130
                sdkPlugin.GetAttr(req, "cancelRequest", "") != ""
4✔
131
        if database == "" && !preconditions {
4✔
132
                p.Logger.Error(
×
133
                        "Database name not found or set in cache, startup message or plugin config. Skipping cache")
×
134
                p.Logger.Error("Consider setting the database name in the plugin config or disabling the plugin if you don't need it")
×
135
                return req, nil
×
136
        }
×
137

138
        query := cast.ToString(sdkPlugin.GetAttr(req, "query", ""))
4✔
139
        request := cast.ToString(sdkPlugin.GetAttr(req, "request", ""))
4✔
140
        server := cast.ToStringMapString(sdkPlugin.GetAttr(req, "server", ""))
4✔
141
        cacheKey := strings.Join([]string{server["remote"], database, request}, ":")
4✔
142

4✔
143
        if query == "" {
6✔
144
                return req, nil
2✔
145
        }
2✔
146

147
        p.Logger.Trace("Query", "query", query)
2✔
148

2✔
149
        // Clear the cache if the query is an insert, update or delete query.
2✔
150
        p.invalidateDML(ctx, query)
2✔
151

2✔
152
        // Check if the query is cached.
2✔
153
        response, err := p.RedisClient.Get(ctx, cacheKey).Bytes()
2✔
154
        if err != nil {
3✔
155
                p.Logger.Debug("Failed to get cached response", "error", err)
1✔
156
        }
1✔
157
        CacheGetsCounter.Inc()
2✔
158

2✔
159
        if response == nil {
3✔
160
                // If the query is not cached, return the request as is.
1✔
161
                CacheMissesCounter.Inc()
1✔
162
                return req, nil
1✔
163
        }
1✔
164

165
        // If the query is cached, return the cached response.
166
        signals, err := v1.NewList([]any{
1✔
167
                sdkAct.Terminate().ToMap(),
1✔
168
                sdkAct.Log("debug", "Returning cached response", map[string]any{
1✔
169
                        "cacheKey": []byte(cacheKey),
1✔
170
                        "plugin":   PluginID.GetName(),
1✔
171
                }).ToMap(),
1✔
172
        })
1✔
173
        if err != nil {
1✔
NEW
174
                CacheErrorsCounter.Inc()
×
175
                p.Logger.Error("Failed to create signals", "error", err)
×
176
        } else {
1✔
177
                CacheHitsCounter.Inc()
1✔
178
                // Return the cached response.
1✔
179
                req.Fields[sdkAct.Signals] = v1.NewListValue(signals)
1✔
180
                req.Fields["response"] = v1.NewBytesValue(response)
1✔
181
        }
1✔
182
        return req, nil
1✔
183
}
184

185
// IsCacheNeeded determines if caching is needed.
186
func IsCacheNeeded(upperQuery string) bool {
16✔
187
        // Iterate over each function name in the set of PostgreSQL date/time functions.
16✔
188
        for function := range pgDateTimeFunctions {
141✔
189
                if strings.Contains(upperQuery, function) {
136✔
190
                        // If the query contains a date/time function, caching is not needed.
11✔
191
                        return false
11✔
192
                }
11✔
193
        }
194
        return true
5✔
195
}
196

197
func (p *Plugin) UpdateCache(ctx context.Context) {
3✔
198
        defer p.WaitGroup.Done()
3✔
199
        for {
10✔
200
                serverResponse, ok := <-p.UpdateCacheChannel
7✔
201
                if !ok {
10✔
202
                        p.Logger.Info("Channel closed, returning from function")
3✔
203
                        return
3✔
204
                }
3✔
205

206
                OnTrafficFromServerCounter.Inc()
4✔
207
                resp, err := postgres.HandleServerMessage(serverResponse, p.Logger)
4✔
208
                if err != nil {
4✔
209
                        p.Logger.Info("Failed to handle server message", "error", err)
×
210
                }
×
211

212
                rowDescription := cast.ToString(sdkPlugin.GetAttr(resp, "rowDescription", ""))
4✔
213
                dataRow := cast.ToStringSlice(sdkPlugin.GetAttr(resp, "dataRow", []interface{}{}))
4✔
214
                errorResponse := cast.ToString(sdkPlugin.GetAttr(resp, "errorResponse", ""))
4✔
215
                request, isOk := sdkPlugin.GetAttr(resp, "request", nil).([]byte)
4✔
216
                if !isOk {
4✔
217
                        request = []byte{}
×
218
                }
×
219

220
                response, isOk := sdkPlugin.GetAttr(resp, "response", nil).([]byte)
4✔
221
                if !isOk {
4✔
222
                        response = []byte{}
×
223
                }
×
224
                server := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "server", ""))
4✔
225

4✔
226
                // This is used as a fallback if the database is not found in the startup message.
4✔
227

4✔
228
                database := p.DefaultDBName
4✔
229
                if database == "" {
8✔
230
                        client := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "client", ""))
4✔
231
                        if client != nil && client["remote"] != "" {
8✔
232
                                database, err = p.RedisClient.Get(ctx, client["remote"]).Result()
4✔
233
                                if err != nil {
5✔
234
                                        CacheErrorsCounter.Inc()
1✔
235
                                        p.Logger.Debug("Failed to get cached response", "error", err)
1✔
236
                                }
1✔
237
                                CacheGetsCounter.Inc()
4✔
238
                        }
239
                }
240

241
                // If the database is still not found, return the response as is without caching.
242
                // This might also happen if the cache is cleared while the client is still connected.
243
                // In this case, the client should reconnect and the error will go away.
244
                if database == "" {
5✔
245
                        p.Logger.Debug("Database name not found or set in cache, startup message or plugin config. " +
1✔
246
                                "Skipping cache")
1✔
247
                        p.Logger.Debug("Consider setting the database name in the " +
1✔
248
                                "plugin config or disabling the plugin if you don't need it")
1✔
249
                        continue
1✔
250
                }
251

252
                cacheKey := strings.Join([]string{server["remote"], database, string(request)}, ":")
3✔
253
                if errorResponse != "" || rowDescription == "" || dataRow == nil || len(dataRow) == 0 {
3✔
NEW
254
                        continue
×
255
                }
256

257
                query, err := postgres.GetQueryFromRequest(request)
3✔
258
                if err != nil {
3✔
NEW
259
                        p.Logger.Debug("Failed to get query from request", "error", err)
×
NEW
260
                        continue
×
261
                }
262

263
                if !IsCacheNeeded(strings.ToUpper(query)) {
4✔
264
                        continue
1✔
265
                }
266

267
                // The request was successful and the response contains data. Cache the response.
268
                if err := p.RedisClient.Set(ctx, cacheKey, response, p.Expiry).Err(); err != nil {
2✔
NEW
269
                        CacheErrorsCounter.Inc()
×
NEW
270
                        p.Logger.Debug("Failed to set cache", "error", err)
×
NEW
271
                }
×
272
                CacheSetsCounter.Inc()
2✔
273

2✔
274
                tables, err := postgres.GetTablesFromQuery(query)
2✔
275
                if err != nil {
2✔
NEW
276
                        p.Logger.Debug("Failed to get tables from query", "error", err)
×
NEW
277
                        continue
×
278
                }
279

280
                // Cache the table(s) used in each cached request. This is used to invalidate
281
                // the cache when a rows is inserted, updated or deleted into that table.
282
                for _, table := range tables {
4✔
283
                        requestQueryCacheKey := strings.Join([]string{table, cacheKey}, ":")
2✔
284
                        if err := p.RedisClient.Set(
2✔
285
                                ctx, requestQueryCacheKey, "", p.Expiry).Err(); err != nil {
2✔
NEW
286
                                CacheErrorsCounter.Inc()
×
NEW
287
                                p.Logger.Debug("Failed to set cache", "error", err)
×
UNCOV
288
                        }
×
289
                        CacheSetsCounter.Inc()
2✔
290
                }
291
        }
292
}
293

294
// OnTrafficFromServer is called when a response is received by GatewayD from the server.
295
func (p *Plugin) OnTrafficFromServer(
296
        _ context.Context, resp *v1.Struct,
297
) (*v1.Struct, error) {
2✔
298
        p.Logger.Debug("Traffic is coming from the server side")
2✔
299
        p.UpdateCacheChannel <- resp
2✔
300
        return resp, nil
2✔
301
}
2✔
302

303
func (p *Plugin) OnClosed(ctx context.Context, req *v1.Struct) (*v1.Struct, error) {
2✔
304
        OnClosedCounter.Inc()
2✔
305
        client := cast.ToStringMapString(sdkPlugin.GetAttr(req, "client", nil))
2✔
306
        if client != nil {
4✔
307
                if err := p.RedisClient.Del(ctx, client["remote"]).Err(); err != nil {
2✔
308
                        p.Logger.Debug("Failed to delete cache", "error", err)
×
NEW
309
                        CacheErrorsCounter.Inc()
×
310
                }
×
311
                p.Logger.Debug("Client closed", "client", client["remote"])
2✔
312
                CacheDeletesCounter.Inc()
2✔
313
        }
314
        return req, nil
2✔
315
}
316

317
// invalidateDML invalidates the cache for the tables that are affected by the DML.
318
// This is done by getting the cached queries for each table and deleting them.
319
func (p *Plugin) invalidateDML(ctx context.Context, query string) {
4✔
320
        // Check if the query is a UPDATE, INSERT or DELETE.
4✔
321
        queryDecoded, err := base64.StdEncoding.DecodeString(query)
4✔
322
        if err != nil {
4✔
323
                p.Logger.Debug("Failed to decode query", "error", err)
×
324
                return
×
325
        }
×
326

327
        queryMessage := cast.ToStringMapString(string(queryDecoded))
4✔
328
        p.Logger.Trace("Query message", "query", queryMessage)
4✔
329

4✔
330
        queryString := strings.ToUpper(queryMessage["String"])
4✔
331
        // Ignore SELECT and WITH/SELECT queries.
4✔
332
        // TODO: This is a naive approach, but query parsing has a cost.
4✔
333
        if strings.HasPrefix(queryString, "SELECT") ||
4✔
334
                (strings.HasPrefix(queryString, "WITH") &&
4✔
335
                        strings.Contains(queryString, "SELECT")) {
7✔
336
                return
3✔
337
        }
3✔
338

339
        tables, err := postgres.GetTablesFromQuery(queryMessage["String"])
1✔
340
        if err != nil {
1✔
341
                p.Logger.Debug("Failed to get tables from query", "error", err)
×
342
                return
×
343
        }
×
344

345
        p.Logger.Trace("Tables", "tables", tables)
1✔
346
        for _, table := range tables {
2✔
347
                // Invalidate the cache for the table.
1✔
348
                // TODO: This is not efficient. We should be able to invalidate the cache
1✔
349
                // for a specific key instead of invalidating the entire table.
1✔
350
                pipeline := p.RedisClient.Pipeline()
1✔
351
                var cursor uint64
1✔
352
                for {
2✔
353
                        scanResult := p.RedisClient.Scan(ctx, cursor, table+":*", p.ScanCount)
1✔
354
                        if scanResult.Err() != nil {
1✔
NEW
355
                                CacheErrorsCounter.Inc()
×
356
                                p.Logger.Debug("Failed to scan keys", "error", scanResult.Err())
×
357
                                break
×
358
                        }
359
                        CacheScanCounter.Inc()
1✔
360

1✔
361
                        // Per each key, delete the cache entry and the table cache key itself.
1✔
362
                        var keys []string
1✔
363
                        keys, cursor = scanResult.Val()
1✔
364
                        CacheScanKeysCounter.Add(float64(len(keys)))
1✔
365
                        for _, tableKey := range keys {
2✔
366
                                // Invalidate the cache for the table.
1✔
367
                                cachedResponseKey := strings.TrimPrefix(tableKey, table+":")
1✔
368
                                pipeline.Del(ctx, cachedResponseKey)
1✔
369
                                // Invalidate the table cache key itself.
1✔
370
                                pipeline.Del(ctx, tableKey)
1✔
371
                        }
1✔
372

373
                        if cursor == 0 {
2✔
374
                                break
1✔
375
                        }
376
                }
377

378
                result, err := pipeline.Exec(ctx)
1✔
379
                if err != nil {
1✔
380
                        p.Logger.Debug("Failed to execute pipeline", "error", err)
×
381
                }
×
382

383
                for _, res := range result {
3✔
384
                        if res.Err() != nil {
2✔
NEW
385
                                CacheErrorsCounter.Inc()
×
386
                        } else {
2✔
387
                                CacheDeletesCounter.Inc()
2✔
388
                        }
2✔
389
                }
390
        }
391
}
392

393
// getDBFromStartupMessage gets the database name from the startup message.
394
func (p *Plugin) getDBFromStartupMessage(
395
        ctx context.Context,
396
        req *v1.Struct,
397
        database string,
398
        client map[string]string,
399
) string {
4✔
400
        // Try to get the database from the startup message, which is only sent once by the client.
4✔
401
        // Store the database in the cache so that we can use it for subsequent requests.
4✔
402
        startupMessageEncoded := cast.ToString(sdkPlugin.GetAttr(req, "startupMessage", ""))
4✔
403
        if startupMessageEncoded == "" {
6✔
404
                return database
2✔
405
        }
2✔
406

407
        startupMessageBytes, err := base64.StdEncoding.DecodeString(startupMessageEncoded)
2✔
408
        if err != nil {
2✔
409
                p.Logger.Debug("Failed to decode startup message", "error", err)
×
410
                return database
×
411
        }
×
412

413
        startupMessage := cast.ToStringMap(string(startupMessageBytes))
2✔
414
        p.Logger.Trace("Startup message", "startupMessage", startupMessage, "client", client)
2✔
415
        if startupMessage != nil && client != nil {
4✔
416
                startupMsgParams := cast.ToStringMapString(startupMessage["Parameters"])
2✔
417
                if startupMsgParams != nil &&
2✔
418
                        startupMsgParams["database"] != "" &&
2✔
419
                        client["remote"] != "" {
4✔
420
                        if err := p.RedisClient.Set(
2✔
421
                                ctx, client["remote"],
2✔
422
                                startupMsgParams["database"],
2✔
423
                                time.Duration(0),
2✔
424
                        ).Err(); err != nil {
2✔
NEW
425
                                CacheErrorsCounter.Inc()
×
426
                                p.Logger.Debug("Failed to set cache", "error", err)
×
427
                        }
×
428
                        CacheSetsCounter.Inc()
2✔
429
                        p.Logger.Debug("Set the database in the cache for the current session",
2✔
430
                                "database", database, "client", client["remote"])
2✔
431
                        return startupMsgParams["database"]
2✔
432
                }
433
        }
434

435
        return database
×
436
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc