• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Permify / permify / 12882779193

21 Jan 2025 08:19AM UTC coverage: 79.891% (-0.06%) from 79.953%
12882779193

push

github

web-flow
Merge pull request #1928 from Permify/feat-add-cursor-pagination-limit

refactor: update dataReader.go for improved pagination and limit hand…

24 of 43 new or added lines in 3 files covered. (55.81%)

2 existing lines in 1 file now uncovered.

8184 of 10244 relevant lines covered (79.89%)

122.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.45
/internal/storage/memory/dataReader.go
1
package memory
2

3
import (
4
        "context"
5
        "errors"
6
        "slices"
7
        "sort"
8
        "strconv"
9
        "time"
10

11
        "github.com/hashicorp/go-memdb"
12

13
        "github.com/Permify/permify/internal/storage/memory/constants"
14

15
        "github.com/Permify/permify/internal/storage"
16
        "github.com/Permify/permify/internal/storage/memory/snapshot"
17
        "github.com/Permify/permify/internal/storage/memory/utils"
18
        "github.com/Permify/permify/pkg/database"
19
        db "github.com/Permify/permify/pkg/database/memory"
20
        base "github.com/Permify/permify/pkg/pb/base/v1"
21
        "github.com/Permify/permify/pkg/token"
22
)
23

24
// DataReader -
25
type DataReader struct {
26
        database *db.Memory
27
}
28

29
// NewDataReader - Creates a new DataReader
30
func NewDataReader(database *db.Memory) *DataReader {
7✔
31
        return &DataReader{
7✔
32
                database: database,
7✔
33
        }
7✔
34
}
7✔
35

36
// QueryRelationships queries the database for relationships based on the provided filter.
37
func (r *DataReader) QueryRelationships(_ context.Context, tenantID string, filter *base.TupleFilter, _ string, pagination database.CursorPagination) (it *database.TupleIterator, err error) {
1✔
38
        txn := r.database.DB.Txn(false)
1✔
39
        defer txn.Abort()
1✔
40

1✔
41
        var lowerBound string
1✔
42

1✔
43
        if pagination.Cursor() != "" {
1✔
44
                var t database.ContinuousToken
×
45
                t, err = utils.EncodedContinuousToken{Value: pagination.Cursor()}.Decode()
×
46
                if err != nil {
×
47
                        return nil, err
×
48
                }
×
49
                lowerBound = t.(utils.ContinuousToken).Value
×
50
        }
51

52
        // Get the index and arguments based on the filter.
53
        index, args := utils.GetRelationTuplesIndexNameAndArgsByFilters(tenantID, filter)
1✔
54

1✔
55
        // Get the result iterator based on the index and arguments.
1✔
56
        var result memdb.ResultIterator
1✔
57
        result, err = txn.LowerBound(constants.RelationTuplesTable, index, args...)
1✔
58
        if err != nil {
1✔
59
                return nil, errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
×
60
        }
×
61

62
        // Filter the result iterator and add the tuples to the collection.
63
        tup := make([]storage.RelationTuple, 0, 10)
1✔
64
        fit := memdb.NewFilterIterator(result, utils.FilterRelationTuplesQuery(tenantID, filter))
1✔
65
        for obj := fit.Next(); obj != nil; obj = fit.Next() {
3✔
66
                t, ok := obj.(storage.RelationTuple)
2✔
67
                if !ok {
2✔
68
                        return nil, errors.New(base.ErrorCode_ERROR_CODE_TYPE_CONVERSATION.String())
×
69
                }
×
70
                tup = append(tup, t)
2✔
71
        }
72

73
        // Sort tuples based on the provided order field
74
        sort.Slice(tup, func(i, j int) bool {
2✔
75
                switch pagination.Sort() {
1✔
76
                case "entity_id":
×
77
                        return tup[i].EntityID < tup[j].EntityID
×
78
                case "subject_id":
×
79
                        return tup[i].SubjectID < tup[j].SubjectID
×
80
                default:
1✔
81
                        return false // No sorting if order field is invalid
1✔
82
                }
83
        })
84

85
        var tuples []*base.Tuple
1✔
86
        count := uint32(0)
1✔
87
        limit := pagination.Limit()
1✔
88

1✔
89
        for _, t := range tup {
3✔
90
                // Skip tuples below the lower bound
2✔
91
                switch pagination.Sort() {
2✔
92
                case "entity_id":
×
NEW
93
                        if t.EntityID < lowerBound {
×
NEW
94
                                continue
×
95
                        }
96
                case "subject_id":
×
NEW
97
                        if t.SubjectID < lowerBound {
×
NEW
98
                                continue
×
99
                        }
100
                }
101

102
                // Add tuple to result set
103
                tuples = append(tuples, t.ToTuple())
2✔
104

2✔
105
                // Enforce the limit if it's set
2✔
106
                count++
2✔
107
                if limit > 0 && count >= limit {
2✔
NEW
108
                        break
×
109
                }
110
        }
111

112
        return database.NewTupleCollection(tuples...).CreateTupleIterator(), nil
1✔
113
}
114

115
// ReadRelationships reads relationships from the database taking into account the pagination.
116
func (r *DataReader) ReadRelationships(_ context.Context, tenantID string, filter *base.TupleFilter, _ string, pagination database.Pagination) (collection *database.TupleCollection, ct database.EncodedContinuousToken, err error) {
4✔
117
        txn := r.database.DB.Txn(false)
4✔
118
        defer txn.Abort()
4✔
119

4✔
120
        var lowerBound uint64
4✔
121
        if pagination.Token() != "" {
5✔
122
                var t database.ContinuousToken
1✔
123
                t, err = utils.EncodedContinuousToken{Value: pagination.Token()}.Decode()
1✔
124
                if err != nil {
1✔
125
                        return nil, database.NewNoopContinuousToken().Encode(), err
×
126
                }
×
127
                lowerBound, err = strconv.ParseUint(t.(utils.ContinuousToken).Value, 10, 64)
1✔
128
                if err != nil {
1✔
129
                        return nil, database.NewNoopContinuousToken().Encode(), errors.New(base.ErrorCode_ERROR_CODE_INVALID_CONTINUOUS_TOKEN.String())
×
130
                }
×
131
        }
132

133
        index, args := utils.GetRelationTuplesIndexNameAndArgsByFilters(tenantID, filter)
4✔
134

4✔
135
        // Get the result iterator using lower bound.
4✔
136
        var result memdb.ResultIterator
4✔
137
        result, err = txn.LowerBound(constants.RelationTuplesTable, index, args...)
4✔
138
        if err != nil {
4✔
139
                return nil, database.NewNoopContinuousToken().Encode(), errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
×
140
        }
×
141

142
        // Filter the result iterator and add the tuples to the array.
143
        tup := make([]storage.RelationTuple, 0, 10)
4✔
144
        fit := memdb.NewFilterIterator(result, utils.FilterRelationTuplesQuery(tenantID, filter))
4✔
145
        for obj := fit.Next(); obj != nil; obj = fit.Next() {
20✔
146
                t, ok := obj.(storage.RelationTuple)
16✔
147
                if !ok {
16✔
148
                        return nil, database.NewNoopContinuousToken().Encode(), errors.New(base.ErrorCode_ERROR_CODE_TYPE_CONVERSATION.String())
×
149
                }
×
150
                tup = append(tup, t)
16✔
151
        }
152

153
        // Sort the tuples and append them to the collection.
154
        sort.Slice(tup, func(i, j int) bool {
16✔
155
                return tup[i].ID < tup[j].ID
12✔
156
        })
12✔
157

158
        tuples := make([]*base.Tuple, 0, pagination.PageSize()+1)
4✔
159
        for _, t := range tup {
18✔
160
                if t.ID >= lowerBound {
26✔
161
                        tuples = append(tuples, t.ToTuple())
12✔
162
                        if pagination.PageSize() != 0 && len(tuples) > int(pagination.PageSize()) {
13✔
163
                                return database.NewTupleCollection(tuples[:pagination.PageSize()]...), utils.NewContinuousToken(strconv.FormatUint(t.ID, 10)).Encode(), nil
1✔
164
                        }
1✔
165
                }
166
        }
167

168
        return database.NewTupleCollection(tuples...), database.NewNoopContinuousToken().Encode(), nil
3✔
169
}
170

171
// QuerySingleAttribute queries the database for a single attribute based on the provided filter.
172
func (r *DataReader) QuerySingleAttribute(_ context.Context, tenantID string, filter *base.AttributeFilter, _ string) (attribute *base.Attribute, err error) {
2✔
173
        txn := r.database.DB.Txn(false)
2✔
174
        defer txn.Abort()
2✔
175

2✔
176
        // Get the index and arguments based on the filter.
2✔
177
        index, args := utils.GetAttributesIndexNameAndArgsByFilters(tenantID, filter)
2✔
178

2✔
179
        // Get the result iterator based on the index and arguments.
2✔
180
        var result memdb.ResultIterator
2✔
181
        result, err = txn.Get(constants.AttributesTable, index, args...)
2✔
182
        if err != nil {
2✔
183
                return nil, errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
×
184
        }
×
185

186
        // Filter the result iterator and add the attributes to the collection.
187
        fit := memdb.NewFilterIterator(result, utils.FilterAttributesQuery(tenantID, filter))
2✔
188
        for obj := fit.Next(); obj != nil; {
3✔
189
                t, ok := obj.(storage.Attribute)
1✔
190
                if !ok {
1✔
191
                        return nil, errors.New(base.ErrorCode_ERROR_CODE_TYPE_CONVERSATION.String())
×
192
                }
×
193
                return t.ToAttribute(), nil
1✔
194
        }
195

196
        return nil, nil
1✔
197
}
198

199
// QueryAttributes queries the database for attributes based on the provided filter.
200
func (r *DataReader) QueryAttributes(_ context.Context, tenantID string, filter *base.AttributeFilter, _ string, pagination database.CursorPagination) (iterator *database.AttributeIterator, err error) {
1✔
201
        txn := r.database.DB.Txn(false)
1✔
202
        defer txn.Abort()
1✔
203

1✔
204
        var lowerBound string
1✔
205

1✔
206
        if pagination.Cursor() != "" {
1✔
207
                var t database.ContinuousToken
×
208
                t, err = utils.EncodedContinuousToken{Value: pagination.Cursor()}.Decode()
×
209
                if err != nil {
×
210
                        return nil, err
×
211
                }
×
212
                lowerBound = t.(utils.ContinuousToken).Value
×
213
        }
214

215
        // Get the index and arguments based on the filter.
216
        index, args := utils.GetAttributesIndexNameAndArgsByFilters(tenantID, filter)
1✔
217

1✔
218
        // Get the result iterator based on the index and arguments.
1✔
219
        var result memdb.ResultIterator
1✔
220
        result, err = txn.Get(constants.AttributesTable, index, args...)
1✔
221
        if err != nil {
1✔
222
                return nil, errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
×
223
        }
×
224

225
        // Filter the result iterator and add the attributes to the collection.
226
        attr := make([]storage.Attribute, 0, 10)
1✔
227
        fit := memdb.NewFilterIterator(result, utils.FilterAttributesQuery(tenantID, filter))
1✔
228
        for obj := fit.Next(); obj != nil; obj = fit.Next() {
3✔
229
                t, ok := obj.(storage.Attribute)
2✔
230
                if !ok {
2✔
231
                        return nil, errors.New(base.ErrorCode_ERROR_CODE_TYPE_CONVERSATION.String())
×
232
                }
×
233
                attr = append(attr, t)
2✔
234
        }
235

236
        // Sort attributes based on the provided order field
237
        sort.Slice(attr, func(i, j int) bool {
2✔
238
                switch pagination.Sort() {
1✔
239
                case "entity_id":
×
240
                        return attr[i].EntityID < attr[j].EntityID
×
241
                default:
1✔
242
                        return false // No sorting if order field is invalid
1✔
243
                }
244
        })
245

246
        var attrs []*base.Attribute
1✔
247
        count := uint32(0)
1✔
248
        limit := pagination.Limit()
1✔
249

1✔
250
        for _, t := range attr {
3✔
251
                // Skip attributes below the lower bound
2✔
252
                switch pagination.Sort() {
2✔
253
                case "entity_id":
×
NEW
254
                        if t.EntityID < lowerBound {
×
NEW
255
                                continue
×
256
                        }
257
                }
258

259
                // Add attribute to result set
260
                attrs = append(attrs, t.ToAttribute())
2✔
261

2✔
262
                // Enforce the limit if it's set
2✔
263
                count++
2✔
264
                if limit > 0 && count >= limit {
2✔
NEW
265
                        break
×
266
                }
267
        }
268

269
        return database.NewAttributeCollection(attrs...).CreateAttributeIterator(), nil
1✔
270
}
271

272
// ReadAttributes reads attributes from the database taking into account the pagination.
273
func (r *DataReader) ReadAttributes(_ context.Context, tenantID string, filter *base.AttributeFilter, _ string, pagination database.Pagination) (collection *database.AttributeCollection, ct database.EncodedContinuousToken, err error) {
4✔
274
        txn := r.database.DB.Txn(false)
4✔
275
        defer txn.Abort()
4✔
276

4✔
277
        var lowerBound uint64
4✔
278
        if pagination.Token() != "" {
5✔
279
                var t database.ContinuousToken
1✔
280
                t, err = utils.EncodedContinuousToken{Value: pagination.Token()}.Decode()
1✔
281
                if err != nil {
1✔
282
                        return nil, database.NewNoopContinuousToken().Encode(), err
×
283
                }
×
284
                lowerBound, err = strconv.ParseUint(t.(utils.ContinuousToken).Value, 10, 64)
1✔
285
                if err != nil {
1✔
286
                        return nil, database.NewNoopContinuousToken().Encode(), errors.New(base.ErrorCode_ERROR_CODE_INVALID_CONTINUOUS_TOKEN.String())
×
287
                }
×
288
        }
289

290
        // Get the index and arguments based on the filter.
291
        index, args := utils.GetAttributesIndexNameAndArgsByFilters(tenantID, filter)
4✔
292

4✔
293
        // Get the result iterator using lower bound.
4✔
294
        var result memdb.ResultIterator
4✔
295
        result, err = txn.LowerBound(constants.AttributesTable, index, args...)
4✔
296
        if err != nil {
4✔
297
                return nil, database.NewNoopContinuousToken().Encode(), errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
×
298
        }
×
299

300
        // Filter the result iterator and add the attributes to the array.
301
        attr := make([]storage.Attribute, 0, 10)
4✔
302
        fit := memdb.NewFilterIterator(result, utils.FilterAttributesQuery(tenantID, filter))
4✔
303
        for obj := fit.Next(); obj != nil; obj = fit.Next() {
19✔
304
                a, ok := obj.(storage.Attribute)
15✔
305
                if !ok {
15✔
306
                        return nil, database.NewNoopContinuousToken().Encode(), errors.New(base.ErrorCode_ERROR_CODE_TYPE_CONVERSATION.String())
×
307
                }
×
308
                attr = append(attr, a)
15✔
309
        }
310

311
        // Sort the attributes and append them to the collection.
312
        sort.Slice(attr, func(i, j int) bool {
25✔
313
                return attr[i].ID < attr[j].ID
21✔
314
        })
21✔
315

316
        attributes := make([]*base.Attribute, 0, pagination.PageSize()+1)
4✔
317
        for _, t := range attr {
17✔
318
                if t.ID >= lowerBound {
24✔
319
                        attributes = append(attributes, t.ToAttribute())
11✔
320
                        if pagination.PageSize() != 0 && len(attributes) > int(pagination.PageSize()) {
12✔
321
                                return database.NewAttributeCollection(attributes[:pagination.PageSize()]...), utils.NewContinuousToken(strconv.FormatUint(t.ID, 10)).Encode(), nil
1✔
322
                        }
1✔
323
                }
324
        }
325

326
        return database.NewAttributeCollection(attributes...), database.NewNoopContinuousToken().Encode(), nil
3✔
327
}
328

329
// QueryUniqueSubjectReferences is a function that searches for unique subject references in a given database.
330
func (r *DataReader) QueryUniqueSubjectReferences(_ context.Context, tenantID string, subjectReference *base.RelationReference, excluded []string, _ string, pagination database.Pagination) (ids []string, _ database.EncodedContinuousToken, err error) {
2✔
331
        txn := r.database.DB.Txn(false)
2✔
332
        defer txn.Abort()
2✔
333

2✔
334
        var lowerBound string
2✔
335
        if pagination.Token() != "" {
2✔
336
                var t database.ContinuousToken
×
337
                t, err := utils.EncodedContinuousToken{Value: pagination.Token()}.Decode()
×
338
                if err != nil {
×
339
                        return nil, database.NewNoopContinuousToken().Encode(), err
×
340
                }
×
341
                lowerBound = t.(utils.ContinuousToken).Value
×
342
        }
343

344
        // Get the result iterator based on the index and arguments.
345
        result, err := txn.LowerBound(constants.RelationTuplesTable, "id")
2✔
346
        if err != nil {
2✔
347
                return nil, database.NewNoopContinuousToken().Encode(), errors.New(base.ErrorCode_ERROR_CODE_EXECUTION.String())
×
348
        }
×
349

350
        var subjectIDs []string
2✔
351

2✔
352
        // Filter the result iterator and add the tuples to the collection.
2✔
353
        fit := memdb.NewFilterIterator(result, utils.FilterRelationTuplesQuery(tenantID, &base.TupleFilter{
2✔
354
                Subject: &base.SubjectFilter{
2✔
355
                        Type:     subjectReference.GetType(),
2✔
356
                        Relation: subjectReference.GetRelation(),
2✔
357
                },
2✔
358
        }))
2✔
359
        for obj := fit.Next(); obj != nil; obj = fit.Next() {
8✔
360
                t, ok := obj.(storage.RelationTuple)
6✔
361
                if !ok {
6✔
362
                        return nil, database.NewNoopContinuousToken().Encode(), errors.New(base.ErrorCode_ERROR_CODE_TYPE_CONVERSATION.String())
×
363
                }
×
364
                subjectIDs = append(subjectIDs, t.SubjectID)
6✔
365
        }
366

367
        // Sort the tuples and append them to the collection.
368
        sort.Slice(subjectIDs, func(i, j int) bool {
9✔
369
                return subjectIDs[i] < subjectIDs[j]
7✔
370
        })
7✔
371

372
        mp := make(map[string]bool)
2✔
373
        var lastID string
2✔
374

2✔
375
        for _, b := range subjectIDs {
8✔
376
                if slices.Contains(excluded, b) {
6✔
377
                        continue
×
378
                }
379
                if _, exists := mp[b]; !exists && b >= lowerBound {
11✔
380

5✔
381
                        ids = append(ids, b)
5✔
382
                        mp[b] = true
5✔
383

5✔
384
                        // Capture the last ID after adding pagesize + 1 elements
5✔
385
                        if len(ids) == int(pagination.PageSize())+1 {
7✔
386
                                lastID = b
2✔
387
                        }
2✔
388

389
                        // Stop appending if we've reached the page size
390
                        if pagination.PageSize() != 0 && len(ids) > int(pagination.PageSize()) {
5✔
391
                                return ids[:pagination.PageSize()], utils.NewContinuousToken(lastID).Encode(), nil
×
392
                        }
×
393
                }
394
        }
395

396
        return ids, database.NewNoopContinuousToken().Encode(), nil
2✔
397
}
398

399
// HeadSnapshot - Reads the latest version of the snapshot from the repository.
400
func (r *DataReader) HeadSnapshot(_ context.Context, _ string) (token.SnapToken, error) {
×
401
        return snapshot.NewToken(time.Now()), nil
×
402
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc