• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

go-sql-driver / mysql / 14675550227

26 Apr 2025 12:08AM UTC coverage: 82.41% (-0.6%) from 82.961%
14675550227

push

github

web-flow
MariaDB Metadata skipping and DEPRECATE_EOF (#1708)

[MariaDB metadata skipping](https://mariadb.com/kb/en/mariadb-protocol-differences-with-mysql/#prepare-statement-skipping-metadata).

With this change, MariaDB server won't send metadata when they have not changed, saving client parsing metadata and network.

This feature rely on these changes:
* extended capabilities support 
* EOF packet deprecation makes current implementation to be revised

A benchmark BenchmarkReceiveMetadata has been added to show the difference.

155 of 194 new or added lines in 5 files covered. (79.9%)

6 existing lines in 2 files now uncovered.

3303 of 4008 relevant lines covered (82.41%)

2391188.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.54
/packets.go
1
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2
//
3
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
4
//
5
// This Source Code Form is subject to the terms of the Mozilla Public
6
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7
// You can obtain one at http://mozilla.org/MPL/2.0/.
8

9
package mysql
10

11
import (
12
        "bytes"
13
        "crypto/tls"
14
        "database/sql/driver"
15
        "encoding/binary"
16
        "encoding/json"
17
        "fmt"
18
        "io"
19
        "math"
20
        "strconv"
21
        "time"
22
)
23

24
// MySQL client/server protocol documentations.
25
// https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html
26
// https://mariadb.com/kb/en/clientserver-protocol/
27

28
// read n bytes from mc.buf
29
func (mc *mysqlConn) readNext(n int) ([]byte, error) {
29,944,505✔
30
        if mc.buf.len() < n {
30,047,798✔
31
                err := mc.buf.fill(n, mc.readWithTimeout)
103,293✔
32
                if err != nil {
105,500✔
33
                        return nil, err
2,207✔
34
                }
2,207✔
35
        }
36
        return mc.buf.readNext(n), nil
29,942,298✔
37
}
38

39
// Read packet to buffer 'data'
40
func (mc *mysqlConn) readPacket() ([]byte, error) {
22,403,297✔
41
        var prevData []byte
22,403,297✔
42
        invalidSequence := false
22,403,297✔
43

22,403,297✔
44
        readNext := mc.readNext
22,403,297✔
45
        if mc.compress {
29,852,834✔
46
                readNext = mc.compIO.readNext
7,449,537✔
47
        }
7,449,537✔
48

49
        for {
44,807,010✔
50
                // read packet header
22,403,713✔
51
                data, err := readNext(4)
22,403,713✔
52
                if err != nil {
22,405,920✔
53
                        mc.close()
2,207✔
54
                        if cerr := mc.canceled.Value(); cerr != nil {
4,350✔
55
                                return nil, cerr
2,143✔
56
                        }
2,143✔
57
                        mc.log(err)
64✔
58
                        return nil, ErrInvalidConn
64✔
59
                }
60

61
                // packet length [24 bit]
62
                pktLen := getUint24(data[:3])
22,401,506✔
63
                seq := data[3]
22,401,506✔
64

22,401,506✔
65
                if mc.compress {
29,851,139✔
66
                        // MySQL and MariaDB doesn't check packet nr in compressed packet.
7,449,633✔
67
                        if debug && seq != mc.compressSequence {
7,449,633✔
68
                                fmt.Printf("[debug] mismatched compression sequence nr: expected: %v, got %v",
×
69
                                        mc.compressSequence, seq)
×
70
                        }
×
71
                        mc.compressSequence = seq + 1
7,449,633✔
72
                } else {
14,951,873✔
73
                        // check packet sync [8 bit]
14,951,873✔
74
                        if seq != mc.sequence {
14,951,937✔
75
                                mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seq))
64✔
76
                                // For large packets, we stop reading as soon as sync error.
64✔
77
                                if len(prevData) > 0 {
64✔
78
                                        mc.close()
×
79
                                        return nil, ErrPktSyncMul
×
80
                                }
×
81
                                invalidSequence = true
64✔
82
                        }
83
                        mc.sequence++
14,951,873✔
84
                }
85

86
                // packets with length 0 terminate a previous packet which is a
87
                // multiple of (2^24)-1 bytes long
88
                if pktLen == 0 {
22,401,602✔
89
                        // there was no previous packet
96✔
90
                        if prevData == nil {
128✔
91
                                mc.log(ErrMalformPkt)
32✔
92
                                mc.close()
32✔
93
                                return nil, ErrInvalidConn
32✔
94
                        }
32✔
95
                        return prevData, nil
64✔
96
                }
97

98
                // read packet body [pktLen bytes]
99
                data, err = readNext(pktLen)
22,401,410✔
100
                if err != nil {
22,401,410✔
101
                        mc.close()
×
102
                        if cerr := mc.canceled.Value(); cerr != nil {
×
103
                                return nil, cerr
×
104
                        }
×
105
                        mc.log(err)
×
106
                        return nil, ErrInvalidConn
×
107
                }
108

109
                // return data if this was the last packet
110
                if pktLen < maxPacketSize {
44,802,404✔
111
                        // zero allocations for non-split packets
22,400,994✔
112
                        if prevData != nil {
22,401,218✔
113
                                data = append(prevData, data...)
224✔
114
                        }
224✔
115
                        if invalidSequence {
22,401,058✔
116
                                mc.close()
64✔
117
                                // return sync error only for regular packet.
64✔
118
                                // error packets may have wrong sequence number.
64✔
119
                                if data[0] != iERR {
128✔
120
                                        return nil, ErrPktSync
64✔
121
                                }
64✔
122
                        }
123
                        return data, nil
22,400,930✔
124
                }
125

126
                prevData = append(prevData, data...)
416✔
127
        }
128
}
129

130
// Write packet buffer 'data'
131
func (mc *mysqlConn) writePacket(data []byte) error {
102,666✔
132
        pktLen := len(data) - 4
102,666✔
133
        if pktLen > mc.maxAllowedPacket {
102,666✔
134
                return ErrPktTooLarge
×
135
        }
×
136

137
        writeFunc := mc.writeWithTimeout
102,666✔
138
        if mc.compress {
127,238✔
139
                writeFunc = mc.compIO.writePackets
24,572✔
140
        }
24,572✔
141

142
        for {
205,620✔
143
                size := min(maxPacketSize, pktLen)
102,954✔
144
                putUint24(data[:3], size)
102,954✔
145
                data[3] = mc.sequence
102,954✔
146

102,954✔
147
                // Write packet
102,954✔
148
                if debug {
102,954✔
149
                        fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence)
×
150
                }
×
151

152
                n, err := writeFunc(data[:4+size])
102,954✔
153
                if err != nil {
103,097✔
154
                        mc.cleanup()
143✔
155
                        if cerr := mc.canceled.Value(); cerr != nil {
210✔
156
                                return cerr
67✔
157
                        }
67✔
158
                        if n == 0 && pktLen == len(data)-4 {
120✔
159
                                // only for the first loop iteration when nothing was written yet
44✔
160
                                mc.log(err)
44✔
161
                                return errBadConnNoWrite
44✔
162
                        } else {
76✔
163
                                return err
32✔
164
                        }
32✔
165
                }
166
                if n != 4+size {
102,811✔
167
                        // io.Writer(b) must return a non-nil error if it cannot write len(b) bytes.
×
168
                        // The io.ErrShortWrite error is used to indicate that this rule has not been followed.
×
169
                        mc.cleanup()
×
170
                        return io.ErrShortWrite
×
171
                }
×
172

173
                mc.sequence++
102,811✔
174
                if size != maxPacketSize {
205,334✔
175
                        return nil
102,523✔
176
                }
102,523✔
177
                pktLen -= size
288✔
178
                data = data[size:]
288✔
179
        }
180
}
181

182
/******************************************************************************
183
*                           Initialization Process                            *
184
******************************************************************************/
185

186
// Handshake Initialization Packet
187
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
188
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
189
func (mc *mysqlConn) readHandshakePacket() (data []byte, capabilities capabilityFlag, extendedCapabilities extendedCapabilityFlag, plugin string, err error) {
17,090✔
190
        data, err = mc.readPacket()
17,090✔
191
        if err != nil {
17,176✔
192
                return
86✔
193
        }
86✔
194

195
        if data[0] == iERR {
17,099✔
196
                err = mc.handleErrorPacket(data)
95✔
197
                return
95✔
198
        }
95✔
199

200
        // protocol version [1 byte]
201
        if data[0] < minProtocolVersion {
16,909✔
NEW
202
                return nil, 0, 0, "", fmt.Errorf(
×
203
                        "unsupported protocol version %d. Version %d or higher is required",
×
204
                        data[0],
×
205
                        minProtocolVersion,
×
206
                )
×
207
        }
×
208

209
        // server version [null terminated string]
210
        // connection id [4 bytes]
211
        pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
16,909✔
212

16,909✔
213
        // first part of the password cipher [8 bytes]
16,909✔
214
        authData := data[pos : pos+8]
16,909✔
215

16,909✔
216
        // (filler) always 0x00 [1 byte]
16,909✔
217
        pos += 8 + 1
16,909✔
218

16,909✔
219
        // capability flags (lower 2 bytes) [2 bytes]
16,909✔
220
        capabilities = capabilityFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
16,909✔
221
        if capabilities&clientProtocol41 == 0 {
16,909✔
NEW
222
                return nil, capabilities, 0, "", ErrOldProtocol
×
UNCOV
223
        }
×
224
        if capabilities&clientSSL == 0 && mc.cfg.TLS != nil {
16,909✔
225
                if mc.cfg.AllowFallbackToPlaintext {
×
226
                        mc.cfg.TLS = nil
×
227
                } else {
×
NEW
228
                        return nil, capabilities, 0, "", ErrNoTLS
×
229
                }
×
230
        }
231
        pos += 2
16,909✔
232

16,909✔
233
        if len(data) > pos {
33,818✔
234
                // character set [1 byte]
16,909✔
235
                // status flags [2 bytes]
16,909✔
236
                pos += 3
16,909✔
237
                // capability flags (upper 2 bytes) [2 bytes]
16,909✔
238
                capabilities |= capabilityFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
16,909✔
239
                pos += 2
16,909✔
240
                // length of auth-plugin-data [1 byte]
16,909✔
241
                // reserved (all [00]) [6 bytes]
16,909✔
242
                pos += 7
16,909✔
243
                if capabilities&clientMySQL == 0 {
26,434✔
244
                        // MariaDB server extended flag
9,525✔
245
                        extendedCapabilities = extendedCapabilityFlag(binary.LittleEndian.Uint32(data[pos : pos+4]))
9,525✔
246
                }
9,525✔
247
                pos += 4
16,909✔
248

16,909✔
249
                // second part of the password cipher [minimum 13 bytes],
16,909✔
250
                // where len=MAX(13, length of auth-plugin-data - 8)
16,909✔
251
                //
16,909✔
252
                // The web documentation is ambiguous about the length. However,
16,909✔
253
                // according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
16,909✔
254
                // the 13th byte is "\0 byte, terminating the second part of
16,909✔
255
                // a scramble". So the second part of the password cipher is
16,909✔
256
                // a NULL terminated string that's at least 13 bytes with the
16,909✔
257
                // last byte being NULL.
16,909✔
258
                //
16,909✔
259
                // The official Python library uses the fixed length 12
16,909✔
260
                // which seems to work but technically could have a hidden bug.
16,909✔
261
                authData = append(authData, data[pos:pos+12]...)
16,909✔
262
                pos += 13
16,909✔
263

16,909✔
264
                // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
16,909✔
265
                // \NUL otherwise
16,909✔
266
                if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
33,786✔
267
                        plugin = string(data[pos : pos+end])
16,877✔
268
                } else {
16,909✔
269
                        plugin = string(data[pos:])
32✔
270
                }
32✔
271

272
                // make a memory safe copy of the cipher slice
273
                var b [20]byte
16,909✔
274
                copy(b[:], authData)
16,909✔
275
                return b[:], capabilities, extendedCapabilities, plugin, nil
16,909✔
276
        }
277

278
        // make a memory safe copy of the cipher slice
279
        var b [8]byte
×
280
        copy(b[:], authData)
×
NEW
281
        return b[:], capabilities, 0, plugin, nil
×
282
}
283

284
// initCapabilities initializes the capabilities based on server support and configuration
285
func (mc *mysqlConn) initCapabilities(serverCapabilities capabilityFlag, serverExtCapabilities extendedCapabilityFlag, cfg *Config) {
16,877✔
286
        clientCapabilities :=
16,877✔
287
                clientMySQL |
16,877✔
288
                        clientLongFlag |
16,877✔
289
                        clientProtocol41 |
16,877✔
290
                        clientSecureConn |
16,877✔
291
                        clientTransactions |
16,877✔
292
                        clientPluginAuthLenEncClientData |
16,877✔
293
                        clientLocalFiles |
16,877✔
294
                        clientPluginAuth |
16,877✔
295
                        clientMultiResults |
16,877✔
296
                        clientConnectAttrs |
16,877✔
297
                        clientDeprecateEOF
16,877✔
298

16,877✔
299
        if cfg.ClientFoundRows {
16,909✔
300
                clientCapabilities |= clientFoundRows
32✔
301
        }
32✔
302
        if cfg.compress {
21,762✔
303
                clientCapabilities |= clientCompress
4,885✔
304
        }
4,885✔
305
        // To enable TLS / SSL
306
        if mc.cfg.TLS != nil {
22,029✔
307
                clientCapabilities |= clientSSL
5,152✔
308
        }
5,152✔
309

310
        if mc.cfg.MultiStatements {
17,165✔
311
                clientCapabilities |= clientMultiStatements
288✔
312
        }
288✔
313
        if n := len(cfg.DBName); n > 0 {
33,754✔
314
                clientCapabilities |= clientConnectWithDB
16,877✔
315
        }
16,877✔
316

317
        // only keep client capabilities that server have
318
        mc.capabilities = clientCapabilities & serverCapabilities
16,877✔
319

16,877✔
320
        // set MariaDB extended clientCacheMetadata capability if server support it
16,877✔
321
        mc.extCapabilities = clientCacheMetadata & serverExtCapabilities
16,877✔
322
}
323

324
// Client Authentication Packet
325
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
326
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
17,325✔
327
        // packet header  4
17,325✔
328
        // capabilities   4
17,325✔
329
        // maxPacketSize  4
17,325✔
330
        // collation id   1
17,325✔
331
        // filler        23
17,325✔
332
        data, err := mc.buf.takeSmallBuffer(4*3 + 24)
17,325✔
333
        if err != nil {
17,325✔
334
                mc.cleanup()
×
335
                return err
×
336
        }
×
337
        _ = data[4*3+23] // boundery check
17,325✔
338

17,325✔
339
        // clientCapabilities [32 bit]
17,325✔
340
        binary.LittleEndian.PutUint32(data[4:], uint32(mc.capabilities))
17,325✔
341

17,325✔
342
        // MaxPacketSize [32 bit] (none)
17,325✔
343
        binary.LittleEndian.PutUint32(data[8:], 0)
17,325✔
344

17,325✔
345
        // Collation ID [1 byte]
17,325✔
346
        data[12] = defaultCollationID
17,325✔
347
        if cname := mc.cfg.Collation; cname != "" {
17,709✔
348
                colID, ok := collations[cname]
384✔
349
                if ok {
768✔
350
                        data[12] = colID
384✔
351
                } else if len(mc.cfg.charsets) > 0 {
384✔
352
                        // When cfg.charset is set, the collation is set by `SET NAMES <charset> COLLATE <collation>`.
×
353
                        return fmt.Errorf("unknown collation: %q", cname)
×
354
                }
×
355
        }
356

357
        // Filler [23 bytes] (all 0x00)
358
        // or filler 19bytes + mariadb extCapabilities
359
        pos := 13
17,325✔
360
        if mc.capabilities&clientMySQL == 0 {
27,298✔
361
                for ; pos < 13+19; pos++ {
199,460✔
362
                        data[pos] = 0
189,487✔
363
                }
189,487✔
364
                // MariaDB Extended Capabilities
365
                binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.extCapabilities))
9,973✔
366
        } else {
7,352✔
367
                for ; pos < 13+23; pos++ {
176,448✔
368
                        data[pos] = 0
169,096✔
369
                }
169,096✔
370
        }
371

372
        // SSL Connection Request Packet
373
        // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_ssl_request.html
374
        // https://mariadb.com/kb/en/connection/#sslrequest-packet
375
        if mc.cfg.TLS != nil {
22,477✔
376
                // Send TLS / SSL request packet
5,152✔
377
                if err := mc.writePacket(data); err != nil {
5,206✔
378
                        return err
54✔
379
                }
54✔
380

381
                // Switch to TLS
382
                tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
5,098✔
383
                if err := tlsConn.Handshake(); err != nil {
7,207✔
384
                        if cerr := mc.canceled.Value(); cerr != nil {
4,218✔
385
                                return cerr
2,109✔
386
                        }
2,109✔
387
                        return err
×
388
                }
389
                mc.netConn = tlsConn
2,989✔
390
        }
391

392
        // User [null terminated string]
393
        if len(mc.cfg.User) > 0 {
30,324✔
394
                data = append(data, mc.cfg.User...)
15,162✔
395
        }
15,162✔
396
        data = append(data, 0)
15,162✔
397

15,162✔
398
        // Auth Data [length encoded integer]
15,162✔
399
        data = appendLengthEncodedInteger(data, uint64(len(authResp)))
15,162✔
400
        data = append(data, authResp...)
15,162✔
401

15,162✔
402
        // Database name [null terminated string]
15,162✔
403
        if mc.capabilities&clientConnectWithDB != 0 {
29,876✔
404
                data = append(data, mc.cfg.DBName...)
14,714✔
405
                data = append(data, 0)
14,714✔
406
        }
14,714✔
407

408
        data = append(data, plugin...)
15,162✔
409
        data = append(data, 0)
15,162✔
410

15,162✔
411
        // Connection Attributes
15,162✔
412
        if mc.capabilities&clientConnectAttrs != 0 {
29,876✔
413
                connAttrsLen := len(mc.connector.encodedAttributes)
14,714✔
414
                data = appendLengthEncodedInteger(data, uint64(connAttrsLen))
14,714✔
415
                data = append(data, mc.connector.encodedAttributes...)
14,714✔
416
        }
14,714✔
417

418
        // Send Auth packet
419
        return mc.writePacket(data)
15,162✔
420
}
421

422
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
423
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
867✔
424
        pktLen := 4 + len(authData)
867✔
425
        data, err := mc.buf.takeBuffer(pktLen)
867✔
426
        if err != nil {
867✔
427
                mc.cleanup()
×
428
                return err
×
429
        }
×
430

431
        // Add the auth data [EOF]
432
        copy(data[4:], authData)
867✔
433
        return mc.writePacket(data)
867✔
434
}
435

436
/******************************************************************************
437
*                             Command Packets                                 *
438
******************************************************************************/
439

440
func (mc *mysqlConn) writeCommandPacket(command byte) error {
12,772✔
441
        // Reset Packet Sequence
12,772✔
442
        mc.resetSequence()
12,772✔
443

12,772✔
444
        data, err := mc.buf.takeSmallBuffer(4 + 1)
12,772✔
445
        if err != nil {
12,772✔
UNCOV
446
                return err
×
UNCOV
447
        }
×
448

449
        // Add command byte
450
        data[4] = command
12,772✔
451

12,772✔
452
        // Send CMD packet
12,772✔
453
        return mc.writePacket(data)
12,772✔
454
}
455

456
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
50,398✔
457
        // Reset Packet Sequence
50,398✔
458
        mc.resetSequence()
50,398✔
459

50,398✔
460
        pktLen := 1 + len(arg)
50,398✔
461
        data, err := mc.buf.takeBuffer(pktLen + 4)
50,398✔
462
        if err != nil {
50,398✔
463
                return err
×
464
        }
×
465

466
        // Add command byte
467
        data[4] = command
50,398✔
468

50,398✔
469
        // Add arg
50,398✔
470
        copy(data[5:], arg)
50,398✔
471

50,398✔
472
        // Send CMD packet
50,398✔
473
        err = mc.writePacket(data)
50,398✔
474
        mc.syncSequence()
50,398✔
475
        return err
50,398✔
476
}
477

478
func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
8,336✔
479
        // Reset Packet Sequence
8,336✔
480
        mc.resetSequence()
8,336✔
481

8,336✔
482
        data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
8,336✔
483
        if err != nil {
8,336✔
484
                return err
×
485
        }
×
486

487
        // Add command byte
488
        data[4] = command
8,336✔
489

8,336✔
490
        // Add arg [32 bit]
8,336✔
491
        binary.LittleEndian.PutUint32(data[5:], arg)
8,336✔
492

8,336✔
493
        // Send CMD packet
8,336✔
494
        return mc.writePacket(data)
8,336✔
495
}
496

497
/******************************************************************************
498
*                              Result Packets                                 *
499
******************************************************************************/
500

501
func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
16,372✔
502
        data, err := mc.readPacket()
16,372✔
503
        if err != nil {
18,185✔
504
                return nil, "", err
1,813✔
505
        }
1,813✔
506

507
        // packet indicator
508
        switch data[0] {
14,559✔
509

510
        case iOK:
9,100✔
511
                // resultUnchanged, since auth happens before any queries or
9,100✔
512
                // commands have been executed.
9,100✔
513
                return nil, "", mc.resultUnchanged().handleOkPacket(data)
9,100✔
514

515
        case iAuthMoreData:
4,713✔
516
                return data[1:], "", err
4,713✔
517

518
        case iEOF:
672✔
519
                if len(data) == 1 {
768✔
520
                        // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
96✔
521
                        return nil, "mysql_old_password", nil
96✔
522
                }
96✔
523
                pluginEndIndex := bytes.IndexByte(data, 0x00)
576✔
524
                if pluginEndIndex < 0 {
576✔
525
                        return nil, "", ErrMalformPkt
×
526
                }
×
527
                plugin := string(data[1:pluginEndIndex])
576✔
528
                authData := data[pluginEndIndex+1:]
576✔
529
                if len(authData) > 0 && authData[len(authData)-1] == 0 {
1,056✔
530
                        authData = authData[:len(authData)-1]
480✔
531
                }
480✔
532
                return authData, plugin, nil
576✔
533

534
        default: // Error otherwise
74✔
535
                return nil, "", mc.handleErrorPacket(data)
74✔
536
        }
537
}
538

539
// Returns error if Packet is not a 'Result OK'-Packet
540
func (mc *okHandler) readResultOK() error {
7,371✔
541
        data, err := mc.conn().readPacket()
7,371✔
542
        if err != nil {
7,475✔
543
                return err
104✔
544
        }
104✔
545

546
        if data[0] == iOK {
14,510✔
547
                return mc.handleOkPacket(data)
7,243✔
548
        }
7,243✔
549
        return mc.conn().handleErrorPacket(data)
24✔
550
}
551

552
// Result Set Header Packet
553
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
554
func (mc *okHandler) readResultSetHeaderPacket() (int, bool, error) {
51,742✔
555
        // handleOkPacket replaces both values; other cases leave the values unchanged.
51,742✔
556
        mc.result.affectedRows = append(mc.result.affectedRows, 0)
51,742✔
557
        mc.result.insertIds = append(mc.result.insertIds, 0)
51,742✔
558

51,742✔
559
        data, err := mc.conn().readPacket()
51,742✔
560
        if err != nil {
51,882✔
561
                return 0, false, err
140✔
562
        }
140✔
563

564
        switch data[0] {
51,602✔
565
        case iOK:
28,542✔
566
                return 0, false, mc.handleOkPacket(data)
28,542✔
567

568
        case iERR:
546✔
569
                return 0, false, mc.conn().handleErrorPacket(data)
546✔
570

571
        case iLocalInFile:
480✔
572
                return 0, false, mc.handleInFileRequest(string(data[1:]))
480✔
573
        }
574

575
        // column count
576
        // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
577
        // https://mariadb.com/kb/en/result-set-packets/#column-count-packet
578
        num, _, len := readLengthEncodedInteger(data)
22,034✔
579

22,034✔
580
        if mc.extCapabilities&clientCacheMetadata != 0 {
32,984✔
581
                return int(num), data[len] == 0x01, nil
10,950✔
582
        }
10,950✔
583
        // ignore remaining data in the packet. see #1478.
584
        return int(num), true, nil
11,084✔
585
}
586

587
// Error Packet
588
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
589
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
739✔
590
        if data[0] != iERR {
739✔
591
                return ErrMalformPkt
×
592
        }
×
593

594
        // 0xff [1 byte]
595

596
        // Error Number [16 bit uint]
597
        errno := binary.LittleEndian.Uint16(data[1:3])
739✔
598

739✔
599
        // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
739✔
600
        // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
739✔
601
        // 1836: ER_READ_ONLY_MODE
739✔
602
        if (errno == 1792 || errno == 1290 || errno == 1836) && mc.cfg.RejectReadOnly {
835✔
603
                // Oops; we are connected to a read-only connection, and won't be able
96✔
604
                // to issue any write statements. Since RejectReadOnly is configured,
96✔
605
                // we throw away this connection hoping this one would have write
96✔
606
                // permission. This is specifically for a possible race condition
96✔
607
                // during failover (e.g. on AWS Aurora). See README.md for more.
96✔
608
                //
96✔
609
                // We explicitly close the connection before returning
96✔
610
                // driver.ErrBadConn to ensure that `database/sql` purges this
96✔
611
                // connection and initiates a new one for next statement next time.
96✔
612
                mc.Close()
96✔
613
                return driver.ErrBadConn
96✔
614
        }
96✔
615

616
        me := &MySQLError{Number: errno}
643✔
617

643✔
618
        pos := 3
643✔
619

643✔
620
        // SQL State [optional: # + 5bytes string]
643✔
621
        if data[3] == 0x23 {
1,191✔
622
                copy(me.SQLState[:], data[4:4+5])
548✔
623
                pos = 9
548✔
624
        }
548✔
625

626
        // Error Message [string]
627
        me.Message = string(data[pos:])
643✔
628

643✔
629
        return me
643✔
630
}
631

632
func readStatus(b []byte) statusFlag {
64,935✔
633
        return statusFlag(b[0]) | statusFlag(b[1])<<8
64,935✔
634
}
64,935✔
635

636
// Returns an instance of okHandler for codepaths where mysqlConn.result doesn't
637
// need to be cleared first (e.g. during authentication, or while additional
638
// resultsets are being fetched.)
639
func (mc *mysqlConn) resultUnchanged() *okHandler {
14,485✔
640
        return (*okHandler)(mc)
14,485✔
641
}
14,485✔
642

643
// okHandler represents the state of the connection when mysqlConn.result has
644
// been prepared for processing of OK packets.
645
//
646
// To correctly populate mysqlConn.result (updated by handleOkPacket()), all
647
// callpaths must either:
648
//
649
// 1. first clear it using clearResult(), or
650
// 2. confirm that they don't need to (by calling resultUnchanged()).
651
//
652
// Both return an instance of type *okHandler.
653
type okHandler mysqlConn
654

655
// Exposes the underlying type's methods.
656
func (mc *okHandler) conn() *mysqlConn {
61,027✔
657
        return (*mysqlConn)(mc)
61,027✔
658
}
61,027✔
659

660
// clearResult clears the connection's stored affectedRows and insertIds
661
// fields.
662
//
663
// It returns a handler that can process OK responses.
664
func (mc *mysqlConn) clearResult() *okHandler {
92,596✔
665
        mc.result = mysqlResult{}
92,596✔
666
        return (*okHandler)(mc)
92,596✔
667
}
92,596✔
668

669
// Ok Packet
670
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
671
func (mc *okHandler) handleOkPacket(data []byte) error {
44,885✔
672
        var n, m int
44,885✔
673
        var affectedRows, insertId uint64
44,885✔
674

44,885✔
675
        // 0x00 [1 byte]
44,885✔
676

44,885✔
677
        // Affected rows [Length Coded Binary]
44,885✔
678
        affectedRows, _, n = readLengthEncodedInteger(data[1:])
44,885✔
679

44,885✔
680
        // Insert id [Length Coded Binary]
44,885✔
681
        insertId, _, m = readLengthEncodedInteger(data[1+n:])
44,885✔
682

44,885✔
683
        // Update for the current statement result (only used by
44,885✔
684
        // readResultSetHeaderPacket).
44,885✔
685
        if len(mc.result.affectedRows) > 0 {
73,715✔
686
                mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
28,830✔
687
        }
28,830✔
688
        if len(mc.result.insertIds) > 0 {
73,715✔
689
                mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
28,830✔
690
        }
28,830✔
691

692
        // server_status [2 bytes]
693
        mc.status = readStatus(data[1+n+m : 1+n+m+2])
44,885✔
694
        if mc.status&statusMoreResultsExists != 0 {
45,237✔
695
                return nil
352✔
696
        }
352✔
697

698
        // warning count [2 bytes]
699

700
        return nil
44,533✔
701
}
702

703
// Read Packets as Field Packets until EOF-Packet or an Error appears
704
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
705
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
21,830✔
706
        columns := make([]mysqlField, count)
21,830✔
707

21,830✔
708
        for i := range count {
12,634,115✔
709
                data, err := mc.readPacket()
12,612,285✔
710
                if err != nil {
12,612,285✔
711
                        return nil, err
×
712
                }
×
713

714
                // Catalog
715
                pos, err := skipLengthEncodedString(data)
12,612,285✔
716
                if err != nil {
12,612,285✔
717
                        return nil, err
×
718
                }
×
719

720
                // Database [len coded string]
721
                n, err := skipLengthEncodedString(data[pos:])
12,612,285✔
722
                if err != nil {
12,612,285✔
723
                        return nil, err
×
724
                }
×
725
                pos += n
12,612,285✔
726

12,612,285✔
727
                // Table [len coded string]
12,612,285✔
728
                if mc.cfg.ColumnsWithAlias {
12,612,349✔
729
                        tableName, _, n, err := readLengthEncodedString(data[pos:])
64✔
730
                        if err != nil {
64✔
731
                                return nil, err
×
732
                        }
×
733
                        pos += n
64✔
734
                        columns[i].tableName = string(tableName)
64✔
735
                } else {
12,612,221✔
736
                        n, err = skipLengthEncodedString(data[pos:])
12,612,221✔
737
                        if err != nil {
12,612,221✔
738
                                return nil, err
×
739
                        }
×
740
                        pos += n
12,612,221✔
741
                }
742

743
                // Original table [len coded string]
744
                n, err = skipLengthEncodedString(data[pos:])
12,612,285✔
745
                if err != nil {
12,612,285✔
746
                        return nil, err
×
747
                }
×
748
                pos += n
12,612,285✔
749

12,612,285✔
750
                // Name [len coded string]
12,612,285✔
751
                name, _, n, err := readLengthEncodedString(data[pos:])
12,612,285✔
752
                if err != nil {
12,612,285✔
753
                        return nil, err
×
754
                }
×
755
                columns[i].name = string(name)
12,612,285✔
756
                pos += n
12,612,285✔
757

12,612,285✔
758
                // Original name [len coded string]
12,612,285✔
759
                n, err = skipLengthEncodedString(data[pos:])
12,612,285✔
760
                if err != nil {
12,612,285✔
761
                        return nil, err
×
762
                }
×
763
                pos += n
12,612,285✔
764

12,612,285✔
765
                // Filler [uint8]
12,612,285✔
766
                pos++
12,612,285✔
767

12,612,285✔
768
                // Charset [charset, collation uint8]
12,612,285✔
769
                columns[i].charSet = data[pos]
12,612,285✔
770
                pos += 2
12,612,285✔
771

12,612,285✔
772
                // Length [uint32]
12,612,285✔
773
                columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
12,612,285✔
774
                pos += 4
12,612,285✔
775

12,612,285✔
776
                // Field type [uint8]
12,612,285✔
777
                columns[i].fieldType = fieldType(data[pos])
12,612,285✔
778
                pos++
12,612,285✔
779

12,612,285✔
780
                // Flags [uint16]
12,612,285✔
781
                columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
12,612,285✔
782
                pos += 2
12,612,285✔
783

12,612,285✔
784
                // Decimals [uint8]
12,612,285✔
785
                columns[i].decimals = data[pos]
12,612,285✔
786
        }
787

788
        // skip EOF packet if client does not support deprecateEOF
789
        if err := mc.skipEof(); err != nil {
21,830✔
NEW
790
                return nil, err
×
UNCOV
791
        }
×
792
        return columns, nil
21,830✔
793
}
794

795
// Read Packets as Field Packets until EOF-Packet or an Error appears
796
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
797
func (rows *textRows) readRow(dest []driver.Value) error {
16,354✔
798
        mc := rows.mc
16,354✔
799

16,354✔
800
        if rows.rs.done {
16,354✔
801
                return io.EOF
×
802
        }
×
803

804
        data, err := mc.readPacket()
16,354✔
805
        if err != nil {
16,354✔
806
                return err
×
807
        }
×
808

809
        // EOF Packet
810
        // text row packets may starts with LengthEncodedString.
811
        // In such case, 0xFE can mean string larger than 0xffffff.
812
        // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html#sect_protocol_basic_dt_int_le
813
        if data[0] == iEOF && len(data) <= 0xffffff {
17,362✔
814
                if mc.capabilities&clientDeprecateEOF == 0 {
1,008✔
NEW
815
                        // Deprecated EOF packet
×
NEW
816
                        // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_eof_packet.html
×
NEW
817
                        mc.status = readStatus(data[3:])
×
818
                } else {
1,008✔
819
                        // Ok Packet with an 0xFE header
1,008✔
820
                        _, _, n := readLengthEncodedInteger(data[1:])   // affected_rows
1,008✔
821
                        _, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id
1,008✔
822
                        mc.status = readStatus(data[1+n+m:])
1,008✔
823
                }
1,008✔
824
                rows.rs.done = true
1,008✔
825
                if !rows.HasNextResultSet() {
1,888✔
826
                        rows.mc = nil
880✔
827
                }
880✔
828
                return io.EOF
1,008✔
829
        }
830
        if data[0] == iERR {
15,346✔
831
                rows.mc = nil
×
832
                return mc.handleErrorPacket(data)
×
833
        }
×
834

835
        // RowSet Packet
836
        var (
15,346✔
837
                n      int
15,346✔
838
                isNull bool
15,346✔
839
                pos    int = 0
15,346✔
840
        )
15,346✔
841

15,346✔
842
        for i := range dest {
47,092✔
843
                // Read bytes and convert to string
31,746✔
844
                var buf []byte
31,746✔
845
                buf, isNull, n, err = readLengthEncodedString(data[pos:])
31,746✔
846
                pos += n
31,746✔
847

31,746✔
848
                if err != nil {
31,746✔
849
                        return err
×
850
                }
×
851

852
                if isNull {
33,716✔
853
                        dest[i] = nil
1,970✔
854
                        continue
1,970✔
855
                }
856

857
                switch rows.rs.columns[i].fieldType {
29,776✔
858
                case fieldTypeTimestamp,
859
                        fieldTypeDateTime,
860
                        fieldTypeDate,
861
                        fieldTypeNewDate:
3,940✔
862
                        if mc.parseTime {
6,630✔
863
                                dest[i], err = parseDateTime(buf, mc.cfg.Loc)
2,690✔
864
                        } else {
3,940✔
865
                                dest[i] = buf
1,250✔
866
                        }
1,250✔
867

868
                case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
6,936✔
869
                        dest[i], err = strconv.ParseInt(string(buf), 10, 64)
6,936✔
870

871
                case fieldTypeLongLong:
1,812✔
872
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
2,420✔
873
                                dest[i], err = strconv.ParseUint(string(buf), 10, 64)
608✔
874
                        } else {
1,812✔
875
                                dest[i], err = strconv.ParseInt(string(buf), 10, 64)
1,204✔
876
                        }
1,204✔
877

878
                case fieldTypeFloat:
736✔
879
                        var d float64
736✔
880
                        d, err = strconv.ParseFloat(string(buf), 32)
736✔
881
                        dest[i] = float32(d)
736✔
882

883
                case fieldTypeDouble:
544✔
884
                        dest[i], err = strconv.ParseFloat(string(buf), 64)
544✔
885

886
                default:
15,808✔
887
                        dest[i] = buf
15,808✔
888
                }
889
                if err != nil {
29,776✔
890
                        return err
×
891
                }
×
892
        }
893

894
        return nil
15,346✔
895
}
896

897
func (mc *mysqlConn) skipPackets(n int) error {
11,468✔
898
        for i := 0; i < n; i++ {
9,657,426✔
899
                if _, err := mc.readPacket(); err != nil {
9,645,958✔
NEW
900
                        return err
×
NEW
901
                }
×
902
        }
903
        return nil
11,468✔
904
}
905

906
// skips EOF packet after n * ColumnDefinition packets when clientDeprecateEOF is not set
907
func (mc *mysqlConn) skipEof() error {
36,658✔
908
        if mc.capabilities&clientDeprecateEOF == 0 {
36,658✔
NEW
909
                if _, err := mc.readPacket(); err != nil {
×
NEW
910
                        return err
×
NEW
911
                }
×
912
        }
913
        return nil
36,658✔
914
}
915

916
func (mc *mysqlConn) skipColumns(n int) error {
11,468✔
917
        if err := mc.skipPackets(n); err != nil {
11,468✔
NEW
918
                return err
×
NEW
919
        }
×
920
        return mc.skipEof()
11,468✔
921
}
922

923
// Reads Packets until EOF-Packet or an Error appears.
924
func (mc *mysqlConn) skipRows() error {
18,370✔
925
        for {
37,524✔
926
                data, err := mc.readPacket()
19,154✔
927
                if err != nil {
19,154✔
928
                        return err
×
929
                }
×
930

931
                switch data[0] {
19,154✔
932
                case iERR:
×
933
                        return mc.handleErrorPacket(data)
×
934
                case iEOF:
18,370✔
935
                        // text row packets may starts with LengthEncodedString.
18,370✔
936
                        // In such case, 0xFE can mean string larger than 0xffffff.
18,370✔
937
                        if len(data) <= 0xffffff {
36,740✔
938
                                if mc.capabilities&clientDeprecateEOF == 0 {
18,370✔
NEW
939
                                        // EOF packet
×
NEW
940
                                        mc.status = readStatus(data[3:])
×
941
                                } else {
18,370✔
942
                                        // OK packet with an 0xFE header
18,370✔
943
                                        _, _, n := readLengthEncodedInteger(data[1:])   // affected_rows
18,370✔
944
                                        _, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id
18,370✔
945
                                        mc.status = readStatus(data[1+n+m:])
18,370✔
946
                                }
18,370✔
947
                                return nil
18,370✔
948
                        }
949
                }
950
        }
951
}
952

953
/******************************************************************************
954
*                           Prepared Statements                               *
955
******************************************************************************/
956

957
// Prepare Result Packets
958
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
959
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
8,464✔
960
        data, err := stmt.mc.readPacket()
8,464✔
961
        if err == nil {
16,928✔
962
                // packet indicator [1 byte]
8,464✔
963
                if data[0] != iOK {
8,464✔
964
                        return 0, stmt.mc.handleErrorPacket(data)
×
965
                }
×
966

967
                // statement id [4 bytes]
968
                stmt.id = binary.LittleEndian.Uint32(data[1:5])
8,464✔
969

8,464✔
970
                // Column count [16 bit uint]
8,464✔
971
                columnCount := binary.LittleEndian.Uint16(data[5:7])
8,464✔
972

8,464✔
973
                // Param count [16 bit uint]
8,464✔
974
                stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
8,464✔
975

8,464✔
976
                // Reserved [8 bit]
8,464✔
977

8,464✔
978
                // Warning count [16 bit uint]
8,464✔
979

8,464✔
980
                return columnCount, nil
8,464✔
981
        }
982
        return 0, err
×
983
}
984

985
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
986
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
64✔
987
        maxLen := stmt.mc.maxAllowedPacket - 1
64✔
988
        pktLen := maxLen
64✔
989

64✔
990
        // After the header (bytes 0-3) follows before the data:
64✔
991
        // 1 byte command
64✔
992
        // 4 bytes stmtID
64✔
993
        // 2 bytes paramID
64✔
994
        const dataOffset = 1 + 4 + 2
64✔
995

64✔
996
        // Cannot use the write buffer since
64✔
997
        // a) the buffer is too small
64✔
998
        // b) it is in use
64✔
999
        data := make([]byte, 4+1+4+2+len(arg))
64✔
1000

64✔
1001
        copy(data[4+dataOffset:], arg)
64✔
1002

64✔
1003
        for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
128✔
1004
                if dataOffset+argLen < maxLen {
128✔
1005
                        pktLen = dataOffset + argLen
64✔
1006
                }
64✔
1007

1008
                stmt.mc.resetSequence()
64✔
1009
                // Add command byte [1 byte]
64✔
1010
                data[4] = comStmtSendLongData
64✔
1011

64✔
1012
                // Add stmtID [32 bit]
64✔
1013
                binary.LittleEndian.PutUint32(data[5:], stmt.id)
64✔
1014

64✔
1015
                // Add paramID [16 bit]
64✔
1016
                binary.LittleEndian.PutUint16(data[9:], uint16(paramID))
64✔
1017

64✔
1018
                // Send CMD packet
64✔
1019
                err := stmt.mc.writePacket(data[:4+pktLen])
64✔
1020
                if err == nil {
128✔
1021
                        data = data[pktLen-dataOffset:]
64✔
1022
                        continue
64✔
1023
                }
1024
                return err
×
1025
        }
1026

1027
        // Reset Packet Sequence
1028
        stmt.mc.resetSequence()
64✔
1029
        return nil
64✔
1030
}
1031

1032
// Execute Prepared Statement
1033
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
1034
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
8,880✔
1035
        if len(args) != stmt.paramCount {
8,880✔
1036
                return fmt.Errorf(
×
1037
                        "argument count mismatch (got: %d; has: %d)",
×
1038
                        len(args),
×
1039
                        stmt.paramCount,
×
1040
                )
×
1041
        }
×
1042

1043
        const minPktLen = 4 + 1 + 4 + 1 + 4
8,880✔
1044
        mc := stmt.mc
8,880✔
1045

8,880✔
1046
        // Determine threshold dynamically to avoid packet size shortage.
8,880✔
1047
        longDataSize := max(mc.maxAllowedPacket/(stmt.paramCount+1), 64)
8,880✔
1048

8,880✔
1049
        // Reset packet-sequence
8,880✔
1050
        mc.resetSequence()
8,880✔
1051

8,880✔
1052
        var data []byte
8,880✔
1053
        var err error
8,880✔
1054

8,880✔
1055
        if len(args) == 0 {
9,520✔
1056
                data, err = mc.buf.takeBuffer(minPktLen)
640✔
1057
        } else {
8,880✔
1058
                data, err = mc.buf.takeCompleteBuffer()
8,240✔
1059
                // In this case the len(data) == cap(data) which is used to optimise the flow below.
8,240✔
1060
        }
8,240✔
1061
        if err != nil {
8,880✔
1062
                return err
×
1063
        }
×
1064

1065
        // command [1 byte]
1066
        data[4] = comStmtExecute
8,880✔
1067

8,880✔
1068
        // statement_id [4 bytes]
8,880✔
1069
        binary.LittleEndian.PutUint32(data[5:], stmt.id)
8,880✔
1070

8,880✔
1071
        // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
8,880✔
1072
        data[9] = 0x00
8,880✔
1073

8,880✔
1074
        // iteration_count (uint32(1)) [4 bytes]
8,880✔
1075
        binary.LittleEndian.PutUint32(data[10:], 1)
8,880✔
1076

8,880✔
1077
        if len(args) > 0 {
17,120✔
1078
                pos := minPktLen
8,240✔
1079

8,240✔
1080
                var nullMask []byte
8,240✔
1081
                if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
8,432✔
1082
                        // buffer has to be extended but we don't know by how much so
192✔
1083
                        // we depend on append after all data with known sizes fit.
192✔
1084
                        // We stop at that because we deal with a lot of columns here
192✔
1085
                        // which makes the required allocation size hard to guess.
192✔
1086
                        tmp := make([]byte, pos+maskLen+typesLen)
192✔
1087
                        copy(tmp[:pos], data[:pos])
192✔
1088
                        data = tmp
192✔
1089
                        nullMask = data[pos : pos+maskLen]
192✔
1090
                        // No need to clean nullMask as make ensures that.
192✔
1091
                        pos += maskLen
192✔
1092
                } else {
8,240✔
1093
                        nullMask = data[pos : pos+maskLen]
8,048✔
1094
                        for i := range nullMask {
16,096✔
1095
                                nullMask[i] = 0
8,048✔
1096
                        }
8,048✔
1097
                        pos += maskLen
8,048✔
1098
                }
1099

1100
                // newParameterBoundFlag 1 [1 byte]
1101
                data[pos] = 0x01
8,240✔
1102
                pos++
8,240✔
1103

8,240✔
1104
                // type of each parameter [len(args)*2 bytes]
8,240✔
1105
                paramTypes := data[pos:]
8,240✔
1106
                pos += len(args) * 2
8,240✔
1107

8,240✔
1108
                // value of each parameter [n bytes]
8,240✔
1109
                paramValues := data[pos:pos]
8,240✔
1110
                valuesCap := cap(paramValues)
8,240✔
1111

8,240✔
1112
                for i, arg := range args {
12,599,680✔
1113
                        // build NULL-bitmap
12,591,440✔
1114
                        if arg == nil {
18,882,896✔
1115
                                nullMask[i/8] |= 1 << (uint(i) & 7)
6,291,456✔
1116
                                paramTypes[i+i] = byte(fieldTypeNULL)
6,291,456✔
1117
                                paramTypes[i+i+1] = 0x00
6,291,456✔
1118
                                continue
6,291,456✔
1119
                        }
1120

1121
                        if v, ok := arg.(json.RawMessage); ok {
6,300,048✔
1122
                                arg = []byte(v)
64✔
1123
                        }
64✔
1124
                        // cache types and values
1125
                        switch v := arg.(type) {
6,299,984✔
1126
                        case int64:
1,280✔
1127
                                paramTypes[i+i] = byte(fieldTypeLongLong)
1,280✔
1128
                                paramTypes[i+i+1] = 0x00
1,280✔
1129
                                paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
1,280✔
1130

1131
                        case uint64:
128✔
1132
                                paramTypes[i+i] = byte(fieldTypeLongLong)
128✔
1133
                                paramTypes[i+i+1] = 0x80 // type is unsigned
128✔
1134
                                paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
128✔
1135

1136
                        case float64:
64✔
1137
                                paramTypes[i+i] = byte(fieldTypeDouble)
64✔
1138
                                paramTypes[i+i+1] = 0x00
64✔
1139
                                paramValues = binary.LittleEndian.AppendUint64(paramValues, math.Float64bits(v))
64✔
1140

1141
                        case bool:
96✔
1142
                                paramTypes[i+i] = byte(fieldTypeTiny)
96✔
1143
                                paramTypes[i+i+1] = 0x00
96✔
1144

96✔
1145
                                if v {
128✔
1146
                                        paramValues = append(paramValues, 0x01)
32✔
1147
                                } else {
96✔
1148
                                        paramValues = append(paramValues, 0x00)
64✔
1149
                                }
64✔
1150

1151
                        case []byte:
224✔
1152
                                // Common case (non-nil value) first
224✔
1153
                                if v != nil {
384✔
1154
                                        paramTypes[i+i] = byte(fieldTypeString)
160✔
1155
                                        paramTypes[i+i+1] = 0x00
160✔
1156

160✔
1157
                                        if len(v) < longDataSize {
320✔
1158
                                                paramValues = appendLengthEncodedInteger(paramValues,
160✔
1159
                                                        uint64(len(v)),
160✔
1160
                                                )
160✔
1161
                                                paramValues = append(paramValues, v...)
160✔
1162
                                        } else {
160✔
1163
                                                if err := stmt.writeCommandLongData(i, v); err != nil {
×
1164
                                                        return err
×
1165
                                                }
×
1166
                                        }
1167
                                        continue
160✔
1168
                                }
1169

1170
                                // Handle []byte(nil) as a NULL value
1171
                                nullMask[i/8] |= 1 << (uint(i) & 7)
64✔
1172
                                paramTypes[i+i] = byte(fieldTypeNULL)
64✔
1173
                                paramTypes[i+i+1] = 0x00
64✔
1174

1175
                        case string:
6,295,912✔
1176
                                paramTypes[i+i] = byte(fieldTypeString)
6,295,912✔
1177
                                paramTypes[i+i+1] = 0x00
6,295,912✔
1178

6,295,912✔
1179
                                if len(v) < longDataSize {
12,591,760✔
1180
                                        paramValues = appendLengthEncodedInteger(paramValues,
6,295,848✔
1181
                                                uint64(len(v)),
6,295,848✔
1182
                                        )
6,295,848✔
1183
                                        paramValues = append(paramValues, v...)
6,295,848✔
1184
                                } else {
6,295,912✔
1185
                                        if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
64✔
1186
                                                return err
×
1187
                                        }
×
1188
                                }
1189

1190
                        case time.Time:
2,280✔
1191
                                paramTypes[i+i] = byte(fieldTypeString)
2,280✔
1192
                                paramTypes[i+i+1] = 0x00
2,280✔
1193

2,280✔
1194
                                var a [64]byte
2,280✔
1195
                                var b = a[:0]
2,280✔
1196

2,280✔
1197
                                if v.IsZero() {
3,152✔
1198
                                        b = append(b, "0000-00-00"...)
872✔
1199
                                } else {
2,280✔
1200
                                        b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
1,408✔
1201
                                        if err != nil {
1,408✔
1202
                                                return err
×
1203
                                        }
×
1204
                                }
1205

1206
                                paramValues = appendLengthEncodedInteger(paramValues,
2,280✔
1207
                                        uint64(len(b)),
2,280✔
1208
                                )
2,280✔
1209
                                paramValues = append(paramValues, b...)
2,280✔
1210

1211
                        default:
×
1212
                                return fmt.Errorf("cannot convert type: %T", arg)
×
1213
                        }
1214
                }
1215

1216
                // Check if param values exceeded the available buffer
1217
                // In that case we must build the data packet with the new values buffer
1218
                if valuesCap != cap(paramValues) {
8,336✔
1219
                        data = append(data[:pos], paramValues...)
96✔
1220
                        mc.buf.store(data) // allow this buffer to be reused
96✔
1221
                }
96✔
1222

1223
                pos += len(paramValues)
8,240✔
1224
                data = data[:pos]
8,240✔
1225
        }
1226

1227
        err = mc.writePacket(data)
8,880✔
1228
        mc.syncSequence()
8,880✔
1229
        return err
8,880✔
1230
}
1231

1232
// For each remaining resultset in the stream, discards its rows and updates
1233
// mc.affectedRows and mc.insertIds.
1234
func (mc *okHandler) discardResults() error {
46,352✔
1235
        for mc.status&statusMoreResultsExists != 0 {
46,608✔
1236
                resLen, _, err := mc.readResultSetHeaderPacket()
256✔
1237
                if err != nil {
320✔
1238
                        return err
64✔
1239
                }
64✔
1240
                if resLen > 0 {
192✔
1241
                        // columns
×
NEW
1242
                        if err := mc.conn().skipColumns(resLen); err != nil {
×
1243
                                return err
×
1244
                        }
×
1245
                        // rows
NEW
1246
                        if err := mc.conn().skipRows(); err != nil {
×
1247
                                return err
×
1248
                        }
×
1249
                }
1250
        }
1251
        return nil
46,288✔
1252
}
1253

1254
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
1255
func (rows *binaryRows) readRow(dest []driver.Value) error {
7,664✔
1256
        data, err := rows.mc.readPacket()
7,664✔
1257
        if err != nil {
7,664✔
1258
                return err
×
1259
        }
×
1260

1261
        // packet indicator [1 byte]
1262
        if data[0] != iOK {
8,336✔
1263
                // EOF/OK Packet
672✔
1264
                if data[0] == iEOF {
1,344✔
1265
                        if rows.mc.capabilities&clientDeprecateEOF == 0 {
672✔
NEW
1266
                                // EOF packet
×
NEW
1267
                                rows.mc.status = readStatus(data[3:])
×
1268
                        } else {
672✔
1269
                                // OK Packet with an 0xFE header
672✔
1270
                                _, _, n := readLengthEncodedInteger(data[1:])
672✔
1271
                                _, _, m := readLengthEncodedInteger(data[1+n:])
672✔
1272
                                rows.mc.status = readStatus(data[1+n+m:])
672✔
1273
                        }
672✔
1274
                        rows.rs.done = true
672✔
1275
                        if !rows.HasNextResultSet() {
1,056✔
1276
                                rows.mc = nil
384✔
1277
                        }
384✔
1278
                        return io.EOF
672✔
1279
                }
1280
                mc := rows.mc
×
1281
                rows.mc = nil
×
1282

×
1283
                // Error otherwise
×
1284
                return mc.handleErrorPacket(data)
×
1285
        }
1286

1287
        // NULL-bitmap,  [(column-count + 7 + 2) / 8 bytes]
1288
        pos := 1 + (len(dest)+7+2)>>3
6,992✔
1289
        nullMask := data[1:pos]
6,992✔
1290

6,992✔
1291
        for i := range dest {
15,264✔
1292
                // Field is NULL
8,272✔
1293
                // (byte >> bit-pos) % 2 == 1
8,272✔
1294
                if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
8,464✔
1295
                        dest[i] = nil
192✔
1296
                        continue
192✔
1297
                }
1298

1299
                // Convert to byte-coded string
1300
                switch rows.rs.columns[i].fieldType {
8,080✔
1301
                case fieldTypeNULL:
×
1302
                        dest[i] = nil
×
1303
                        continue
×
1304

1305
                // Numeric Types
1306
                case fieldTypeTiny:
64✔
1307
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
64✔
1308
                                dest[i] = int64(data[pos])
×
1309
                        } else {
64✔
1310
                                dest[i] = int64(int8(data[pos]))
64✔
1311
                        }
64✔
1312
                        pos++
64✔
1313
                        continue
64✔
1314

1315
                case fieldTypeShort, fieldTypeYear:
32✔
1316
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
32✔
1317
                                dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
×
1318
                        } else {
32✔
1319
                                dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
32✔
1320
                        }
32✔
1321
                        pos += 2
32✔
1322
                        continue
32✔
1323

1324
                case fieldTypeInt24, fieldTypeLong:
1,014✔
1325
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
1,046✔
1326
                                dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
32✔
1327
                        } else {
1,014✔
1328
                                dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
982✔
1329
                        }
982✔
1330
                        pos += 4
1,014✔
1331
                        continue
1,014✔
1332

1333
                case fieldTypeLongLong:
1,002✔
1334
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
1,130✔
1335
                                val := binary.LittleEndian.Uint64(data[pos : pos+8])
128✔
1336
                                if val > math.MaxInt64 {
192✔
1337
                                        dest[i] = uint64ToString(val)
64✔
1338
                                } else {
128✔
1339
                                        dest[i] = int64(val)
64✔
1340
                                }
64✔
1341
                        } else {
874✔
1342
                                dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
874✔
1343
                        }
874✔
1344
                        pos += 8
1,002✔
1345
                        continue
1,002✔
1346

1347
                case fieldTypeFloat:
64✔
1348
                        dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
64✔
1349
                        pos += 4
64✔
1350
                        continue
64✔
1351

1352
                case fieldTypeDouble:
64✔
1353
                        dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
64✔
1354
                        pos += 8
64✔
1355
                        continue
64✔
1356

1357
                // Length coded Binary Strings
1358
                case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
1359
                        fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
1360
                        fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
1361
                        fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
1362
                        fieldTypeVector:
256✔
1363
                        var isNull bool
256✔
1364
                        var n int
256✔
1365
                        dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
256✔
1366
                        pos += n
256✔
1367
                        if err == nil {
512✔
1368
                                if !isNull {
512✔
1369
                                        continue
256✔
1370
                                } else {
×
1371
                                        dest[i] = nil
×
1372
                                        continue
×
1373
                                }
1374
                        }
1375
                        return err
×
1376

1377
                case
1378
                        fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
1379
                        fieldTypeTime,                         // Time [-][H]HH:MM:SS[.fractal]
1380
                        fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
5,584✔
1381

5,584✔
1382
                        num, isNull, n := readLengthEncodedInteger(data[pos:])
5,584✔
1383
                        pos += n
5,584✔
1384

5,584✔
1385
                        switch {
5,584✔
1386
                        case isNull:
×
1387
                                dest[i] = nil
×
1388
                                continue
×
1389
                        case rows.rs.columns[i].fieldType == fieldTypeTime:
3,584✔
1390
                                // database/sql does not support an equivalent to TIME, return a string
3,584✔
1391
                                var dstlen uint8
3,584✔
1392
                                switch decimals := rows.rs.columns[i].decimals; decimals {
3,584✔
1393
                                case 0x00, 0x1f:
1,792✔
1394
                                        dstlen = 8
1,792✔
1395
                                case 1, 2, 3, 4, 5, 6:
1,792✔
1396
                                        dstlen = 8 + 1 + decimals
1,792✔
1397
                                default:
×
1398
                                        return fmt.Errorf(
×
1399
                                                "protocol error, illegal decimals value %d",
×
1400
                                                rows.rs.columns[i].decimals,
×
1401
                                        )
×
1402
                                }
1403
                                dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
3,584✔
1404
                        case rows.mc.parseTime:
1,000✔
1405
                                dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
1,000✔
1406
                        default:
1,000✔
1407
                                var dstlen uint8
1,000✔
1408
                                if rows.rs.columns[i].fieldType == fieldTypeDate {
1,200✔
1409
                                        dstlen = 10
200✔
1410
                                } else {
1,000✔
1411
                                        switch decimals := rows.rs.columns[i].decimals; decimals {
800✔
1412
                                        case 0x00, 0x1f:
400✔
1413
                                                dstlen = 19
400✔
1414
                                        case 1, 2, 3, 4, 5, 6:
400✔
1415
                                                dstlen = 19 + 1 + decimals
400✔
1416
                                        default:
×
1417
                                                return fmt.Errorf(
×
1418
                                                        "protocol error, illegal decimals value %d",
×
1419
                                                        rows.rs.columns[i].decimals,
×
1420
                                                )
×
1421
                                        }
1422
                                }
1423
                                dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
1,000✔
1424
                        }
1425

1426
                        if err == nil {
11,168✔
1427
                                pos += int(num)
5,584✔
1428
                                continue
5,584✔
1429
                        } else {
×
1430
                                return err
×
1431
                        }
×
1432

1433
                // Please report if this happens!
1434
                default:
×
1435
                        return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
×
1436
                }
1437
        }
1438

1439
        return nil
6,992✔
1440
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc