• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tarantool / go-tarantool / 13930913017

18 Mar 2025 06:33PM UTC coverage: 75.939% (+0.08%) from 75.863%
13930913017

Pull #435

github

maksim.konovalov
pool: Pooler interface supports GetInfo method in TopologyEditor
Pull Request #435: pool: Pooler interface supports GetInfo method in TopologyEditor

2992 of 3940 relevant lines covered (75.94%)

9863.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.07
/dial.go
1
package tarantool
2

3
import (
4
        "bufio"
5
        "bytes"
6
        "context"
7
        "errors"
8
        "fmt"
9
        "io"
10
        "net"
11
        "os"
12
        "strings"
13
        "time"
14

15
        "github.com/tarantool/go-iproto"
16
        "github.com/vmihailenco/msgpack/v5"
17
)
18

19
const bufSize = 128 * 1024
20

21
// Greeting is a message sent by Tarantool on connect.
22
type Greeting struct {
23
        // Version is the supported protocol version.
24
        Version string
25
        // Salt is used to authenticate a user.
26
        Salt string
27
}
28

29
// writeFlusher is the interface that groups the basic Write and Flush methods.
30
type writeFlusher interface {
31
        io.Writer
32
        Flush() error
33
}
34

35
// Conn is a generic stream-oriented network connection to a Tarantool
36
// instance.
37
type Conn interface {
38
        // Read reads data from the connection.
39
        Read(b []byte) (int, error)
40
        // Write writes data to the connection. There may be an internal buffer for
41
        // better performance control from a client side.
42
        Write(b []byte) (int, error)
43
        // Flush writes any buffered data.
44
        Flush() error
45
        // Close closes the connection.
46
        // Any blocked Read or Flush operations will be unblocked and return
47
        // errors.
48
        Close() error
49
        // Greeting returns server greeting.
50
        Greeting() Greeting
51
        // ProtocolInfo returns server protocol info.
52
        ProtocolInfo() ProtocolInfo
53
        // Addr returns the connection address.
54
        Addr() net.Addr
55
}
56

57
// DialOpts is a way to configure a Dial method to create a new Conn.
58
type DialOpts struct {
59
        // IoTimeout is a timeout per a network read/write.
60
        IoTimeout time.Duration
61
}
62

63
// Dialer is the interface that wraps a method to connect to a Tarantool
64
// instance. The main idea is to provide a ready-to-work connection with
65
// basic preparation, successful authorization and additional checks.
66
//
67
// You can provide your own implementation to Connect() call if
68
// some functionality is not implemented in the connector. See NetDialer.Dial()
69
// implementation as example.
70
type Dialer interface {
71
        // Dial connects to a Tarantool instance to the address with specified
72
        // options.
73
        Dial(ctx context.Context, opts DialOpts) (Conn, error)
74
}
75

76
type tntConn struct {
77
        net    net.Conn
78
        reader io.Reader
79
        writer writeFlusher
80
}
81

82
// Addr makes tntConn satisfy the Conn interface.
83
func (c *tntConn) Addr() net.Addr {
311✔
84
        return c.net.RemoteAddr()
311✔
85
}
311✔
86

87
// Read makes tntConn satisfy the Conn interface.
88
func (c *tntConn) Read(p []byte) (int, error) {
4,976✔
89
        return c.reader.Read(p)
4,976✔
90
}
4,976✔
91

92
// Write makes tntConn satisfy the Conn interface.
93
func (c *tntConn) Write(p []byte) (int, error) {
2,445✔
94
        if l, err := c.writer.Write(p); err != nil {
2,445✔
95
                return l, err
×
96
        } else if l != len(p) {
2,445✔
97
                return l, errors.New("wrong length written")
×
98
        } else {
2,445✔
99
                return l, nil
2,445✔
100
        }
2,445✔
101
}
102

103
// Flush makes tntConn satisfy the Conn interface.
104
func (c *tntConn) Flush() error {
2,173✔
105
        return c.writer.Flush()
2,173✔
106
}
2,173✔
107

108
// Close makes tntConn satisfy the Conn interface.
109
func (c *tntConn) Close() error {
322✔
110
        return c.net.Close()
322✔
111
}
322✔
112

113
// Greeting makes tntConn satisfy the Conn interface.
114
func (c *tntConn) Greeting() Greeting {
×
115
        return Greeting{}
×
116
}
×
117

118
// ProtocolInfo makes tntConn satisfy the Conn interface.
119
func (c *tntConn) ProtocolInfo() ProtocolInfo {
×
120
        return ProtocolInfo{}
×
121
}
×
122

123
// protocolConn is a wrapper for connections, so they contain the ProtocolInfo.
124
type protocolConn struct {
125
        Conn
126
        protocolInfo ProtocolInfo
127
}
128

129
// ProtocolInfo returns ProtocolInfo of a protocolConn.
130
func (c *protocolConn) ProtocolInfo() ProtocolInfo {
1,695✔
131
        return c.protocolInfo
1,695✔
132
}
1,695✔
133

134
// greetingConn is a wrapper for connections, so they contain the Greeting.
135
type greetingConn struct {
136
        Conn
137
        greeting Greeting
138
}
139

140
// Greeting returns Greeting of a greetingConn.
141
func (c *greetingConn) Greeting() Greeting {
953✔
142
        return c.greeting
953✔
143
}
953✔
144

145
type netDialer struct {
146
        address string
147
}
148

149
func (d netDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
392✔
150
        var err error
392✔
151
        conn := new(tntConn)
392✔
152

392✔
153
        network, address := parseAddress(d.address)
392✔
154
        dialer := net.Dialer{}
392✔
155
        conn.net, err = dialer.DialContext(ctx, network, address)
392✔
156
        if err != nil {
466✔
157
                return nil, fmt.Errorf("failed to dial: %w", err)
74✔
158
        }
74✔
159

160
        dc := &deadlineIO{to: opts.IoTimeout, c: conn.net}
318✔
161
        conn.reader = bufio.NewReaderSize(dc, bufSize)
318✔
162
        conn.writer = bufio.NewWriterSize(dc, bufSize)
318✔
163

318✔
164
        return conn, nil
318✔
165
}
166

167
// NetDialer is a basic Dialer implementation.
168
type NetDialer struct {
169
        // Address is an address to connect.
170
        // It could be specified in following ways:
171
        //
172
        // - TCP connections (tcp://192.168.1.1:3013, tcp://my.host:3013,
173
        // tcp:192.168.1.1:3013, tcp:my.host:3013, 192.168.1.1:3013, my.host:3013)
174
        //
175
        // - Unix socket, first '/' or '.' indicates Unix socket
176
        // (unix:///abs/path/tnt.sock, unix:path/tnt.sock, /abs/path/tnt.sock,
177
        // ./rel/path/tnt.sock, unix/:path/tnt.sock)
178
        Address string
179
        // Username for logging in to Tarantool.
180
        User string
181
        // User password for logging in to Tarantool.
182
        Password string
183
        // RequiredProtocol contains minimal protocol version and
184
        // list of protocol features that should be supported by
185
        // Tarantool server. By default, there are no restrictions.
186
        RequiredProtocolInfo ProtocolInfo
187
}
188

189
// Dial makes NetDialer satisfy the Dialer interface.
190
func (d NetDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
392✔
191
        dialer := AuthDialer{
392✔
192
                Dialer: ProtocolDialer{
392✔
193
                        Dialer: GreetingDialer{
392✔
194
                                Dialer: netDialer{
392✔
195
                                        address: d.Address,
392✔
196
                                },
392✔
197
                        },
392✔
198
                        RequiredProtocolInfo: d.RequiredProtocolInfo,
392✔
199
                },
392✔
200
                Auth:     ChapSha1Auth,
392✔
201
                Username: d.User,
392✔
202
                Password: d.Password,
392✔
203
        }
392✔
204

392✔
205
        return dialer.Dial(ctx, opts)
392✔
206
}
392✔
207

208
type fdAddr struct {
209
        Fd uintptr
210
}
211

212
func (a fdAddr) Network() string {
×
213
        return "fd"
×
214
}
×
215

216
func (a fdAddr) String() string {
×
217
        return fmt.Sprintf("fd://%d", a.Fd)
×
218
}
×
219

220
type fdConn struct {
221
        net.Conn
222
        Addr fdAddr
223
}
224

225
func (c *fdConn) RemoteAddr() net.Addr {
1✔
226
        return c.Addr
1✔
227
}
1✔
228

229
type fdDialer struct {
230
        fd uintptr
231
}
232

233
func (d fdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
6✔
234
        file := os.NewFile(d.fd, "")
6✔
235
        c, err := net.FileConn(file)
6✔
236
        if err != nil {
6✔
237
                return nil, fmt.Errorf("failed to dial: %w", err)
×
238
        }
×
239

240
        conn := new(tntConn)
6✔
241
        conn.net = &fdConn{Conn: c, Addr: fdAddr{Fd: d.fd}}
6✔
242

6✔
243
        dc := &deadlineIO{to: opts.IoTimeout, c: conn.net}
6✔
244
        conn.reader = bufio.NewReaderSize(dc, bufSize)
6✔
245
        conn.writer = bufio.NewWriterSize(dc, bufSize)
6✔
246

6✔
247
        return conn, nil
6✔
248
}
249

250
// FdDialer allows using an existing socket fd for connection.
251
type FdDialer struct {
252
        // Fd is a socket file descriptor.
253
        Fd uintptr
254
        // RequiredProtocol contains minimal protocol version and
255
        // list of protocol features that should be supported by
256
        // Tarantool server. By default, there are no restrictions.
257
        RequiredProtocolInfo ProtocolInfo
258
}
259

260
// Dial makes FdDialer satisfy the Dialer interface.
261
func (d FdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
6✔
262
        dialer := ProtocolDialer{
6✔
263
                Dialer: GreetingDialer{
6✔
264
                        Dialer: fdDialer{
6✔
265
                                fd: d.Fd,
6✔
266
                        },
6✔
267
                },
6✔
268
                RequiredProtocolInfo: d.RequiredProtocolInfo,
6✔
269
        }
6✔
270

6✔
271
        return dialer.Dial(ctx, opts)
6✔
272
}
6✔
273

274
// AuthDialer is a dialer-wrapper that does authentication of a user.
275
type AuthDialer struct {
276
        // Dialer is a base dialer.
277
        Dialer Dialer
278
        // Authentication options.
279
        Auth Auth
280
        // Username is a name of a user for authentication.
281
        Username string
282
        // Password is a user password for authentication.
283
        Password string
284
}
285

286
// Dial makes AuthDialer satisfy the Dialer interface.
287
func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
397✔
288
        conn, err := d.Dialer.Dial(ctx, opts)
397✔
289
        if err != nil {
475✔
290
                return conn, err
78✔
291
        }
78✔
292
        greeting := conn.Greeting()
319✔
293
        if greeting.Salt == "" {
320✔
294
                conn.Close()
1✔
295
                return nil, fmt.Errorf("failed to authenticate: " +
1✔
296
                        "an invalid connection without salt")
1✔
297
        }
1✔
298

299
        if d.Username == "" {
318✔
300
                return conn, nil
×
301
        }
×
302

303
        protocolAuth := conn.ProtocolInfo().Auth
318✔
304
        if d.Auth == AutoAuth {
319✔
305
                if protocolAuth != AutoAuth {
1✔
306
                        d.Auth = protocolAuth
×
307
                } else {
1✔
308
                        d.Auth = ChapSha1Auth
1✔
309
                }
1✔
310
        }
311

312
        if err := authenticate(conn, d.Auth, d.Username, d.Password,
318✔
313
                conn.Greeting().Salt); err != nil {
321✔
314
                conn.Close()
3✔
315
                return nil, fmt.Errorf("failed to authenticate: %w", err)
3✔
316
        }
3✔
317
        return conn, nil
315✔
318
}
319

320
// ProtocolDialer is a dialer-wrapper that reads and fills the ProtocolInfo
321
// of a connection.
322
type ProtocolDialer struct {
323
        // Dialer is a base dialer.
324
        Dialer Dialer
325
        // RequiredProtocol contains minimal protocol version and
326
        // list of protocol features that should be supported by
327
        // Tarantool server. By default, there are no restrictions.
328
        RequiredProtocolInfo ProtocolInfo
329
}
330

331
// Dial makes ProtocolDialer satisfy the Dialer interface.
332
func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
402✔
333
        conn, err := d.Dialer.Dial(ctx, opts)
402✔
334
        if err != nil {
479✔
335
                return conn, err
77✔
336
        }
77✔
337

338
        protocolConn := protocolConn{
325✔
339
                Conn:         conn,
325✔
340
                protocolInfo: d.RequiredProtocolInfo,
325✔
341
        }
325✔
342

325✔
343
        protocolConn.protocolInfo, err = identify(&protocolConn)
325✔
344
        if err != nil {
328✔
345
                protocolConn.Close()
3✔
346
                return nil, fmt.Errorf("failed to identify: %w", err)
3✔
347
        }
3✔
348

349
        err = checkProtocolInfo(d.RequiredProtocolInfo, protocolConn.protocolInfo)
322✔
350
        if err != nil {
325✔
351
                protocolConn.Close()
3✔
352
                return nil, fmt.Errorf("invalid server protocol: %w", err)
3✔
353
        }
3✔
354

355
        return &protocolConn, nil
319✔
356
}
357

358
// GreetingDialer is a dialer-wrapper that reads and fills the Greeting
359
// of a connection.
360
type GreetingDialer struct {
361
        // Dialer is a base dialer.
362
        Dialer Dialer
363
}
364

365
// Dial makes GreetingDialer satisfy the Dialer interface.
366
func (d GreetingDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
401✔
367
        conn, err := d.Dialer.Dial(ctx, opts)
401✔
368
        if err != nil {
476✔
369
                return conn, err
75✔
370
        }
75✔
371

372
        greetingConn := greetingConn{
326✔
373
                Conn: conn,
326✔
374
        }
326✔
375
        version, salt, err := readGreeting(greetingConn)
326✔
376
        if err != nil {
329✔
377
                greetingConn.Close()
3✔
378
                return nil, fmt.Errorf("failed to read greeting: %w", err)
3✔
379
        }
3✔
380
        greetingConn.greeting = Greeting{
323✔
381
                Version: version,
323✔
382
                Salt:    salt,
323✔
383
        }
323✔
384

323✔
385
        return &greetingConn, err
323✔
386
}
387

388
// parseAddress split address into network and address parts.
389
func parseAddress(address string) (string, string) {
392✔
390
        network := "tcp"
392✔
391
        addrLen := len(address)
392✔
392

392✔
393
        if addrLen > 0 && (address[0] == '.' || address[0] == '/') {
392✔
394
                network = "unix"
×
395
        } else if addrLen >= 7 && address[0:7] == "unix://" {
392✔
396
                network = "unix"
×
397
                address = address[7:]
×
398
        } else if addrLen >= 5 && address[0:5] == "unix:" {
392✔
399
                network = "unix"
×
400
                address = address[5:]
×
401
        } else if addrLen >= 6 && address[0:6] == "unix/:" {
392✔
402
                network = "unix"
×
403
                address = address[6:]
×
404
        } else if addrLen >= 6 && address[0:6] == "tcp://" {
392✔
405
                address = address[6:]
×
406
        } else if addrLen >= 4 && address[0:4] == "tcp:" {
392✔
407
                address = address[4:]
×
408
        }
×
409

410
        return network, address
392✔
411
}
412

413
// readGreeting reads a greeting message.
414
func readGreeting(reader io.Reader) (string, string, error) {
326✔
415
        var version, salt string
326✔
416

326✔
417
        data := make([]byte, 128)
326✔
418
        _, err := io.ReadFull(reader, data)
326✔
419
        if err == nil {
649✔
420
                version = bytes.NewBuffer(data[:64]).String()
323✔
421
                salt = bytes.NewBuffer(data[64:108]).String()
323✔
422
        }
323✔
423

424
        return version, salt, err
326✔
425
}
426

427
// identify sends info about client protocol, receives info
428
// about server protocol in response and stores it in the connection.
429
func identify(conn Conn) (ProtocolInfo, error) {
325✔
430
        var info ProtocolInfo
325✔
431

325✔
432
        req := NewIdRequest(clientProtocolInfo)
325✔
433
        if err := writeRequest(conn, req); err != nil {
325✔
434
                return info, err
×
435
        }
×
436

437
        resp, err := readResponse(conn, req)
325✔
438
        if err != nil {
330✔
439
                if resp != nil &&
5✔
440
                        resp.Header().Error == iproto.ER_UNKNOWN_REQUEST_TYPE {
7✔
441
                        // IPROTO_ID requests are not supported by server.
2✔
442
                        return info, nil
2✔
443
                }
2✔
444
                return info, err
3✔
445
        }
446
        data, err := resp.Decode()
320✔
447
        if err != nil {
320✔
448
                return info, err
×
449
        }
×
450

451
        if len(data) == 0 {
320✔
452
                return info, errors.New("unexpected response: no data")
×
453
        }
×
454

455
        info, ok := data[0].(ProtocolInfo)
320✔
456
        if !ok {
320✔
457
                return info, errors.New("unexpected response: wrong data")
×
458
        }
×
459

460
        return info, nil
320✔
461
}
462

463
// checkProtocolInfo checks that required protocol version is
464
// and protocol features are supported.
465
func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error {
322✔
466
        if required.Version > actual.Version {
322✔
467
                return fmt.Errorf("protocol version %d is not supported",
×
468
                        required.Version)
×
469
        }
×
470

471
        // It seems that iterating over a small list is way faster
472
        // than building a map: https://stackoverflow.com/a/52710077/11646599
473
        var missed []string
322✔
474
        for _, requiredFeature := range required.Features {
328✔
475
                found := false
6✔
476
                for _, actualFeature := range actual.Features {
28✔
477
                        if requiredFeature == actualFeature {
25✔
478
                                found = true
3✔
479
                        }
3✔
480
                }
481
                if !found {
9✔
482
                        missed = append(missed, requiredFeature.String())
3✔
483
                }
3✔
484
        }
485

486
        switch {
322✔
487
        case len(missed) == 1:
3✔
488
                return fmt.Errorf("protocol feature %s is not supported", missed[0])
3✔
489
        case len(missed) > 1:
×
490
                joined := strings.Join(missed, ", ")
×
491
                return fmt.Errorf("protocol features %s are not supported", joined)
×
492
        default:
319✔
493
                return nil
319✔
494
        }
495
}
496

497
// authenticate authenticates for a connection.
498
func authenticate(c Conn, auth Auth, user string, pass string, salt string) error {
318✔
499
        var req Request
318✔
500
        var err error
318✔
501

318✔
502
        switch auth {
318✔
503
        case ChapSha1Auth:
316✔
504
                req, err = newChapSha1AuthRequest(user, pass, salt)
316✔
505
                if err != nil {
316✔
506
                        return err
×
507
                }
×
508
        case PapSha256Auth:
2✔
509
                req = newPapSha256AuthRequest(user, pass)
2✔
510
        default:
×
511
                return errors.New("unsupported method " + auth.String())
×
512
        }
513

514
        if err = writeRequest(c, req); err != nil {
318✔
515
                return err
×
516
        }
×
517
        if _, err = readResponse(c, req); err != nil {
321✔
518
                return err
3✔
519
        }
3✔
520
        return nil
315✔
521
}
522

523
// writeRequest writes a request to the writer.
524
func writeRequest(w writeFlusher, req Request) error {
647✔
525
        var packet smallWBuf
647✔
526
        err := pack(&packet, msgpack.NewEncoder(&packet), 0, req, ignoreStreamId, nil)
647✔
527

647✔
528
        if err != nil {
647✔
529
                return fmt.Errorf("pack error: %w", err)
×
530
        }
×
531
        if _, err = w.Write(packet.b); err != nil {
647✔
532
                return fmt.Errorf("write error: %w", err)
×
533
        }
×
534
        if err = w.Flush(); err != nil {
647✔
535
                return fmt.Errorf("flush error: %w", err)
×
536
        }
×
537
        return err
647✔
538
}
539

540
// readResponse reads a response from the reader.
541
func readResponse(r io.Reader, req Request) (Response, error) {
643✔
542
        var lenbuf [packetLengthBytes]byte
643✔
543

643✔
544
        respBytes, err := read(r, lenbuf[:])
643✔
545
        if err != nil {
647✔
546
                return nil, fmt.Errorf("read error: %w", err)
4✔
547
        }
4✔
548

549
        buf := smallBuf{b: respBytes}
639✔
550
        header, _, err := decodeHeader(msgpack.NewDecoder(&smallBuf{}), &buf)
639✔
551
        if err != nil {
639✔
552
                return nil, fmt.Errorf("decode response header error: %w", err)
×
553
        }
×
554
        resp, err := req.Response(header, &buf)
639✔
555
        if err != nil {
639✔
556
                return nil, fmt.Errorf("creating response error: %w", err)
×
557
        }
×
558
        _, err = resp.Decode()
639✔
559
        if err != nil {
643✔
560
                switch err.(type) {
4✔
561
                case Error:
4✔
562
                        return resp, err
4✔
563
                default:
×
564
                        return resp, fmt.Errorf("decode response body error: %w", err)
×
565
                }
566
        }
567
        return resp, nil
635✔
568
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc