• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

uber / cadence / 018e1aca-4829-4e25-951b-23cb08010d99

07 Mar 2024 09:20PM UTC coverage: 63.21% (-0.7%) from 63.932%
018e1aca-4829-4e25-951b-23cb08010d99

push

buildkite

web-flow
Add unit tests for common/persistence/sql/factory.go (#5751)

92665 of 146599 relevant lines covered (63.21%)

2349.68 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

24.0
/common/persistence/sql/sql_queue_store.go
1
// Copyright (c) 2019 Uber Technologies, Inc.
2
//
3
// Permission is hereby granted, free of charge, to any person obtaining a copy
4
// of this software and associated documentation files (the "Software"), to deal
5
// in the Software without restriction, including without limitation the rights
6
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
// copies of the Software, and to permit persons to whom the Software is
8
// furnished to do so, subject to the following conditions:
9
//
10
// The above copyright notice and this permission notice shall be included in
11
// all copies or substantial portions of the Software.
12
//
13
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
// THE SOFTWARE.
20

21
package sql
22

23
import (
24
        "context"
25
        "database/sql"
26
        "fmt"
27

28
        "github.com/uber/cadence/common/log"
29
        "github.com/uber/cadence/common/persistence"
30
        "github.com/uber/cadence/common/persistence/sql/sqlplugin"
31
        "github.com/uber/cadence/common/types"
32
)
33

34
type (
35
        sqlQueueStore struct {
36
                queueType persistence.QueueType
37
                logger    log.Logger
38
                sqlStore
39
        }
40
)
41

42
func newQueueStore(
43
        db sqlplugin.DB,
44
        logger log.Logger,
45
        queueType persistence.QueueType,
46
) (persistence.Queue, error) {
34✔
47
        return &sqlQueueStore{
34✔
48
                sqlStore: sqlStore{
34✔
49
                        db:     db,
34✔
50
                        logger: logger,
34✔
51
                },
34✔
52
                queueType: queueType,
34✔
53
                logger:    logger,
34✔
54
        }, nil
34✔
55
}
34✔
56

57
func (q *sqlQueueStore) EnqueueMessage(
58
        ctx context.Context,
59
        messagePayload []byte,
60
) error {
2✔
61
        return q.txExecute(ctx, sqlplugin.DbDefaultShard, "EnqueueMessage", func(tx sqlplugin.Tx) error {
4✔
62
                lastMessageID, err := tx.GetLastEnqueuedMessageIDForUpdate(ctx, q.queueType)
2✔
63
                if err != nil {
4✔
64
                        if err == sql.ErrNoRows {
4✔
65
                                lastMessageID = -1
2✔
66
                        } else {
2✔
67
                                return err
×
68
                        }
×
69
                }
70

71
                ackLevels, err := tx.GetAckLevels(ctx, q.queueType, true)
2✔
72
                if err != nil {
2✔
73
                        return err
×
74
                }
×
75

76
                _, err = tx.InsertIntoQueue(ctx, newQueueRow(q.queueType, getNextID(ackLevels, lastMessageID), messagePayload))
2✔
77
                return err
2✔
78
        })
79
}
80

81
func (q *sqlQueueStore) ReadMessages(
82
        ctx context.Context,
83
        lastMessageID int64,
84
        maxCount int,
85
) ([]*persistence.InternalQueueMessage, error) {
×
86

×
87
        rows, err := q.db.GetMessagesFromQueue(ctx, q.queueType, lastMessageID, maxCount)
×
88
        if err != nil {
×
89
                return nil, convertCommonErrors(q.db, "ReadMessages", "", err)
×
90
        }
×
91

92
        var messages []*persistence.InternalQueueMessage
×
93
        for _, row := range rows {
×
94
                messages = append(messages, &persistence.InternalQueueMessage{ID: row.MessageID, Payload: row.MessagePayload})
×
95
        }
×
96
        return messages, nil
×
97
}
98

99
func newQueueRow(
100
        queueType persistence.QueueType,
101
        messageID int64,
102
        payload []byte,
103
) *sqlplugin.QueueRow {
2✔
104

2✔
105
        return &sqlplugin.QueueRow{QueueType: queueType, MessageID: messageID, MessagePayload: payload}
2✔
106
}
2✔
107

108
func (q *sqlQueueStore) DeleteMessagesBefore(
109
        ctx context.Context,
110
        messageID int64,
111
) error {
×
112

×
113
        _, err := q.db.DeleteMessagesBefore(ctx, q.queueType, messageID)
×
114
        if err != nil {
×
115
                return convertCommonErrors(q.db, "DeleteMessagesBefore", "", err)
×
116
        }
×
117
        return nil
×
118
}
119

120
func (q *sqlQueueStore) UpdateAckLevel(
121
        ctx context.Context,
122
        messageID int64,
123
        clusterName string,
124
) error {
×
125
        return q.txExecute(ctx, sqlplugin.DbDefaultShard, "UpdateAckLevel", func(tx sqlplugin.Tx) error {
×
126
                clusterAckLevels, err := tx.GetAckLevels(ctx, q.queueType, true)
×
127
                if err != nil {
×
128
                        return err
×
129
                }
×
130

131
                if clusterAckLevels == nil {
×
132
                        return tx.InsertAckLevel(ctx, q.queueType, messageID, clusterName)
×
133
                }
×
134

135
                // Ignore possibly delayed message
136
                if ackLevel, ok := clusterAckLevels[clusterName]; ok && ackLevel >= messageID {
×
137
                        return nil
×
138
                }
×
139

140
                clusterAckLevels[clusterName] = messageID
×
141
                return tx.UpdateAckLevels(ctx, q.queueType, clusterAckLevels)
×
142
        })
143
}
144

145
func (q *sqlQueueStore) GetAckLevels(
146
        ctx context.Context,
147
) (map[string]int64, error) {
×
148
        result, err := q.db.GetAckLevels(ctx, q.queueType, false)
×
149
        if err != nil {
×
150
                return nil, convertCommonErrors(q.db, "GetAckLevels", "", err)
×
151
        }
×
152
        return result, nil
×
153
}
154

155
func (q *sqlQueueStore) EnqueueMessageToDLQ(
156
        ctx context.Context,
157
        messagePayload []byte,
158
) error {
×
159
        return q.txExecute(ctx, sqlplugin.DbDefaultShard, "EnqueueMessageToDLQ", func(tx sqlplugin.Tx) error {
×
160
                var err error
×
161
                lastMessageID, err := tx.GetLastEnqueuedMessageIDForUpdate(ctx, q.getDLQTypeFromQueueType())
×
162
                if err != nil {
×
163
                        if err == sql.ErrNoRows {
×
164
                                lastMessageID = -1
×
165
                        } else {
×
166
                                return err
×
167
                        }
×
168
                }
169
                _, err = tx.InsertIntoQueue(ctx, newQueueRow(q.getDLQTypeFromQueueType(), lastMessageID+1, messagePayload))
×
170
                return err
×
171
        })
172
}
173

174
func (q *sqlQueueStore) ReadMessagesFromDLQ(
175
        ctx context.Context,
176
        firstMessageID int64,
177
        lastMessageID int64,
178
        pageSize int,
179
        pageToken []byte,
180
) ([]*persistence.InternalQueueMessage, []byte, error) {
×
181

×
182
        if len(pageToken) != 0 {
×
183
                lastReadMessageID, err := deserializePageToken(pageToken)
×
184
                if err != nil {
×
185
                        return nil, nil, &types.InternalServiceError{
×
186
                                Message: fmt.Sprintf("invalid next page token %v", pageToken)}
×
187
                }
×
188
                firstMessageID = lastReadMessageID
×
189
        }
190

191
        rows, err := q.db.GetMessagesBetween(ctx, q.getDLQTypeFromQueueType(), firstMessageID, lastMessageID, pageSize)
×
192
        if err != nil {
×
193
                return nil, nil, convertCommonErrors(q.db, "ReadMessagesFromDLQ", "", err)
×
194
        }
×
195

196
        var messages []*persistence.InternalQueueMessage
×
197
        for _, row := range rows {
×
198
                messages = append(messages, &persistence.InternalQueueMessage{ID: row.MessageID, Payload: row.MessagePayload})
×
199
        }
×
200

201
        var newPagingToken []byte
×
202
        if messages != nil && len(messages) >= pageSize {
×
203
                lastReadMessageID := messages[len(messages)-1].ID
×
204
                newPagingToken = serializePageToken(int64(lastReadMessageID))
×
205
        }
×
206
        return messages, newPagingToken, nil
×
207
}
208

209
func (q *sqlQueueStore) DeleteMessageFromDLQ(
210
        ctx context.Context,
211
        messageID int64,
212
) error {
×
213
        _, err := q.db.DeleteMessage(ctx, q.getDLQTypeFromQueueType(), messageID)
×
214
        if err != nil {
×
215
                return convertCommonErrors(q.db, "DeleteMessageFromDLQ", "", err)
×
216
        }
×
217
        return nil
×
218
}
219

220
func (q *sqlQueueStore) RangeDeleteMessagesFromDLQ(
221
        ctx context.Context,
222
        firstMessageID int64,
223
        lastMessageID int64,
224
) error {
×
225
        _, err := q.db.RangeDeleteMessages(ctx, q.getDLQTypeFromQueueType(), firstMessageID, lastMessageID)
×
226
        if err != nil {
×
227
                return convertCommonErrors(q.db, "RangeDeleteMessagesFromDLQ", "", err)
×
228
        }
×
229
        return nil
×
230
}
231

232
func (q *sqlQueueStore) UpdateDLQAckLevel(
233
        ctx context.Context,
234
        messageID int64,
235
        clusterName string,
236
) error {
×
237
        return q.txExecute(ctx, sqlplugin.DbDefaultShard, "UpdateDLQAckLevel", func(tx sqlplugin.Tx) error {
×
238
                clusterAckLevels, err := tx.GetAckLevels(ctx, q.getDLQTypeFromQueueType(), true)
×
239
                if err != nil {
×
240
                        return err
×
241
                }
×
242

243
                if clusterAckLevels == nil {
×
244
                        return tx.InsertAckLevel(ctx, q.getDLQTypeFromQueueType(), messageID, clusterName)
×
245
                }
×
246

247
                // Ignore possibly delayed message
248
                if ackLevel, ok := clusterAckLevels[clusterName]; ok && ackLevel >= messageID {
×
249
                        return nil
×
250
                }
×
251

252
                clusterAckLevels[clusterName] = messageID
×
253
                return tx.UpdateAckLevels(ctx, q.getDLQTypeFromQueueType(), clusterAckLevels)
×
254
        })
255
}
256

257
func (q *sqlQueueStore) GetDLQAckLevels(
258
        ctx context.Context,
259
) (map[string]int64, error) {
×
260
        result, err := q.db.GetAckLevels(ctx, q.getDLQTypeFromQueueType(), false)
×
261
        if err != nil {
×
262
                return nil, convertCommonErrors(q.db, "GetDLQAckLevels", "", err)
×
263
        }
×
264
        return result, nil
×
265
}
266

267
func (q *sqlQueueStore) GetDLQSize(
268
        ctx context.Context,
269
) (int64, error) {
2✔
270
        result, err := q.db.GetQueueSize(ctx, q.getDLQTypeFromQueueType())
2✔
271
        if err != nil {
2✔
272
                return 0, convertCommonErrors(q.db, "GetDLQSize", "", err)
×
273
        }
×
274
        return result, nil
2✔
275
}
276

277
func (q *sqlQueueStore) getDLQTypeFromQueueType() persistence.QueueType {
2✔
278
        return -q.queueType
2✔
279
}
2✔
280

281
// if, for whatever reason, the ack-levels get ahead of the actual messages
282
// then ensure the next ID follows
283
func getNextID(acks map[string]int64, lastMessageID int64) int64 {
2✔
284
        o := lastMessageID
2✔
285
        for _, v := range acks {
2✔
286
                if v > o {
×
287
                        o = v
×
288
                }
×
289
        }
290
        return o + 1
2✔
291
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc