• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

uber / cadence / 018e1f2e-cd91-4a89-a804-d5e16440a906

08 Mar 2024 05:49PM UTC coverage: 63.815% (+0.2%) from 63.663%
018e1f2e-cd91-4a89-a804-d5e16440a906

push

buildkite

web-flow
Code cleanup for sql package (#5756)

93026 of 145775 relevant lines covered (63.81%)

2363.45 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

54.32
/common/persistence/sql/sql_task_store.go
1
// Copyright (c) 2020 Uber Technologies, Inc.
2
// Portions of the Software are attributed to Copyright (c) 2020 Temporal Technologies Inc.
3
//
4
// Permission is hereby granted, free of charge, to any person obtaining a copy
5
// of this software and associated documentation files (the "Software"), to deal
6
// in the Software without restriction, including without limitation the rights
7
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8
// copies of the Software, and to permit persons to whom the Software is
9
// furnished to do so, subject to the following conditions:
10
//
11
// The above copyright notice and this permission notice shall be included in
12
// all copies or substantial portions of the Software.
13
//
14
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20
// THE SOFTWARE.
21

22
package sql
23

24
import (
25
        "context"
26
        "database/sql"
27
        "fmt"
28
        "math"
29
        "time"
30

31
        "github.com/uber/cadence/common"
32
        "github.com/uber/cadence/common/log"
33
        "github.com/uber/cadence/common/persistence"
34
        "github.com/uber/cadence/common/persistence/serialization"
35
        "github.com/uber/cadence/common/persistence/sql/sqlplugin"
36
        "github.com/uber/cadence/common/types"
37
)
38

39
type sqlTaskStore struct {
40
        sqlStore
41
        nShards int
42
}
43

44
var (
45
        stickyTasksListsTTL = time.Hour * 24
46
)
47

48
// newTaskPersistence creates a new instance of TaskManager
49
func newTaskPersistence(
50
        db sqlplugin.DB,
51
        nShards int,
52
        log log.Logger,
53
        parser serialization.Parser,
54
) (persistence.TaskStore, error) {
34✔
55
        return &sqlTaskStore{
34✔
56
                sqlStore: sqlStore{
34✔
57
                        db:     db,
34✔
58
                        logger: log,
34✔
59
                        parser: parser,
34✔
60
                },
34✔
61
                nShards: nShards,
34✔
62
        }, nil
34✔
63
}
34✔
64

65
func (m *sqlTaskStore) GetTaskListSize(ctx context.Context, request *persistence.GetTaskListSizeRequest) (*persistence.GetTaskListSizeResponse, error) {
2,296✔
66
        dbShardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskListName, m.db.GetTotalNumDBShards())
2,296✔
67
        domainID := serialization.MustParseUUID(request.DomainID)
2,296✔
68
        size, err := m.db.GetTasksCount(ctx, &sqlplugin.TasksFilter{
2,296✔
69
                ShardID:      dbShardID,
2,296✔
70
                DomainID:     domainID,
2,296✔
71
                TaskListName: request.TaskListName,
2,296✔
72
                TaskType:     int64(request.TaskListType),
2,296✔
73
                MinTaskID:    &request.AckLevel,
2,296✔
74
        })
2,296✔
75
        if err != nil {
2,296✔
76
                return nil, convertCommonErrors(m.db, "GetTaskListSize", "", err)
×
77
        }
×
78
        return &persistence.GetTaskListSizeResponse{Size: size}, nil
2,296✔
79
}
80

81
func (m *sqlTaskStore) LeaseTaskList(
82
        ctx context.Context,
83
        request *persistence.LeaseTaskListRequest,
84
) (*persistence.LeaseTaskListResponse, error) {
873✔
85
        var rangeID int64
873✔
86
        var ackLevel int64
873✔
87
        dbShardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskList, m.db.GetTotalNumDBShards())
873✔
88

873✔
89
        domainID := serialization.MustParseUUID(request.DomainID)
873✔
90
        rows, err := m.db.SelectFromTaskLists(ctx, &sqlplugin.TaskListsFilter{
873✔
91
                ShardID:  dbShardID,
873✔
92
                DomainID: &domainID,
873✔
93
                Name:     &request.TaskList,
873✔
94
                TaskType: common.Int64Ptr(int64(request.TaskType))})
873✔
95
        if err != nil {
1,727✔
96
                if err == sql.ErrNoRows {
1,708✔
97
                        tlInfo := &serialization.TaskListInfo{
854✔
98
                                AckLevel:        ackLevel,
854✔
99
                                Kind:            int16(request.TaskListKind),
854✔
100
                                ExpiryTimestamp: time.Unix(0, 0),
854✔
101
                                LastUpdated:     time.Now(),
854✔
102
                        }
854✔
103
                        blob, err := m.parser.TaskListInfoToBlob(tlInfo)
854✔
104
                        if err != nil {
854✔
105
                                return nil, err
×
106
                        }
×
107
                        row := sqlplugin.TaskListsRow{
854✔
108
                                ShardID:      dbShardID,
854✔
109
                                DomainID:     domainID,
854✔
110
                                Name:         request.TaskList,
854✔
111
                                TaskType:     int64(request.TaskType),
854✔
112
                                Data:         blob.Data,
854✔
113
                                DataEncoding: string(blob.Encoding),
854✔
114
                        }
854✔
115
                        rows = []sqlplugin.TaskListsRow{row}
854✔
116
                        if m.db.SupportsTTL() && request.TaskListKind == persistence.TaskListKindSticky {
854✔
117
                                rowWithTTL := sqlplugin.TaskListsRowWithTTL{
×
118
                                        TaskListsRow: row,
×
119
                                        TTL:          stickyTasksListsTTL,
×
120
                                }
×
121
                                if _, err := m.db.InsertIntoTaskListsWithTTL(ctx, &rowWithTTL); err != nil {
×
122
                                        return nil, convertCommonErrors(m.db, "LeaseTaskListWithTTL", fmt.Sprintf("Failed to make task list %v of type %v.", request.TaskList, request.TaskType), err)
×
123
                                }
×
124
                        } else {
854✔
125
                                if _, err := m.db.InsertIntoTaskLists(ctx, &row); err != nil {
854✔
126
                                        return nil, convertCommonErrors(m.db, "LeaseTaskList", fmt.Sprintf("Failed to make task list %v of type %v.", request.TaskList, request.TaskType), err)
×
127
                                }
×
128
                        }
129
                } else {
×
130
                        return nil, convertCommonErrors(m.db, "LeaseTaskList", "Failed to check if task list existed.", err)
×
131
                }
×
132
        }
133

134
        row := rows[0]
873✔
135
        if request.RangeID > 0 && request.RangeID != row.RangeID {
873✔
136
                return nil, &persistence.ConditionFailedError{
×
137
                        Msg: fmt.Sprintf("leaseTaskList:renew failed:taskList:%v, taskListType:%v, haveRangeID:%v, gotRangeID:%v",
×
138
                                request.TaskList, request.TaskType, rangeID, row.RangeID),
×
139
                }
×
140
        }
×
141

142
        tlInfo, err := m.parser.TaskListInfoFromBlob(row.Data, row.DataEncoding)
873✔
143
        if err != nil {
873✔
144
                return nil, err
×
145
        }
×
146

147
        var resp *persistence.LeaseTaskListResponse
873✔
148
        err = m.txExecute(ctx, dbShardID, "LeaseTaskList", func(tx sqlplugin.Tx) error {
1,746✔
149
                rangeID = row.RangeID
873✔
150
                ackLevel = tlInfo.GetAckLevel()
873✔
151
                // We need to separately check the condition and do the
873✔
152
                // update because we want to throw different error codes.
873✔
153
                // Since we need to do things separately (in a transaction), we need to take a lock.
873✔
154
                err1 := lockTaskList(ctx, tx, dbShardID, domainID, request.TaskList, request.TaskType, rangeID)
873✔
155
                if err1 != nil {
873✔
156
                        return err1
×
157
                }
×
158
                now := time.Now()
873✔
159
                tlInfo.LastUpdated = now
873✔
160
                blob, err1 := m.parser.TaskListInfoToBlob(tlInfo)
873✔
161
                if err1 != nil {
873✔
162
                        return err1
×
163
                }
×
164
                row := &sqlplugin.TaskListsRow{
873✔
165
                        ShardID:      dbShardID,
873✔
166
                        DomainID:     row.DomainID,
873✔
167
                        RangeID:      row.RangeID + 1,
873✔
168
                        Name:         row.Name,
873✔
169
                        TaskType:     row.TaskType,
873✔
170
                        Data:         blob.Data,
873✔
171
                        DataEncoding: string(blob.Encoding),
873✔
172
                }
873✔
173
                var result sql.Result
873✔
174
                if tlInfo.GetKind() == persistence.TaskListKindSticky && m.db.SupportsTTL() {
873✔
175
                        result, err1 = tx.UpdateTaskListsWithTTL(ctx, &sqlplugin.TaskListsRowWithTTL{
×
176
                                TaskListsRow: *row,
×
177
                                TTL:          stickyTasksListsTTL,
×
178
                        })
×
179
                } else {
873✔
180
                        result, err1 = tx.UpdateTaskLists(ctx, row)
873✔
181
                }
873✔
182
                if err1 != nil {
873✔
183
                        return err1
×
184
                }
×
185
                rowsAffected, err1 := result.RowsAffected()
873✔
186
                if err1 != nil {
873✔
187
                        return fmt.Errorf("rowsAffected error: %v", err1)
×
188
                }
×
189
                if rowsAffected == 0 {
873✔
190
                        return fmt.Errorf("%v rows affected instead of 1", rowsAffected)
×
191
                }
×
192
                resp = &persistence.LeaseTaskListResponse{TaskListInfo: &persistence.TaskListInfo{
873✔
193
                        DomainID:    request.DomainID,
873✔
194
                        Name:        request.TaskList,
873✔
195
                        TaskType:    request.TaskType,
873✔
196
                        RangeID:     rangeID + 1,
873✔
197
                        AckLevel:    ackLevel,
873✔
198
                        Kind:        request.TaskListKind,
873✔
199
                        LastUpdated: now,
873✔
200
                }}
873✔
201
                return nil
873✔
202
        })
203
        return resp, err
873✔
204
}
205

206
func (m *sqlTaskStore) UpdateTaskList(
207
        ctx context.Context,
208
        request *persistence.UpdateTaskListRequest,
209
) (*persistence.UpdateTaskListResponse, error) {
3,159✔
210
        dbShardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.TaskListInfo.DomainID, request.TaskListInfo.Name, m.db.GetTotalNumDBShards())
3,159✔
211
        domainID := serialization.MustParseUUID(request.TaskListInfo.DomainID)
3,159✔
212
        tlInfo := &serialization.TaskListInfo{
3,159✔
213
                AckLevel:        request.TaskListInfo.AckLevel,
3,159✔
214
                Kind:            int16(request.TaskListInfo.Kind),
3,159✔
215
                ExpiryTimestamp: time.Unix(0, 0),
3,159✔
216
                LastUpdated:     time.Now(),
3,159✔
217
        }
3,159✔
218
        if request.TaskListInfo.Kind == persistence.TaskListKindSticky {
3,231✔
219
                tlInfo.ExpiryTimestamp = stickyTaskListExpiry()
72✔
220
        }
72✔
221

222
        var resp *persistence.UpdateTaskListResponse
3,159✔
223
        blob, err := m.parser.TaskListInfoToBlob(tlInfo)
3,159✔
224
        if err != nil {
3,159✔
225
                return nil, err
×
226
        }
×
227
        err = m.txExecute(ctx, dbShardID, "UpdateTaskList", func(tx sqlplugin.Tx) error {
6,318✔
228
                err1 := lockTaskList(
3,159✔
229
                        ctx, tx, dbShardID, domainID, request.TaskListInfo.Name, request.TaskListInfo.TaskType, request.TaskListInfo.RangeID)
3,159✔
230
                if err1 != nil {
3,159✔
231
                        return err1
×
232
                }
×
233
                var result sql.Result
3,159✔
234
                row := &sqlplugin.TaskListsRow{
3,159✔
235
                        ShardID:      dbShardID,
3,159✔
236
                        DomainID:     domainID,
3,159✔
237
                        RangeID:      request.TaskListInfo.RangeID,
3,159✔
238
                        Name:         request.TaskListInfo.Name,
3,159✔
239
                        TaskType:     int64(request.TaskListInfo.TaskType),
3,159✔
240
                        Data:         blob.Data,
3,159✔
241
                        DataEncoding: string(blob.Encoding),
3,159✔
242
                }
3,159✔
243
                if m.db.SupportsTTL() && request.TaskListInfo.Kind == persistence.TaskListKindSticky {
3,159✔
244
                        result, err1 = tx.UpdateTaskListsWithTTL(ctx, &sqlplugin.TaskListsRowWithTTL{
×
245
                                TaskListsRow: *row,
×
246
                                TTL:          stickyTasksListsTTL,
×
247
                        })
×
248
                } else {
3,159✔
249
                        result, err1 = tx.UpdateTaskLists(ctx, row)
3,159✔
250
                }
3,159✔
251
                if err1 != nil {
3,159✔
252
                        return err1
×
253
                }
×
254
                rowsAffected, err1 := result.RowsAffected()
3,159✔
255
                if err1 != nil {
3,159✔
256
                        return err1
×
257
                }
×
258
                if rowsAffected != 1 {
3,159✔
259
                        return fmt.Errorf("%v rows were affected instead of 1", rowsAffected)
×
260
                }
×
261
                resp = &persistence.UpdateTaskListResponse{}
3,159✔
262
                return nil
3,159✔
263
        })
264
        return resp, err
3,159✔
265
}
266

267
type taskListPageToken struct {
268
        ShardID  int
269
        DomainID serialization.UUID
270
        Name     string
271
        TaskType int64
272
}
273

274
// ListTaskList lists tasklist from DB
275
// DomainID translates into byte array in SQL. The minUUID is not the minimum byte array.
276
func (m *sqlTaskStore) ListTaskList(
277
        ctx context.Context,
278
        request *persistence.ListTaskListRequest,
279
) (*persistence.ListTaskListResponse, error) {
×
280
        pageToken := taskListPageToken{DomainID: serialization.UUID{}}
×
281
        if len(request.PageToken) > 0 {
×
282
                if err := gobDeserialize(request.PageToken, &pageToken); err != nil {
×
283
                        return nil, &types.InternalServiceError{Message: fmt.Sprintf("error deserializing page token: %v", err)}
×
284
                }
×
285
        } else {
×
286
                pageToken = taskListPageToken{TaskType: math.MinInt16, DomainID: serialization.UUID{}}
×
287
        }
×
288
        var err error
×
289
        var rows []sqlplugin.TaskListsRow
×
290
        for pageToken.ShardID < m.nShards {
×
291
                rows, err = m.db.SelectFromTaskLists(ctx, &sqlplugin.TaskListsFilter{
×
292
                        ShardID:             pageToken.ShardID,
×
293
                        DomainIDGreaterThan: &pageToken.DomainID,
×
294
                        NameGreaterThan:     &pageToken.Name,
×
295
                        TaskTypeGreaterThan: &pageToken.TaskType,
×
296
                        PageSize:            &request.PageSize,
×
297
                })
×
298
                if err != nil {
×
299
                        return nil, convertCommonErrors(m.db, "ListTaskList", "", err)
×
300
                }
×
301
                if len(rows) > 0 {
×
302
                        break
×
303
                }
304
                pageToken = taskListPageToken{ShardID: pageToken.ShardID + 1, TaskType: math.MinInt16, DomainID: serialization.UUID{}}
×
305
        }
306

307
        var nextPageToken []byte
×
308
        switch {
×
309
        case len(rows) >= request.PageSize:
×
310
                lastRow := &rows[request.PageSize-1]
×
311
                nextPageToken, err = gobSerialize(&taskListPageToken{
×
312
                        ShardID:  pageToken.ShardID,
×
313
                        DomainID: lastRow.DomainID,
×
314
                        Name:     lastRow.Name,
×
315
                        TaskType: lastRow.TaskType,
×
316
                })
×
317
        case pageToken.ShardID+1 < m.nShards:
×
318
                nextPageToken, err = gobSerialize(&taskListPageToken{ShardID: pageToken.ShardID + 1, TaskType: math.MinInt16, DomainID: serialization.UUID{}})
×
319
        }
320

321
        if err != nil {
×
322
                return nil, &types.InternalServiceError{Message: fmt.Sprintf("error serializing nextPageToken:%v", err)}
×
323
        }
×
324

325
        resp := &persistence.ListTaskListResponse{
×
326
                Items:         make([]persistence.TaskListInfo, len(rows)),
×
327
                NextPageToken: nextPageToken,
×
328
        }
×
329

×
330
        for i := range rows {
×
331
                info, err := m.parser.TaskListInfoFromBlob(rows[i].Data, rows[i].DataEncoding)
×
332
                if err != nil {
×
333
                        return nil, err
×
334
                }
×
335
                resp.Items[i].DomainID = rows[i].DomainID.String()
×
336
                resp.Items[i].Name = rows[i].Name
×
337
                resp.Items[i].TaskType = int(rows[i].TaskType)
×
338
                resp.Items[i].RangeID = rows[i].RangeID
×
339
                resp.Items[i].Kind = int(info.GetKind())
×
340
                resp.Items[i].AckLevel = info.GetAckLevel()
×
341
                resp.Items[i].Expiry = info.GetExpiryTimestamp()
×
342
                resp.Items[i].LastUpdated = info.GetLastUpdated()
×
343
        }
344

345
        return resp, nil
×
346
}
347

348
func (m *sqlTaskStore) DeleteTaskList(
349
        ctx context.Context,
350
        request *persistence.DeleteTaskListRequest,
351
) error {
×
352
        shardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskListName, m.db.GetTotalNumDBShards())
×
353
        domainID := serialization.MustParseUUID(request.DomainID)
×
354
        result, err := m.db.DeleteFromTaskLists(ctx, &sqlplugin.TaskListsFilter{
×
355
                ShardID:  shardID,
×
356
                DomainID: &domainID,
×
357
                Name:     &request.TaskListName,
×
358
                TaskType: common.Int64Ptr(int64(request.TaskListType)),
×
359
                RangeID:  &request.RangeID,
×
360
        })
×
361
        if err != nil {
×
362
                return convertCommonErrors(m.db, "DeleteTaskList", "", err)
×
363
        }
×
364
        nRows, err := result.RowsAffected()
×
365
        if err != nil {
×
366
                return &types.InternalServiceError{Message: fmt.Sprintf("rowsAffected returned error:%v", err)}
×
367
        }
×
368
        if nRows != 1 {
×
369
                return &types.InternalServiceError{Message: fmt.Sprintf("delete failed: %v rows affected instead of 1", nRows)}
×
370
        }
×
371
        return nil
×
372
}
373

374
func (m *sqlTaskStore) CreateTasks(
375
        ctx context.Context,
376
        request *persistence.InternalCreateTasksRequest,
377
) (*persistence.CreateTasksResponse, error) {
534✔
378
        var tasksRows []sqlplugin.TasksRow
534✔
379
        var tasksRowsWithTTL []sqlplugin.TasksRowWithTTL
534✔
380
        if m.db.SupportsTTL() {
534✔
381
                tasksRowsWithTTL = make([]sqlplugin.TasksRowWithTTL, len(request.Tasks))
×
382
        } else {
534✔
383
                tasksRows = make([]sqlplugin.TasksRow, len(request.Tasks))
534✔
384
        }
534✔
385

386
        dbShardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.TaskListInfo.DomainID, request.TaskListInfo.Name, m.db.GetTotalNumDBShards())
534✔
387

534✔
388
        for i, v := range request.Tasks {
1,069✔
389
                var expiryTime time.Time
535✔
390
                var ttl time.Duration
535✔
391
                if v.Data.ScheduleToStartTimeout.Seconds() > 0 {
1,070✔
392
                        ttl = v.Data.ScheduleToStartTimeout
535✔
393
                        if m.db.SupportsTTL() {
535✔
394
                                maxAllowedTTL, err := m.db.MaxAllowedTTL()
×
395
                                if err != nil {
×
396
                                        return nil, err
×
397
                                }
×
398
                                if ttl > *maxAllowedTTL {
×
399
                                        ttl = *maxAllowedTTL
×
400
                                }
×
401
                        }
402
                        expiryTime = time.Now().Add(ttl)
535✔
403
                }
404
                blob, err := m.parser.TaskInfoToBlob(&serialization.TaskInfo{
535✔
405
                        WorkflowID:       v.Data.WorkflowID,
535✔
406
                        RunID:            serialization.MustParseUUID(v.Data.RunID),
535✔
407
                        ScheduleID:       v.Data.ScheduleID,
535✔
408
                        ExpiryTimestamp:  expiryTime,
535✔
409
                        CreatedTimestamp: time.Now(),
535✔
410
                        PartitionConfig:  v.Data.PartitionConfig,
535✔
411
                })
535✔
412
                if err != nil {
535✔
413
                        return nil, err
×
414
                }
×
415

416
                currTasksRow := sqlplugin.TasksRow{
535✔
417
                        ShardID:      dbShardID,
535✔
418
                        DomainID:     serialization.MustParseUUID(v.Data.DomainID),
535✔
419
                        TaskListName: request.TaskListInfo.Name,
535✔
420
                        TaskType:     int64(request.TaskListInfo.TaskType),
535✔
421
                        TaskID:       v.TaskID,
535✔
422
                        Data:         blob.Data,
535✔
423
                        DataEncoding: string(blob.Encoding),
535✔
424
                }
535✔
425
                if m.db.SupportsTTL() {
535✔
426
                        currTasksRowWithTTL := sqlplugin.TasksRowWithTTL{
×
427
                                TasksRow: currTasksRow,
×
428
                        }
×
429
                        if ttl > 0 {
×
430
                                currTasksRowWithTTL.TTL = &ttl
×
431
                        }
×
432
                        tasksRowsWithTTL[i] = currTasksRowWithTTL
×
433
                } else {
535✔
434
                        tasksRows[i] = currTasksRow
535✔
435
                }
535✔
436

437
        }
438
        var resp *persistence.CreateTasksResponse
534✔
439
        err := m.txExecute(ctx, dbShardID, "CreateTasks", func(tx sqlplugin.Tx) error {
1,068✔
440
                if m.db.SupportsTTL() {
534✔
441
                        if _, err := tx.InsertIntoTasksWithTTL(ctx, tasksRowsWithTTL); err != nil {
×
442
                                return err
×
443
                        }
×
444
                } else {
534✔
445
                        if _, err := tx.InsertIntoTasks(ctx, tasksRows); err != nil {
534✔
446
                                return err
×
447
                        }
×
448
                }
449

450
                // Lock task list before committing.
451
                err1 := lockTaskList(ctx, tx,
534✔
452
                        dbShardID,
534✔
453
                        serialization.MustParseUUID(request.TaskListInfo.DomainID),
534✔
454
                        request.TaskListInfo.Name,
534✔
455
                        request.TaskListInfo.TaskType, request.TaskListInfo.RangeID)
534✔
456
                if err1 != nil {
534✔
457
                        return err1
×
458
                }
×
459
                resp = &persistence.CreateTasksResponse{}
534✔
460
                return nil
534✔
461
        })
462
        return resp, err
534✔
463
}
464

465
func (m *sqlTaskStore) GetTasks(
466
        ctx context.Context,
467
        request *persistence.GetTasksRequest,
468
) (*persistence.InternalGetTasksResponse, error) {
555✔
469
        shardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskList, m.db.GetTotalNumDBShards())
555✔
470
        rows, err := m.db.SelectFromTasks(ctx, &sqlplugin.TasksFilter{
555✔
471
                ShardID:      shardID,
555✔
472
                DomainID:     serialization.MustParseUUID(request.DomainID),
555✔
473
                TaskListName: request.TaskList,
555✔
474
                TaskType:     int64(request.TaskType),
555✔
475
                MinTaskID:    &request.ReadLevel,
555✔
476
                MaxTaskID:    request.MaxReadLevel,
555✔
477
                PageSize:     &request.BatchSize,
555✔
478
        })
555✔
479
        if err != nil {
555✔
480
                return nil, convertCommonErrors(m.db, "GetTasks", "", err)
×
481
        }
×
482

483
        var tasks = make([]*persistence.InternalTaskInfo, len(rows))
555✔
484
        for i, v := range rows {
1,092✔
485
                info, err := m.parser.TaskInfoFromBlob(v.Data, v.DataEncoding)
537✔
486
                if err != nil {
537✔
487
                        return nil, err
×
488
                }
×
489
                tasks[i] = &persistence.InternalTaskInfo{
537✔
490
                        DomainID:        request.DomainID,
537✔
491
                        WorkflowID:      info.GetWorkflowID(),
537✔
492
                        RunID:           info.RunID.String(),
537✔
493
                        TaskID:          v.TaskID,
537✔
494
                        ScheduleID:      info.GetScheduleID(),
537✔
495
                        Expiry:          info.GetExpiryTimestamp(),
537✔
496
                        CreatedTime:     info.GetCreatedTimestamp(),
537✔
497
                        PartitionConfig: info.GetPartitionConfig(),
537✔
498
                }
537✔
499
        }
500

501
        return &persistence.InternalGetTasksResponse{Tasks: tasks}, nil
555✔
502
}
503

504
func (m *sqlTaskStore) CompleteTask(
505
        ctx context.Context,
506
        request *persistence.CompleteTaskRequest,
507
) error {
×
508
        taskID := request.TaskID
×
509
        taskList := request.TaskList
×
510
        shardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(taskList.DomainID, taskList.Name, m.db.GetTotalNumDBShards())
×
511
        _, err := m.db.DeleteFromTasks(ctx, &sqlplugin.TasksFilter{
×
512
                ShardID:      shardID,
×
513
                DomainID:     serialization.MustParseUUID(taskList.DomainID),
×
514
                TaskListName: taskList.Name,
×
515
                TaskType:     int64(taskList.TaskType),
×
516
                TaskID:       &taskID})
×
517
        if err != nil {
×
518
                return convertCommonErrors(m.db, "CompleteTask", "", err)
×
519
        }
×
520
        return nil
×
521
}
522

523
func (m *sqlTaskStore) CompleteTasksLessThan(
524
        ctx context.Context,
525
        request *persistence.CompleteTasksLessThanRequest,
526
) (*persistence.CompleteTasksLessThanResponse, error) {
292✔
527
        shardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskListName, m.db.GetTotalNumDBShards())
292✔
528
        result, err := m.db.DeleteFromTasks(ctx, &sqlplugin.TasksFilter{
292✔
529
                ShardID:              shardID,
292✔
530
                DomainID:             serialization.MustParseUUID(request.DomainID),
292✔
531
                TaskListName:         request.TaskListName,
292✔
532
                TaskType:             int64(request.TaskType),
292✔
533
                TaskIDLessThanEquals: &request.TaskID,
292✔
534
                Limit:                &request.Limit,
292✔
535
        })
292✔
536
        if err != nil {
292✔
537
                return nil, convertCommonErrors(m.db, "CompleteTasksLessThan", "", err)
×
538
        }
×
539
        nRows, err := result.RowsAffected()
292✔
540
        if err != nil {
292✔
541
                return nil, &types.InternalServiceError{
×
542
                        Message: fmt.Sprintf("rowsAffected returned error: %v", err),
×
543
                }
×
544
        }
×
545
        return &persistence.CompleteTasksLessThanResponse{TasksCompleted: int(nRows)}, nil
292✔
546
}
547

548
// GetOrphanTasks gets tasks from the tasks table that belong to a task_list no longer present
549
// in the task_lists table.
550
// TODO: Limit this query to a specific shard at a time. See https://github.com/uber/cadence/issues/4064
551
func (m *sqlTaskStore) GetOrphanTasks(ctx context.Context, request *persistence.GetOrphanTasksRequest) (*persistence.GetOrphanTasksResponse, error) {
×
552
        rows, err := m.db.GetOrphanTasks(ctx, &sqlplugin.OrphanTasksFilter{
×
553
                Limit: &request.Limit,
×
554
        })
×
555
        if err != nil {
×
556
                return nil, convertCommonErrors(m.db, "GetOrphanTasks", "", err)
×
557
        }
×
558

559
        var tasks = make([]*persistence.TaskKey, len(rows))
×
560
        for i, v := range rows {
×
561
                tasks[i] = &persistence.TaskKey{
×
562
                        DomainID:     v.DomainID.String(),
×
563
                        TaskListName: v.TaskListName,
×
564
                        TaskType:     int(v.TaskType),
×
565
                        TaskID:       v.TaskID,
×
566
                }
×
567
        }
×
568

569
        return &persistence.GetOrphanTasksResponse{Tasks: tasks}, nil
×
570
}
571

572
func lockTaskList(ctx context.Context, tx sqlplugin.Tx, shardID int, domainID serialization.UUID, name string, taskListType int, oldRangeID int64) error {
4,566✔
573
        rangeID, err := tx.LockTaskLists(ctx, &sqlplugin.TaskListsFilter{
4,566✔
574
                ShardID: shardID, DomainID: &domainID, Name: &name, TaskType: common.Int64Ptr(int64(taskListType))})
4,566✔
575

4,566✔
576
        switch err {
4,566✔
577
        case nil:
4,566✔
578
                if rangeID != oldRangeID {
4,566✔
579
                        return &persistence.ConditionFailedError{
×
580
                                Msg: fmt.Sprintf("Task list range ID was %v when it was should have been %v", rangeID, oldRangeID),
×
581
                        }
×
582
                }
×
583
                return nil
4,566✔
584
        case sql.ErrNoRows:
×
585
                return &persistence.ConditionFailedError{
×
586
                        Msg: "Task list does not exist.",
×
587
                }
×
588
        default:
×
589
                return convertCommonErrors(tx, "lockTaskList", "", err)
×
590
        }
591
}
592

593
func stickyTaskListExpiry() time.Time {
72✔
594
        return time.Now().Add(stickyTasksListsTTL)
72✔
595
}
72✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc