• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tarantool / go-tarantool / 6933294594

20 Nov 2023 04:56PM UTC coverage: 79.76% (+0.08%) from 79.68%
6933294594

Pull #349

github

askalt
docs: update according to the changes
Pull Request #349: connection: support connection via an existing socket fd

255 of 309 new or added lines in 7 files covered. (82.52%)

24 existing lines in 3 files now uncovered.

5647 of 7080 relevant lines covered (79.76%)

11113.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.6
/dial.go
1
package tarantool
2

3
import (
4
        "bufio"
5
        "bytes"
6
        "context"
7
        "errors"
8
        "fmt"
9
        "io"
10
        "net"
11
        "os"
12
        "strings"
13
        "time"
14

15
        "github.com/tarantool/go-iproto"
16
        "github.com/vmihailenco/msgpack/v5"
17
)
18

19
const bufSize = 128 * 1024
20

21
// Greeting is a message sent by Tarantool on connect.
22
type Greeting struct {
23
        Version string
24
}
25

26
// writeFlusher is the interface that groups the basic Write and Flush methods.
27
type writeFlusher interface {
28
        io.Writer
29
        Flush() error
30
}
31

32
// Conn is a generic stream-oriented network connection to a Tarantool
33
// instance.
34
type Conn interface {
35
        // Read reads data from the connection.
36
        Read(b []byte) (int, error)
37
        // Write writes data to the connection. There may be an internal buffer for
38
        // better performance control from a client side.
39
        Write(b []byte) (int, error)
40
        // Flush writes any buffered data.
41
        Flush() error
42
        // Close closes the connection.
43
        // Any blocked Read or Flush operations will be unblocked and return
44
        // errors.
45
        Close() error
46
        // Greeting returns server greeting.
47
        Greeting() Greeting
48
        // ProtocolInfo returns server protocol info.
49
        ProtocolInfo() ProtocolInfo
50
        // Addr returns the connection address.
51
        Addr() net.Addr
52
}
53

54
// DialOpts is a way to configure a Dial method to create a new Conn.
55
type DialOpts struct {
56
        // IoTimeout is a timeout per a network read/write.
57
        IoTimeout time.Duration
58
}
59

60
// Dialer is the interface that wraps a method to connect to a Tarantool
61
// instance. The main idea is to provide a ready-to-work connection with
62
// basic preparation, successful authorization and additional checks.
63
//
64
// You can provide your own implementation to Connect() call if
65
// some functionality is not implemented in the connector. See NetDialer.Dial()
66
// implementation as example.
67
type Dialer interface {
68
        // Dial connects to a Tarantool instance to the address with specified
69
        // options.
70
        Dial(ctx context.Context, opts DialOpts) (Conn, error)
71
}
72

73
type tntConn struct {
74
        net      net.Conn
75
        reader   io.Reader
76
        writer   writeFlusher
77
        greeting Greeting
78
        protocol ProtocolInfo
79
}
80

81
// rawDial does basic dial operations:
82
// reads greeting, identifies a protocol and validates it.
83
func rawDial(conn *tntConn, requiredProto ProtocolInfo) (string, error) {
710✔
84
        version, salt, err := readGreeting(conn.reader)
710✔
85
        if err != nil {
716✔
86
                return "", fmt.Errorf("failed to read greeting: %w", err)
6✔
87
        }
6✔
88
        conn.greeting.Version = version
704✔
89

704✔
90
        if conn.protocol, err = identify(conn.writer, conn.reader); err != nil {
710✔
91
                return "", fmt.Errorf("failed to identify: %w", err)
6✔
92
        }
6✔
93

94
        if err = checkProtocolInfo(requiredProto, conn.protocol); err != nil {
702✔
95
                return "", fmt.Errorf("invalid server protocol: %w", err)
4✔
96
        }
4✔
97
        return salt, err
694✔
98
}
99

100
// NetDialer is a basic Dialer implementation.
101
type NetDialer struct {
102
        // Address is an address to connect.
103
        // It could be specified in following ways:
104
        //
105
        // - TCP connections (tcp://192.168.1.1:3013, tcp://my.host:3013,
106
        // tcp:192.168.1.1:3013, tcp:my.host:3013, 192.168.1.1:3013, my.host:3013)
107
        //
108
        // - Unix socket, first '/' or '.' indicates Unix socket
109
        // (unix:///abs/path/tnt.sock, unix:path/tnt.sock, /abs/path/tnt.sock,
110
        // ./rel/path/tnt.sock, unix/:path/tnt.sock)
111
        Address string
112
        // Username for logging in to Tarantool.
113
        User string
114
        // User password for logging in to Tarantool.
115
        Password string
116
        // RequiredProtocol contains minimal protocol version and
117
        // list of protocol features that should be supported by
118
        // Tarantool server. By default, there are no restrictions.
119
        RequiredProtocolInfo ProtocolInfo
120
}
121

122
// Dial makes NetDialer satisfy the Dialer interface.
123
func (d NetDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
742✔
124
        var err error
742✔
125
        conn := new(tntConn)
742✔
126

742✔
127
        network, address := parseAddress(d.Address)
742✔
128
        dialer := net.Dialer{}
742✔
129
        conn.net, err = dialer.DialContext(ctx, network, address)
742✔
130
        if err != nil {
872✔
131
                return nil, fmt.Errorf("failed to dial: %w", err)
130✔
132
        }
130✔
133

134
        dc := &deadlineIO{to: opts.IoTimeout, c: conn.net}
612✔
135
        conn.reader = bufio.NewReaderSize(dc, bufSize)
612✔
136
        conn.writer = bufio.NewWriterSize(dc, bufSize)
612✔
137

612✔
138
        salt, err := rawDial(conn, d.RequiredProtocolInfo)
612✔
139
        if err != nil {
618✔
140
                conn.net.Close()
6✔
141
                return nil, err
6✔
142
        }
6✔
143

144
        if d.User == "" {
606✔
NEW
145
                return conn, nil
×
NEW
146
        }
×
147

148
        conn.protocol.Auth = ChapSha1Auth
606✔
149
        if err = authenticate(conn, ChapSha1Auth, d.User, d.Password, salt); err != nil {
608✔
150
                conn.net.Close()
2✔
151
                return nil, fmt.Errorf("failed to authenticate: %w", err)
2✔
152
        }
2✔
153

154
        return conn, nil
604✔
155
}
156

157
// OpenSslDialer allows to use SSL transport for connection.
158
type OpenSslDialer struct {
159
        // Address is an address to connect.
160
        // It could be specified in following ways:
161
        //
162
        // - TCP connections (tcp://192.168.1.1:3013, tcp://my.host:3013,
163
        // tcp:192.168.1.1:3013, tcp:my.host:3013, 192.168.1.1:3013, my.host:3013)
164
        //
165
        // - Unix socket, first '/' or '.' indicates Unix socket
166
        // (unix:///abs/path/tnt.sock, unix:path/tnt.sock, /abs/path/tnt.sock,
167
        // ./rel/path/tnt.sock, unix/:path/tnt.sock)
168
        Address string
169
        // Auth is an authentication method.
170
        Auth Auth
171
        // Username for logging in to Tarantool.
172
        User string
173
        // User password for logging in to Tarantool.
174
        Password string
175
        // RequiredProtocol contains minimal protocol version and
176
        // list of protocol features that should be supported by
177
        // Tarantool server. By default, there are no restrictions.
178
        RequiredProtocolInfo ProtocolInfo
179
        // SslKeyFile is a path to a private SSL key file.
180
        SslKeyFile string
181
        // SslCertFile is a path to an SSL certificate file.
182
        SslCertFile string
183
        // SslCaFile is a path to a trusted certificate authorities (CA) file.
184
        SslCaFile string
185
        // SslCiphers is a colon-separated (:) list of SSL cipher suites the connection
186
        // can use.
187
        //
188
        // We don't provide a list of supported ciphers. This is what OpenSSL
189
        // does. The only limitation is usage of TLSv1.2 (because other protocol
190
        // versions don't seem to support the GOST cipher). To add additional
191
        // ciphers (GOST cipher), you must configure OpenSSL.
192
        //
193
        // See also
194
        //
195
        // * https://www.openssl.org/docs/man1.1.1/man1/ciphers.html
196
        SslCiphers string
197
        // SslPassword is a password for decrypting the private SSL key file.
198
        // The priority is as follows: try to decrypt with SslPassword, then
199
        // try SslPasswordFile.
200
        SslPassword string
201
        // SslPasswordFile is a path to the list of passwords for decrypting
202
        // the private SSL key file. The connection tries every line from the
203
        // file as a password.
204
        SslPasswordFile string
205
}
206

207
// Dial makes OpenSslDialer satisfy the Dialer interface.
208
func (d OpenSslDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
124✔
209
        var err error
124✔
210
        conn := new(tntConn)
124✔
211

124✔
212
        network, address := parseAddress(d.Address)
124✔
213
        conn.net, err = sslDialContext(ctx, network, address, sslOpts{
124✔
214
                KeyFile:      d.SslKeyFile,
124✔
215
                CertFile:     d.SslCertFile,
124✔
216
                CaFile:       d.SslCaFile,
124✔
217
                Ciphers:      d.SslCiphers,
124✔
218
                Password:     d.SslPassword,
124✔
219
                PasswordFile: d.SslPasswordFile,
124✔
220
        })
124✔
221
        if err != nil {
160✔
222
                return nil, fmt.Errorf("failed to dial: %w", err)
36✔
223
        }
36✔
224

225
        dc := &deadlineIO{to: opts.IoTimeout, c: conn.net}
88✔
226
        conn.reader = bufio.NewReaderSize(dc, bufSize)
88✔
227
        conn.writer = bufio.NewWriterSize(dc, bufSize)
88✔
228

88✔
229
        salt, err := rawDial(conn, d.RequiredProtocolInfo)
88✔
230
        if err != nil {
94✔
231
                conn.net.Close()
6✔
232
                return nil, err
6✔
233
        }
6✔
234

235
        if d.User == "" {
82✔
NEW
236
                return conn, nil
×
NEW
237
        }
×
238

239
        if d.Auth == AutoAuth {
160✔
240
                if conn.protocol.Auth != AutoAuth {
154✔
241
                        d.Auth = conn.protocol.Auth
76✔
242
                } else {
78✔
243
                        d.Auth = ChapSha1Auth
2✔
244
                }
2✔
245
        }
246
        conn.protocol.Auth = d.Auth
82✔
247

82✔
248
        if err = authenticate(conn, d.Auth, d.User, d.Password, salt); err != nil {
84✔
249
                conn.net.Close()
2✔
250
                return nil, fmt.Errorf("failed to authenticate: %w", err)
2✔
251
        }
2✔
252

253
        return conn, nil
80✔
254
}
255

256
// FdDialer allows to use an existing socket fd for connection.
257
type FdDialer struct {
258
        // Fd is a socket file descrpitor.
259
        Fd uintptr
260
        // RequiredProtocol contains minimal protocol version and
261
        // list of protocol features that should be supported by
262
        // Tarantool server. By default, there are no restrictions.
263
        RequiredProtocolInfo ProtocolInfo
264
}
265

266
type fdAddr struct {
267
        Fd uintptr
268
}
269

NEW
270
func (a fdAddr) Network() string {
×
NEW
271
        return "fd"
×
NEW
272
}
×
273

NEW
274
func (a fdAddr) String() string {
×
NEW
275
        return fmt.Sprintf("fd://%d", a.Fd)
×
NEW
276
}
×
277

278
type fdConn struct {
279
        net.Conn
280
        Addr fdAddr
281
}
282

283
func (c *fdConn) RemoteAddr() net.Addr {
2✔
284
        return c.Addr
2✔
285
}
2✔
286

287
// Dial makes FdDialer satisfy the Dialer interface.
288
func (d FdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
10✔
289
        file := os.NewFile(d.Fd, "")
10✔
290
        c, err := net.FileConn(file)
10✔
291
        if err != nil {
10✔
NEW
292
                return nil, fmt.Errorf("failed to dial: %w", err)
×
NEW
293
        }
×
294

295
        conn := new(tntConn)
10✔
296
        conn.net = &fdConn{Conn: c, Addr: fdAddr{Fd: d.Fd}}
10✔
297

10✔
298
        dc := &deadlineIO{to: opts.IoTimeout, c: conn.net}
10✔
299
        conn.reader = bufio.NewReaderSize(dc, bufSize)
10✔
300
        conn.writer = bufio.NewWriterSize(dc, bufSize)
10✔
301

10✔
302
        _, err = rawDial(conn, d.RequiredProtocolInfo)
10✔
303
        if err != nil {
14✔
304
                conn.net.Close()
4✔
305
                return nil, err
4✔
306
        }
4✔
307

308
        return conn, nil
6✔
309
}
310

311
// Addr makes tntConn satisfy the Conn interface.
312
func (c *tntConn) Addr() net.Addr {
646✔
313
        return c.net.RemoteAddr()
646✔
314
}
646✔
315

316
// Read makes tntConn satisfy the Conn interface.
317
func (c *tntConn) Read(p []byte) (int, error) {
8,050✔
318
        return c.reader.Read(p)
8,050✔
319
}
8,050✔
320

321
// Write makes tntConn satisfy the Conn interface.
322
func (c *tntConn) Write(p []byte) (int, error) {
4,285✔
323
        if l, err := c.writer.Write(p); err != nil {
4,285✔
324
                return l, err
×
325
        } else if l != len(p) {
4,285✔
326
                return l, errors.New("wrong length written")
×
327
        } else {
4,285✔
328
                return l, nil
4,285✔
329
        }
4,285✔
330
}
331

332
// Flush makes tntConn satisfy the Conn interface.
333
func (c *tntConn) Flush() error {
3,728✔
334
        return c.writer.Flush()
3,728✔
335
}
3,728✔
336

337
// Close makes tntConn satisfy the Conn interface.
338
func (c *tntConn) Close() error {
686✔
339
        return c.net.Close()
686✔
340
}
686✔
341

342
// Greeting makes tntConn satisfy the Conn interface.
343
func (c *tntConn) Greeting() Greeting {
690✔
344
        return c.greeting
690✔
345
}
690✔
346

347
// ProtocolInfo makes tntConn satisfy the Conn interface.
348
func (c *tntConn) ProtocolInfo() ProtocolInfo {
2,816✔
349
        return c.protocol
2,816✔
350
}
2,816✔
351

352
// parseAddress split address into network and address parts.
353
func parseAddress(address string) (string, string) {
866✔
354
        network := "tcp"
866✔
355
        addrLen := len(address)
866✔
356

866✔
357
        if addrLen > 0 && (address[0] == '.' || address[0] == '/') {
866✔
358
                network = "unix"
×
359
        } else if addrLen >= 7 && address[0:7] == "unix://" {
866✔
360
                network = "unix"
×
361
                address = address[7:]
×
362
        } else if addrLen >= 5 && address[0:5] == "unix:" {
866✔
363
                network = "unix"
×
364
                address = address[5:]
×
365
        } else if addrLen >= 6 && address[0:6] == "unix/:" {
866✔
366
                network = "unix"
×
367
                address = address[6:]
×
368
        } else if addrLen >= 6 && address[0:6] == "tcp://" {
866✔
369
                address = address[6:]
×
370
        } else if addrLen >= 4 && address[0:4] == "tcp:" {
866✔
371
                address = address[4:]
×
372
        }
×
373

374
        return network, address
866✔
375
}
376

377
// readGreeting reads a greeting message.
378
func readGreeting(reader io.Reader) (string, string, error) {
710✔
379
        var version, salt string
710✔
380

710✔
381
        data := make([]byte, 128)
710✔
382
        _, err := io.ReadFull(reader, data)
710✔
383
        if err == nil {
1,414✔
384
                version = bytes.NewBuffer(data[:64]).String()
704✔
385
                salt = bytes.NewBuffer(data[64:108]).String()
704✔
386
        }
704✔
387

388
        return version, salt, err
710✔
389
}
390

391
// identify sends info about client protocol, receives info
392
// about server protocol in response and stores it in the connection.
393
func identify(w writeFlusher, r io.Reader) (ProtocolInfo, error) {
704✔
394
        var info ProtocolInfo
704✔
395

704✔
396
        req := NewIdRequest(clientProtocolInfo)
704✔
397
        if err := writeRequest(w, req); err != nil {
704✔
398
                return info, err
×
399
        }
×
400

401
        resp, err := readResponse(r)
704✔
402
        if err != nil {
716✔
403
                if iproto.Error(resp.Code) == iproto.ER_UNKNOWN_REQUEST_TYPE {
18✔
404
                        // IPROTO_ID requests are not supported by server.
6✔
405
                        return info, nil
6✔
406
                }
6✔
407

408
                return info, err
6✔
409
        }
410

411
        if len(resp.Data) == 0 {
692✔
412
                return info, errors.New("unexpected response: no data")
×
413
        }
×
414

415
        info, ok := resp.Data[0].(ProtocolInfo)
692✔
416
        if !ok {
692✔
417
                return info, errors.New("unexpected response: wrong data")
×
418
        }
×
419

420
        return info, nil
692✔
421
}
422

423
// checkProtocolInfo checks that required protocol version is
424
// and protocol features are supported.
425
func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error {
698✔
426
        if required.Version > actual.Version {
698✔
427
                return fmt.Errorf("protocol version %d is not supported",
×
428
                        required.Version)
×
429
        }
×
430

431
        // It seems that iterating over a small list is way faster
432
        // than building a map: https://stackoverflow.com/a/52710077/11646599
433
        var missed []string
698✔
434
        for _, requiredFeature := range required.Features {
704✔
435
                found := false
6✔
436
                for _, actualFeature := range actual.Features {
29✔
437
                        if requiredFeature == actualFeature {
25✔
438
                                found = true
2✔
439
                        }
2✔
440
                }
441
                if !found {
10✔
442
                        missed = append(missed, requiredFeature.String())
4✔
443
                }
4✔
444
        }
445

446
        switch {
698✔
447
        case len(missed) == 1:
4✔
448
                return fmt.Errorf("protocol feature %s is not supported", missed[0])
4✔
449
        case len(missed) > 1:
×
450
                joined := strings.Join(missed, ", ")
×
451
                return fmt.Errorf("protocol features %s are not supported", joined)
×
452
        default:
694✔
453
                return nil
694✔
454
        }
455
}
456

457
// authenticate authenticates for a connection.
458
func authenticate(c Conn, auth Auth, user string, pass string, salt string) error {
688✔
459
        var req Request
688✔
460
        var err error
688✔
461

688✔
462
        switch auth {
688✔
463
        case ChapSha1Auth:
683✔
464
                req, err = newChapSha1AuthRequest(user, pass, salt)
683✔
465
                if err != nil {
683✔
466
                        return err
×
467
                }
×
468
        case PapSha256Auth:
5✔
469
                req = newPapSha256AuthRequest(user, pass)
5✔
UNCOV
470
        default:
×
NEW
471
                return errors.New("unsupported method " + auth.String())
×
472
        }
473

474
        if err = writeRequest(c, req); err != nil {
688✔
475
                return err
×
476
        }
×
477
        if _, err = readResponse(c); err != nil {
692✔
478
                return err
4✔
479
        }
4✔
480
        return nil
684✔
481
}
482

483
// writeRequest writes a request to the writer.
484
func writeRequest(w writeFlusher, req Request) error {
1,398✔
485
        var packet smallWBuf
1,398✔
486
        err := pack(&packet, msgpack.NewEncoder(&packet), 0, req, ignoreStreamId, nil)
1,398✔
487

1,398✔
488
        if err != nil {
1,398✔
489
                return fmt.Errorf("pack error: %w", err)
×
490
        }
×
491
        if _, err = w.Write(packet.b); err != nil {
1,398✔
492
                return fmt.Errorf("write error: %w", err)
×
493
        }
×
494
        if err = w.Flush(); err != nil {
1,398✔
495
                return fmt.Errorf("flush error: %w", err)
×
496
        }
×
497
        return err
1,398✔
498
}
499

500
// readResponse reads a response from the reader.
501
func readResponse(r io.Reader) (Response, error) {
1,392✔
502
        var lenbuf [packetLengthBytes]byte
1,392✔
503

1,392✔
504
        respBytes, err := read(r, lenbuf[:])
1,392✔
505
        if err != nil {
1,402✔
506
                return Response{}, fmt.Errorf("read error: %w", err)
10✔
507
        }
10✔
508

509
        resp := Response{buf: smallBuf{b: respBytes}}
1,382✔
510
        err = resp.decodeHeader(msgpack.NewDecoder(&smallBuf{}))
1,382✔
511
        if err != nil {
1,382✔
512
                return resp, fmt.Errorf("decode response header error: %w", err)
×
513
        }
×
514

515
        err = resp.decodeBody()
1,382✔
516
        if err != nil {
1,388✔
517
                switch err.(type) {
6✔
518
                case Error:
6✔
519
                        return resp, err
6✔
520
                default:
×
521
                        return resp, fmt.Errorf("decode response body error: %w", err)
×
522
                }
523
        }
524
        return resp, nil
1,376✔
525
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc