• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

dgraph-io / dgraph / 6000148643

28 Aug 2023 12:58PM UTC coverage: 67.112% (+0.5%) from 66.655%
6000148643

push

web-flow
feat(dql): add @unique constraint support in schema for new predicates (#8827)

Partially Fixes #8827
Closes: DGRAPHCORE-206
Docs PR: https://github.com/dgraph-io/dgraph-docs/pull/638

This PR adds support for uniqueness constraint using @unique directive
in DQL schema. This unique directive ensures that all values of the
predicate are different in a Dgraph Cluster. This completes phase 1, and
enables adding a new predicate with unique directive. As part of the
phase 2, we will work on adding support for unique directive for
existing predicates.

## Performance
Live Loader before this change on 21 million dataset took 10m54s whereas
after this change took 11m02s. It did not make any significant different
to non-unique predicates.

## How to Use
You can now specify unique in schema as follows: `email: string @unique
@index(hash) @upsert .`. Now, Dgraph will ensure that no mutation adds a
duplicate for the predicate email.

## Phase 2 [TODO]
- [ ] check if @unique can be added to schema depending upon whether
existing data has any duplicates. If the existing data has any
duplicates, we do not allow adding the @unique directive and return a
query that allows user to identify these UIDs.
- [ ] If index computation is in progress, we should not allow mutations
with predicates for which @unique is set
- [ ] Fix ACL to ensure that we do not end up adding duplicate users
- [ ] Ensure that unique constraint is not violated during Bulk loader

---------

Co-authored-by: Aman Mangal <aman@dgraph.io>

347 of 347 new or added lines in 8 files covered. (100.0%)

58763 of 87560 relevant lines covered (67.11%)

2200726.47 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

54.97
/dgraph/cmd/bulk/loader.go
1
/*
2
 * Copyright 2017-2023 Dgraph Labs, Inc. and Contributors
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package bulk
18

19
import (
20
        "bytes"
21
        "compress/gzip"
22
        "context"
23
        "encoding/json"
24
        "fmt"
25
        "hash/adler32"
26
        "io"
27
        "log"
28
        "math"
29
        "os"
30
        "path/filepath"
31
        "strconv"
32
        "sync"
33
        "time"
34

35
        "github.com/golang/glog"
36
        "google.golang.org/grpc"
37
        "google.golang.org/grpc/credentials"
38
        "google.golang.org/grpc/credentials/insecure"
39

40
        "github.com/dgraph-io/badger/v4"
41
        "github.com/dgraph-io/badger/v4/y"
42
        "github.com/dgraph-io/dgraph/chunker"
43
        "github.com/dgraph-io/dgraph/ee/enc"
44
        "github.com/dgraph-io/dgraph/filestore"
45
        gqlSchema "github.com/dgraph-io/dgraph/graphql/schema"
46
        "github.com/dgraph-io/dgraph/protos/pb"
47
        "github.com/dgraph-io/dgraph/schema"
48
        "github.com/dgraph-io/dgraph/x"
49
        "github.com/dgraph-io/dgraph/xidmap"
50
)
51

52
type options struct {
53
        DataFiles        string
54
        DataFormat       string
55
        SchemaFile       string
56
        GqlSchemaFile    string
57
        OutDir           string
58
        ReplaceOutDir    bool
59
        TmpDir           string
60
        NumGoroutines    int
61
        MapBufSize       uint64
62
        PartitionBufSize int64
63
        SkipMapPhase     bool
64
        CleanupTmp       bool
65
        NumReducers      int
66
        Version          bool
67
        StoreXids        bool
68
        ZeroAddr         string
69
        HttpAddr         string
70
        IgnoreErrors     bool
71
        CustomTokenizers string
72
        NewUids          bool
73
        ClientDir        string
74
        Encrypted        bool
75
        EncryptedOut     bool
76

77
        MapShards    int
78
        ReduceShards int
79

80
        Namespace uint64
81

82
        shardOutputDirs []string
83

84
        // ........... Badger options ..........
85
        // EncryptionKey is the key used for encryption. Enterprise only feature.
86
        EncryptionKey x.Sensitive
87
        // Badger options.
88
        Badger badger.Options
89
}
90

91
type state struct {
92
        opt           *options
93
        prog          *progress
94
        xids          *xidmap.XidMap
95
        schema        *schemaStore
96
        shards        *shardMap
97
        readerChunkCh chan *bytes.Buffer
98
        mapFileId     uint32 // Used atomically to name the output files of the mappers.
99
        dbs           []*badger.DB
100
        tmpDbs        []*badger.DB // Temporary DB to write the split lists to avoid ordering issues.
101
        writeTs       uint64       // All badger writes use this timestamp
102
        namespaces    *sync.Map    // To store the encountered namespaces.
103
}
104

105
type loader struct {
106
        *state
107
        mappers []*mapper
108
        zero    *grpc.ClientConn
109
}
110

111
func newLoader(opt *options) *loader {
2✔
112
        if opt == nil {
2✔
113
                log.Fatalf("Cannot create loader with nil options.")
×
114
        }
×
115

116
        fmt.Printf("Connecting to zero at %s\n", opt.ZeroAddr)
2✔
117
        ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
2✔
118
        defer cancel()
2✔
119

2✔
120
        tlsConf, err := x.LoadClientTLSConfigForInternalPort(Bulk.Conf)
2✔
121
        x.Check(err)
2✔
122
        dialOpts := []grpc.DialOption{
2✔
123
                grpc.WithBlock(),
2✔
124
        }
2✔
125
        if tlsConf != nil {
2✔
126
                dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConf)))
×
127
        } else {
2✔
128
                dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
2✔
129
        }
2✔
130
        zero, err := grpc.DialContext(ctx, opt.ZeroAddr, dialOpts...)
2✔
131
        x.Checkf(err, "Unable to connect to zero, Is it running at %s?", opt.ZeroAddr)
2✔
132
        st := &state{
2✔
133
                opt:    opt,
2✔
134
                prog:   newProgress(),
2✔
135
                shards: newShardMap(opt.MapShards),
2✔
136
                // Lots of gz readers, so not much channel buffer needed.
2✔
137
                readerChunkCh: make(chan *bytes.Buffer, opt.NumGoroutines),
2✔
138
                writeTs:       getWriteTimestamp(zero),
2✔
139
                namespaces:    &sync.Map{},
2✔
140
        }
2✔
141
        st.schema = newSchemaStore(readSchema(opt), opt, st)
2✔
142
        ld := &loader{
2✔
143
                state:   st,
2✔
144
                mappers: make([]*mapper, opt.NumGoroutines),
2✔
145
                zero:    zero,
2✔
146
        }
2✔
147
        for i := 0; i < opt.NumGoroutines; i++ {
10✔
148
                ld.mappers[i] = newMapper(st)
8✔
149
        }
8✔
150
        go ld.prog.report()
2✔
151
        return ld
2✔
152
}
153

154
func getWriteTimestamp(zero *grpc.ClientConn) uint64 {
2✔
155
        client := pb.NewZeroClient(zero)
2✔
156
        for {
4✔
157
                ctx, cancel := context.WithTimeout(context.Background(), time.Second)
2✔
158
                ts, err := client.Timestamps(ctx, &pb.Num{Val: 1})
2✔
159
                cancel()
2✔
160
                if err == nil {
4✔
161
                        return ts.GetStartId()
2✔
162
                }
2✔
163
                fmt.Printf("Error communicating with dgraph zero, retrying: %v", err)
×
164
                time.Sleep(time.Second)
×
165
        }
166
}
167

168
// leaseNamespace is called at the end of map phase. It leases the namespace ids till the maximum
169
// seen namespace id.
170
func (ld *loader) leaseNamespaces() {
2✔
171
        var maxNs uint64
2✔
172
        ld.namespaces.Range(func(key, value interface{}) bool {
4✔
173
                if ns := key.(uint64); ns > maxNs {
2✔
174
                        maxNs = ns
×
175
                }
×
176
                return true
2✔
177
        })
178

179
        // If only the default namespace is seen, do nothing.
180
        if maxNs == 0 {
4✔
181
                return
2✔
182
        }
2✔
183

184
        client := pb.NewZeroClient(ld.zero)
×
185
        for {
×
186
                ctx, cancel := context.WithTimeout(context.Background(), time.Second)
×
187
                ns, err := client.AssignIds(ctx, &pb.Num{Val: maxNs, Type: pb.Num_NS_ID})
×
188
                cancel()
×
189
                if err == nil {
×
190
                        fmt.Printf("Assigned namespaces till %d\n", ns.GetEndId())
×
191
                        return
×
192
                }
×
193
                fmt.Printf("Error communicating with dgraph zero, retrying: %v", err)
×
194
                time.Sleep(time.Second)
×
195
        }
196
}
197

198
func readSchema(opt *options) *schema.ParsedSchema {
2✔
199
        if opt.SchemaFile == "" {
2✔
200
                return genDQLSchema(opt)
×
201
        }
×
202

203
        f, err := filestore.Open(opt.SchemaFile)
2✔
204
        x.Check(err)
2✔
205
        defer func() {
4✔
206
                if err := f.Close(); err != nil {
2✔
207
                        glog.Warningf("error while closing fd: %v", err)
×
208
                }
×
209
        }()
210

211
        key := opt.EncryptionKey
2✔
212
        if !opt.Encrypted {
4✔
213
                key = nil
2✔
214
        }
2✔
215
        r, err := enc.GetReader(key, f)
2✔
216
        x.Check(err)
2✔
217
        if filepath.Ext(opt.SchemaFile) == ".gz" {
2✔
218
                r, err = gzip.NewReader(r)
×
219
                x.Check(err)
×
220
        }
×
221

222
        buf, err := io.ReadAll(r)
2✔
223
        x.Check(err)
2✔
224

2✔
225
        result, err := schema.ParseWithNamespace(string(buf), opt.Namespace)
2✔
226
        x.Check(err)
2✔
227
        return result
2✔
228
}
229

230
func genDQLSchema(opt *options) *schema.ParsedSchema {
×
231
        gqlSchBytes := readGqlSchema(opt)
×
232
        nsToSchemas := parseGqlSchema(string(gqlSchBytes))
×
233

×
234
        var finalSch schema.ParsedSchema
×
235
        for ns, gqlSch := range nsToSchemas {
×
236
                if opt.Namespace != math.MaxUint64 {
×
237
                        ns = opt.Namespace
×
238
                }
×
239

240
                h, err := gqlSchema.NewHandler(gqlSch, false)
×
241
                x.Check(err)
×
242

×
243
                _, err = gqlSchema.FromString(h.GQLSchema(), ns)
×
244
                x.Check(err)
×
245

×
246
                ps, err := schema.ParseWithNamespace(h.DGSchema(), ns)
×
247
                x.Check(err)
×
248

×
249
                finalSch.Preds = append(finalSch.Preds, ps.Preds...)
×
250
                finalSch.Types = append(finalSch.Types, ps.Types...)
×
251
        }
252

253
        return &finalSch
×
254
}
255

256
func (ld *loader) mapStage() {
2✔
257
        ld.prog.setPhase(mapPhase)
2✔
258
        var db *badger.DB
2✔
259
        if len(ld.opt.ClientDir) > 0 {
2✔
260
                x.Check(os.MkdirAll(ld.opt.ClientDir, 0700))
×
261

×
262
                var err error
×
263
                db, err = badger.Open(badger.DefaultOptions(ld.opt.ClientDir))
×
264
                x.Checkf(err, "Error while creating badger KV posting store")
×
265
        }
×
266
        ld.xids = xidmap.New(xidmap.XidMapOptions{
2✔
267
                UidAssigner: ld.zero,
2✔
268
                DB:          db,
2✔
269
                Dir:         filepath.Join(ld.opt.TmpDir, bufferDir),
2✔
270
        })
2✔
271

2✔
272
        fs := filestore.NewFileStore(ld.opt.DataFiles)
2✔
273

2✔
274
        files := fs.FindDataFiles(ld.opt.DataFiles, []string{".rdf", ".rdf.gz", ".json", ".json.gz"})
2✔
275
        if len(files) == 0 {
2✔
276
                fmt.Printf("No data files found in %s.\n", ld.opt.DataFiles)
×
277
                os.Exit(1)
×
278
        }
×
279

280
        // Because mappers must handle chunks that may be from different input files, they must all
281
        // assume the same data format, either RDF or JSON. Use the one specified by the user or by
282
        // the first load file.
283
        loadType := chunker.DataFormat(files[0], ld.opt.DataFormat)
2✔
284
        if loadType == chunker.UnknownFormat {
2✔
285
                // Dont't try to detect JSON input in bulk loader.
×
286
                fmt.Printf("Need --format=rdf or --format=json to load %s", files[0])
×
287
                os.Exit(1)
×
288
        }
×
289

290
        var mapperWg sync.WaitGroup
2✔
291
        mapperWg.Add(len(ld.mappers))
2✔
292
        for _, m := range ld.mappers {
10✔
293
                go func(m *mapper) {
16✔
294
                        m.run(loadType)
8✔
295
                        mapperWg.Done()
8✔
296
                }(m)
8✔
297
        }
298

299
        // This is the main map loop.
300
        thr := y.NewThrottle(ld.opt.NumGoroutines)
2✔
301
        for i, file := range files {
50✔
302
                x.Check(thr.Do())
48✔
303
                fmt.Printf("Processing file (%d out of %d): %s\n", i+1, len(files), file)
48✔
304

48✔
305
                go func(file string) {
96✔
306
                        defer thr.Done(nil)
48✔
307

48✔
308
                        key := ld.opt.EncryptionKey
48✔
309
                        if !ld.opt.Encrypted {
96✔
310
                                key = nil
48✔
311
                        }
48✔
312
                        r, cleanup := fs.ChunkReader(file, key)
48✔
313
                        defer cleanup()
48✔
314

48✔
315
                        chunk := chunker.NewChunker(loadType, 1000)
48✔
316
                        for {
326✔
317
                                chunkBuf, err := chunk.Chunk(r)
278✔
318
                                if chunkBuf != nil && chunkBuf.Len() > 0 {
556✔
319
                                        ld.readerChunkCh <- chunkBuf
278✔
320
                                }
278✔
321
                                if err == io.EOF {
326✔
322
                                        break
48✔
323
                                } else if err != nil {
230✔
324
                                        x.Check(err)
×
325
                                }
×
326
                        }
327
                }(file)
328
        }
329
        x.Check(thr.Finish())
2✔
330

2✔
331
        // Send the graphql triples
2✔
332
        ld.processGqlSchema(loadType)
2✔
333

2✔
334
        close(ld.readerChunkCh)
2✔
335
        mapperWg.Wait()
2✔
336

2✔
337
        // Allow memory to GC before the reduce phase.
2✔
338
        for i := range ld.mappers {
10✔
339
                ld.mappers[i] = nil
8✔
340
        }
8✔
341
        x.Check(ld.xids.Flush())
2✔
342
        if db != nil {
2✔
343
                x.Check(db.Close())
×
344
        }
×
345
        ld.xids = nil
2✔
346
}
347

348
func parseGqlSchema(s string) map[uint64]string {
×
349
        var schemas []x.ExportedGQLSchema
×
350
        if err := json.Unmarshal([]byte(s), &schemas); err != nil {
×
351
                fmt.Println("Error while decoding the graphql schema. Assuming it to be in format < 21.03.")
×
352
                return map[uint64]string{x.GalaxyNamespace: s}
×
353
        }
×
354

355
        schemaMap := make(map[uint64]string)
×
356
        for _, schema := range schemas {
×
357
                if _, ok := schemaMap[schema.Namespace]; ok {
×
358
                        fmt.Printf("Found multiple GraphQL schema for namespace %d.", schema.Namespace)
×
359
                        continue
×
360
                }
361
                schemaMap[schema.Namespace] = schema.Schema
×
362
        }
363
        return schemaMap
×
364
}
365

366
func readGqlSchema(opt *options) []byte {
×
367
        f, err := filestore.Open(opt.GqlSchemaFile)
×
368
        x.Check(err)
×
369
        defer func() {
×
370
                if err := f.Close(); err != nil {
×
371
                        glog.Warningf("error while closing fd: %v", err)
×
372
                }
×
373
        }()
374

375
        key := opt.EncryptionKey
×
376
        if !opt.Encrypted {
×
377
                key = nil
×
378
        }
×
379
        r, err := enc.GetReader(key, f)
×
380
        x.Check(err)
×
381
        if filepath.Ext(opt.GqlSchemaFile) == ".gz" {
×
382
                r, err = gzip.NewReader(r)
×
383
                x.Check(err)
×
384
        }
×
385

386
        buf, err := io.ReadAll(r)
×
387
        x.Check(err)
×
388
        return buf
×
389
}
390

391
func (ld *loader) processGqlSchema(loadType chunker.InputFormat) {
2✔
392
        if ld.opt.GqlSchemaFile == "" {
4✔
393
                return
2✔
394
        }
2✔
395

396
        rdfSchema := `_:gqlschema <dgraph.type> "dgraph.graphql" <%#x> .
×
397
        _:gqlschema <dgraph.graphql.xid> "dgraph.graphql.schema" <%#x> .
×
398
        _:gqlschema <dgraph.graphql.schema> %s <%#x> .
×
399
        `
×
400

×
401
        jsonSchema := `{
×
402
                "namespace": "%#x",
×
403
                "dgraph.type": "dgraph.graphql",
×
404
                "dgraph.graphql.xid": "dgraph.graphql.schema",
×
405
                "dgraph.graphql.schema": %s
×
406
        }`
×
407

×
408
        process := func(ns uint64, schema string) {
×
409
                // Ignore the schema if the namespace is not already seen.
×
410
                if _, ok := ld.schema.namespaces.Load(ns); !ok {
×
411
                        fmt.Printf("No data exist for namespace: %d. Cannot load the graphql schema.", ns)
×
412
                        return
×
413
                }
×
414
                gqlBuf := &bytes.Buffer{}
×
415
                schema = strconv.Quote(schema)
×
416
                switch loadType {
×
417
                case chunker.RdfFormat:
×
418
                        x.Check2(gqlBuf.Write([]byte(fmt.Sprintf(rdfSchema, ns, ns, schema, ns))))
×
419
                case chunker.JsonFormat:
×
420
                        x.Check2(gqlBuf.Write([]byte(fmt.Sprintf(jsonSchema, ns, schema))))
×
421
                }
422
                ld.readerChunkCh <- gqlBuf
×
423
        }
424

425
        buf := readGqlSchema(ld.opt)
×
426
        schemas := parseGqlSchema(string(buf))
×
427
        if ld.opt.Namespace == math.MaxUint64 {
×
428
                // Preserve the namespace.
×
429
                for ns, schema := range schemas {
×
430
                        process(ns, schema)
×
431
                }
×
432
                return
×
433
        }
434

435
        switch len(schemas) {
×
436
        case 1:
×
437
                // User might have exported from a different namespace. So, schema.Namespace will not be
×
438
                // having the correct value.
×
439
                for _, schema := range schemas {
×
440
                        process(ld.opt.Namespace, schema)
×
441
                }
×
442
        default:
×
443
                if _, ok := schemas[ld.opt.Namespace]; !ok {
×
444
                        // We expect only a single GraphQL schema when loading into specific namespace.
×
445
                        fmt.Printf("Didn't find GraphQL schema for namespace %d. Not loading GraphQL schema.",
×
446
                                ld.opt.Namespace)
×
447
                        return
×
448
                }
×
449
                process(ld.opt.Namespace, schemas[ld.opt.Namespace])
×
450
        }
451
}
452

453
func (ld *loader) reduceStage() {
2✔
454
        ld.prog.setPhase(reducePhase)
2✔
455

2✔
456
        r := reducer{
2✔
457
                state:     ld.state,
2✔
458
                streamIds: make(map[string]uint32),
2✔
459
        }
2✔
460
        x.Check(r.run())
2✔
461
}
2✔
462

463
func (ld *loader) writeSchema() {
2✔
464
        numDBs := uint32(len(ld.dbs))
2✔
465
        preds := make([][]string, numDBs)
2✔
466

2✔
467
        // Get all predicates that have data in some DB.
2✔
468
        m := make(map[string]struct{})
2✔
469
        for i, db := range ld.dbs {
4✔
470
                preds[i] = ld.schema.getPredicates(db)
2✔
471
                for _, p := range preds[i] {
70✔
472
                        m[p] = struct{}{}
68✔
473
                }
68✔
474
        }
475

476
        // Find any predicates that don't have data in any DB
477
        // and distribute them among all the DBs.
478
        for p := range ld.schema.schemaMap {
92✔
479
                if _, ok := m[p]; !ok {
112✔
480
                        i := adler32.Checksum([]byte(p)) % numDBs
22✔
481
                        preds[i] = append(preds[i], p)
22✔
482
                }
22✔
483
        }
484

485
        // Write out each DB's final predicate list.
486
        for i, db := range ld.dbs {
4✔
487
                ld.schema.write(db, preds[i])
2✔
488
        }
2✔
489
}
490

491
func (ld *loader) cleanup() {
2✔
492
        for _, db := range ld.dbs {
4✔
493
                x.Check(db.Close())
2✔
494
        }
2✔
495
        for _, db := range ld.tmpDbs {
4✔
496
                opts := db.Opts()
2✔
497
                x.Check(db.Close())
2✔
498
                x.Check(os.RemoveAll(opts.Dir))
2✔
499
        }
2✔
500
        ld.prog.endSummary()
2✔
501
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc