• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

go-sql-driver / mysql / 15624673805

13 Jun 2025 01:56AM UTC coverage: 82.906% (-0.004%) from 82.91%
15624673805

push

github

web-flow
[1.9] fix PING on compressed connections (#1723)

Add missing mc.syncSequence()

Fix #1718

15 of 18 new or added lines in 2 files covered. (83.33%)

7 existing lines in 3 files now uncovered.

3264 of 3937 relevant lines covered (82.91%)

2489344.05 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.39
/packets.go
1
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2
//
3
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
4
//
5
// This Source Code Form is subject to the terms of the Mozilla Public
6
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7
// You can obtain one at http://mozilla.org/MPL/2.0/.
8

9
package mysql
10

11
import (
12
        "bytes"
13
        "crypto/tls"
14
        "database/sql/driver"
15
        "encoding/binary"
16
        "encoding/json"
17
        "fmt"
18
        "io"
19
        "math"
20
        "os"
21
        "strconv"
22
        "time"
23
)
24

25
// MySQL client/server protocol documentations.
26
// https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html
27
// https://mariadb.com/kb/en/clientserver-protocol/
28

29
// read n bytes from mc.buf
30
func (mc *mysqlConn) readNext(n int) ([]byte, error) {
33,928,237✔
31
        if mc.buf.len() < n {
34,031,902✔
32
                err := mc.buf.fill(n, mc.readWithTimeout)
103,665✔
33
                if err != nil {
105,920✔
34
                        return nil, err
2,255✔
35
                }
2,255✔
36
        }
37
        return mc.buf.readNext(n), nil
33,925,982✔
38
}
39

40
// Read packet to buffer 'data'
41
func (mc *mysqlConn) readPacket() ([]byte, error) {
25,392,102✔
42
        var prevData []byte
25,392,102✔
43
        invalidSequence := false
25,392,102✔
44

25,392,102✔
45
        readNext := mc.readNext
25,392,102✔
46
        if mc.compress {
33,839,088✔
47
                readNext = mc.compIO.readNext
8,446,986✔
48
        }
8,446,986✔
49

50
        for {
50,784,620✔
51
                // read packet header
25,392,518✔
52
                data, err := readNext(4)
25,392,518✔
53
                if err != nil {
25,394,773✔
54
                        mc.close()
2,255✔
55
                        if cerr := mc.canceled.Value(); cerr != nil {
4,446✔
56
                                return nil, cerr
2,191✔
57
                        }
2,191✔
58
                        mc.log(err)
64✔
59
                        return nil, ErrInvalidConn
64✔
60
                }
61

62
                // packet length [24 bit]
63
                pktLen := getUint24(data[:3])
25,390,263✔
64
                seq := data[3]
25,390,263✔
65

25,390,263✔
66
                // check packet sync [8 bit]
25,390,263✔
67
                if seq != mc.sequence {
25,390,327✔
68
                        mc.log(fmt.Sprintf("[warn] unexpected sequence nr: expected %v, got %v", mc.sequence, seq))
64✔
69
                        // MySQL and MariaDB doesn't check packet nr in compressed packet.
64✔
70
                        if !mc.compress {
128✔
71
                                // For large packets, we stop reading as soon as sync error.
64✔
72
                                if len(prevData) > 0 {
64✔
73
                                        mc.close()
×
74
                                        return nil, ErrPktSyncMul
×
75
                                }
×
76
                                invalidSequence = true
64✔
77
                        }
78
                }
79
                mc.sequence = seq + 1
25,390,263✔
80

25,390,263✔
81
                // packets with length 0 terminate a previous packet which is a
25,390,263✔
82
                // multiple of (2^24)-1 bytes long
25,390,263✔
83
                if pktLen == 0 {
25,390,359✔
84
                        // there was no previous packet
96✔
85
                        if prevData == nil {
128✔
86
                                mc.log(ErrMalformPkt)
32✔
87
                                mc.close()
32✔
88
                                return nil, ErrInvalidConn
32✔
89
                        }
32✔
90
                        return prevData, nil
64✔
91
                }
92

93
                // read packet body [pktLen bytes]
94
                data, err = readNext(pktLen)
25,390,167✔
95
                if err != nil {
25,390,167✔
96
                        mc.close()
×
97
                        if cerr := mc.canceled.Value(); cerr != nil {
×
98
                                return nil, cerr
×
99
                        }
×
100
                        mc.log(err)
×
101
                        return nil, ErrInvalidConn
×
102
                }
103

104
                // return data if this was the last packet
105
                if pktLen < maxPacketSize {
50,779,918✔
106
                        // zero allocations for non-split packets
25,389,751✔
107
                        if prevData != nil {
25,389,975✔
108
                                data = append(prevData, data...)
224✔
109
                        }
224✔
110
                        if invalidSequence {
25,389,815✔
111
                                mc.close()
64✔
112
                                // return sync error only for regular packet.
64✔
113
                                // error packets may have wrong sequence number.
64✔
114
                                if data[0] != iERR {
128✔
115
                                        return nil, ErrPktSync
64✔
116
                                }
64✔
117
                        }
118
                        return data, nil
25,389,687✔
119
                }
120

121
                prevData = append(prevData, data...)
416✔
122
        }
123
}
124

125
// Write packet buffer 'data'
126
func (mc *mysqlConn) writePacket(data []byte) error {
102,107✔
127
        pktLen := len(data) - 4
102,107✔
128
        if pktLen > mc.maxAllowedPacket {
102,107✔
129
                return ErrPktTooLarge
×
130
        }
×
131

132
        writeFunc := mc.writeWithTimeout
102,107✔
133
        if mc.compress {
126,681✔
134
                writeFunc = mc.compIO.writePackets
24,574✔
135
        }
24,574✔
136

137
        for {
204,502✔
138
                size := min(maxPacketSize, pktLen)
102,395✔
139
                putUint24(data[:3], size)
102,395✔
140
                data[3] = mc.sequence
102,395✔
141

102,395✔
142
                // Write packet
102,395✔
143
                if debug {
102,395✔
NEW
144
                        fmt.Fprintf(os.Stderr, "writePacket: size=%v seq=%v\n", size, mc.sequence)
×
145
                }
×
146

147
                n, err := writeFunc(data[:4+size])
102,395✔
148
                if err != nil {
102,524✔
149
                        mc.cleanup()
129✔
150
                        if cerr := mc.canceled.Value(); cerr != nil {
182✔
151
                                return cerr
53✔
152
                        }
53✔
153
                        if n == 0 && pktLen == len(data)-4 {
120✔
154
                                // only for the first loop iteration when nothing was written yet
44✔
155
                                mc.log(err)
44✔
156
                                return errBadConnNoWrite
44✔
157
                        } else {
76✔
158
                                return err
32✔
159
                        }
32✔
160
                }
161
                if n != 4+size {
102,266✔
162
                        // io.Writer(b) must return a non-nil error if it cannot write len(b) bytes.
×
163
                        // The io.ErrShortWrite error is used to indicate that this rule has not been followed.
×
164
                        mc.cleanup()
×
165
                        return io.ErrShortWrite
×
166
                }
×
167

168
                mc.sequence++
102,266✔
169
                if size != maxPacketSize {
204,244✔
170
                        return nil
101,978✔
171
                }
101,978✔
172
                pktLen -= size
288✔
173
                data = data[size:]
288✔
174
        }
175
}
176

177
/******************************************************************************
178
*                           Initialization Process                            *
179
******************************************************************************/
180

181
// Handshake Initialization Packet
182
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
183
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
17,014✔
184
        data, err = mc.readPacket()
17,014✔
185
        if err != nil {
17,098✔
186
                return
84✔
187
        }
84✔
188

189
        if data[0] == iERR {
17,018✔
190
                return nil, "", mc.handleErrorPacket(data)
88✔
191
        }
88✔
192

193
        // protocol version [1 byte]
194
        if data[0] < minProtocolVersion {
16,842✔
195
                return nil, "", fmt.Errorf(
×
196
                        "unsupported protocol version %d. Version %d or higher is required",
×
197
                        data[0],
×
198
                        minProtocolVersion,
×
199
                )
×
200
        }
×
201

202
        // server version [null terminated string]
203
        // connection id [4 bytes]
204
        pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
16,842✔
205

16,842✔
206
        // first part of the password cipher [8 bytes]
16,842✔
207
        authData := data[pos : pos+8]
16,842✔
208

16,842✔
209
        // (filler) always 0x00 [1 byte]
16,842✔
210
        pos += 8 + 1
16,842✔
211

16,842✔
212
        // capability flags (lower 2 bytes) [2 bytes]
16,842✔
213
        mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
16,842✔
214
        if mc.flags&clientProtocol41 == 0 {
16,842✔
215
                return nil, "", ErrOldProtocol
×
216
        }
×
217
        if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
16,842✔
218
                if mc.cfg.AllowFallbackToPlaintext {
×
219
                        mc.cfg.TLS = nil
×
220
                } else {
×
221
                        return nil, "", ErrNoTLS
×
222
                }
×
223
        }
224
        pos += 2
16,842✔
225

16,842✔
226
        if len(data) > pos {
33,684✔
227
                // character set [1 byte]
16,842✔
228
                // status flags [2 bytes]
16,842✔
229
                pos += 3
16,842✔
230
                // capability flags (upper 2 bytes) [2 bytes]
16,842✔
231
                mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
16,842✔
232
                pos += 2
16,842✔
233
                // length of auth-plugin-data [1 byte]
16,842✔
234
                // reserved (all [00]) [10 bytes]
16,842✔
235
                pos += 11
16,842✔
236

16,842✔
237
                // second part of the password cipher [minimum 13 bytes],
16,842✔
238
                // where len=MAX(13, length of auth-plugin-data - 8)
16,842✔
239
                //
16,842✔
240
                // The web documentation is ambiguous about the length. However,
16,842✔
241
                // according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
16,842✔
242
                // the 13th byte is "\0 byte, terminating the second part of
16,842✔
243
                // a scramble". So the second part of the password cipher is
16,842✔
244
                // a NULL terminated string that's at least 13 bytes with the
16,842✔
245
                // last byte being NULL.
16,842✔
246
                //
16,842✔
247
                // The official Python library uses the fixed length 12
16,842✔
248
                // which seems to work but technically could have a hidden bug.
16,842✔
249
                authData = append(authData, data[pos:pos+12]...)
16,842✔
250
                pos += 13
16,842✔
251

16,842✔
252
                // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
16,842✔
253
                // \NUL otherwise
16,842✔
254
                if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
33,652✔
255
                        plugin = string(data[pos : pos+end])
16,810✔
256
                } else {
16,842✔
257
                        plugin = string(data[pos:])
32✔
258
                }
32✔
259

260
                // make a memory safe copy of the cipher slice
261
                var b [20]byte
16,842✔
262
                copy(b[:], authData)
16,842✔
263
                return b[:], plugin, nil
16,842✔
264
        }
265

266
        // make a memory safe copy of the cipher slice
267
        var b [8]byte
×
268
        copy(b[:], authData)
×
269
        return b[:], plugin, nil
×
270
}
271

272
// Client Authentication Packet
273
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
274
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
17,258✔
275
        // Adjust client flags based on server support
17,258✔
276
        clientFlags := clientProtocol41 |
17,258✔
277
                clientSecureConn |
17,258✔
278
                clientLongPassword |
17,258✔
279
                clientTransactions |
17,258✔
280
                clientLocalFiles |
17,258✔
281
                clientPluginAuth |
17,258✔
282
                clientMultiResults |
17,258✔
283
                mc.flags&clientConnectAttrs |
17,258✔
284
                mc.flags&clientLongFlag
17,258✔
285

17,258✔
286
        sendConnectAttrs := mc.flags&clientConnectAttrs != 0
17,258✔
287

17,258✔
288
        if mc.cfg.ClientFoundRows {
17,290✔
289
                clientFlags |= clientFoundRows
32✔
290
        }
32✔
291
        if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
22,069✔
292
                clientFlags |= clientCompress
4,811✔
293
        }
4,811✔
294
        // To enable TLS / SSL
295
        if mc.cfg.TLS != nil {
22,352✔
296
                clientFlags |= clientSSL
5,094✔
297
        }
5,094✔
298

299
        if mc.cfg.MultiStatements {
17,546✔
300
                clientFlags |= clientMultiStatements
288✔
301
        }
288✔
302

303
        // encode length of the auth plugin data
304
        var authRespLEIBuf [9]byte
17,258✔
305
        authRespLen := len(authResp)
17,258✔
306
        authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
17,258✔
307
        if len(authRespLEI) > 1 {
17,290✔
308
                // if the length can not be written in 1 byte, it must be written as a
32✔
309
                // length encoded integer
32✔
310
                clientFlags |= clientPluginAuthLenEncClientData
32✔
311
        }
32✔
312

313
        pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
17,258✔
314

17,258✔
315
        // To specify a db name
17,258✔
316
        if n := len(mc.cfg.DBName); n > 0 {
34,068✔
317
                clientFlags |= clientConnectWithDB
16,810✔
318
                pktLen += n + 1
16,810✔
319
        }
16,810✔
320

321
        // encode length of the connection attributes
322
        var connAttrsLEI []byte
17,258✔
323
        if sendConnectAttrs {
34,068✔
324
                var connAttrsLEIBuf [9]byte
16,810✔
325
                connAttrsLen := len(mc.connector.encodedAttributes)
16,810✔
326
                connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
16,810✔
327
                pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
16,810✔
328
        }
16,810✔
329

330
        // Calculate packet length and get buffer with that size
331
        data, err := mc.buf.takeBuffer(pktLen + 4)
17,258✔
332
        if err != nil {
17,258✔
333
                mc.cleanup()
×
334
                return err
×
335
        }
×
336

337
        // ClientFlags [32 bit]
338
        binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags))
17,258✔
339

17,258✔
340
        // MaxPacketSize [32 bit] (none)
17,258✔
341
        binary.LittleEndian.PutUint32(data[8:], 0)
17,258✔
342

17,258✔
343
        // Collation ID [1 byte]
17,258✔
344
        data[12] = defaultCollationID
17,258✔
345
        if cname := mc.cfg.Collation; cname != "" {
17,642✔
346
                colID, ok := collations[cname]
384✔
347
                if ok {
768✔
348
                        data[12] = colID
384✔
349
                } else if len(mc.cfg.charsets) > 0 {
384✔
350
                        // When cfg.charset is set, the collation is set by `SET NAMES <charset> COLLATE <collation>`.
×
351
                        return fmt.Errorf("unknown collation: %q", cname)
×
352
                }
×
353
        }
354

355
        // Filler [23 bytes] (all 0x00)
356
        pos := 13
17,258✔
357
        for ; pos < 13+23; pos++ {
414,192✔
358
                data[pos] = 0
396,934✔
359
        }
396,934✔
360

361
        // SSL Connection Request Packet
362
        // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
363
        if mc.cfg.TLS != nil {
22,352✔
364
                // Send TLS / SSL request packet
5,094✔
365
                if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
5,143✔
366
                        return err
49✔
367
                }
49✔
368

369
                // Switch to TLS
370
                tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
5,045✔
371
                if err := tlsConn.Handshake(); err != nil {
7,251✔
372
                        if cerr := mc.canceled.Value(); cerr != nil {
4,412✔
373
                                return cerr
2,206✔
374
                        }
2,206✔
375
                        return err
×
376
                }
377
                mc.netConn = tlsConn
2,839✔
378
        }
379

380
        // User [null terminated string]
381
        if len(mc.cfg.User) > 0 {
30,006✔
382
                pos += copy(data[pos:], mc.cfg.User)
15,003✔
383
        }
15,003✔
384
        data[pos] = 0x00
15,003✔
385
        pos++
15,003✔
386

15,003✔
387
        // Auth Data [length encoded integer]
15,003✔
388
        pos += copy(data[pos:], authRespLEI)
15,003✔
389
        pos += copy(data[pos:], authResp)
15,003✔
390

15,003✔
391
        // Databasename [null terminated string]
15,003✔
392
        if len(mc.cfg.DBName) > 0 {
29,558✔
393
                pos += copy(data[pos:], mc.cfg.DBName)
14,555✔
394
                data[pos] = 0x00
14,555✔
395
                pos++
14,555✔
396
        }
14,555✔
397

398
        pos += copy(data[pos:], plugin)
15,003✔
399
        data[pos] = 0x00
15,003✔
400
        pos++
15,003✔
401

15,003✔
402
        // Connection Attributes
15,003✔
403
        if sendConnectAttrs {
29,558✔
404
                pos += copy(data[pos:], connAttrsLEI)
14,555✔
405
                pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
14,555✔
406
        }
14,555✔
407

408
        // Send Auth packet
409
        return mc.writePacket(data[:pos])
15,003✔
410
}
411

412
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
413
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
867✔
414
        pktLen := 4 + len(authData)
867✔
415
        data, err := mc.buf.takeBuffer(pktLen)
867✔
416
        if err != nil {
867✔
417
                mc.cleanup()
×
418
                return err
×
419
        }
×
420

421
        // Add the auth data [EOF]
422
        copy(data[4:], authData)
867✔
423
        return mc.writePacket(data)
867✔
424
}
425

426
/******************************************************************************
427
*                             Command Packets                                 *
428
******************************************************************************/
429

430
func (mc *mysqlConn) writeCommandPacket(command byte) error {
12,416✔
431
        // Reset Packet Sequence
12,416✔
432
        mc.resetSequence()
12,416✔
433

12,416✔
434
        data, err := mc.buf.takeSmallBuffer(4 + 1)
12,416✔
435
        if err != nil {
12,416✔
UNCOV
436
                return err
×
UNCOV
437
        }
×
438

439
        // Add command byte
440
        data[4] = command
12,416✔
441

12,416✔
442
        // Send CMD packet
12,416✔
443
        err = mc.writePacket(data)
12,416✔
444
        mc.syncSequence()
12,416✔
445
        return err
12,416✔
446
}
447

448
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
50,412✔
449
        // Reset Packet Sequence
50,412✔
450
        mc.resetSequence()
50,412✔
451

50,412✔
452
        pktLen := 1 + len(arg)
50,412✔
453
        data, err := mc.buf.takeBuffer(pktLen + 4)
50,412✔
454
        if err != nil {
50,412✔
455
                return err
×
456
        }
×
457

458
        // Add command byte
459
        data[4] = command
50,412✔
460

50,412✔
461
        // Add arg
50,412✔
462
        copy(data[5:], arg)
50,412✔
463

50,412✔
464
        // Send CMD packet
50,412✔
465
        err = mc.writePacket(data)
50,412✔
466
        mc.syncSequence()
50,412✔
467
        return err
50,412✔
468
}
469

470
func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
8,336✔
471
        // Reset Packet Sequence
8,336✔
472
        mc.resetSequence()
8,336✔
473

8,336✔
474
        data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
8,336✔
475
        if err != nil {
8,336✔
476
                return err
×
477
        }
×
478

479
        // Add command byte
480
        data[4] = command
8,336✔
481

8,336✔
482
        // Add arg [32 bit]
8,336✔
483
        binary.LittleEndian.PutUint32(data[5:], arg)
8,336✔
484

8,336✔
485
        // Send CMD packet
8,336✔
486
        err = mc.writePacket(data)
8,336✔
487
        mc.syncSequence()
8,336✔
488
        return err
8,336✔
489
}
490

491
/******************************************************************************
492
*                              Result Packets                                 *
493
******************************************************************************/
494

495
func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
16,216✔
496
        data, err := mc.readPacket()
16,216✔
497
        if err != nil {
18,070✔
498
                return nil, "", err
1,854✔
499
        }
1,854✔
500

501
        // packet indicator
502
        switch data[0] {
14,362✔
503

504
        case iOK:
8,884✔
505
                // resultUnchanged, since auth happens before any queries or
8,884✔
506
                // commands have been executed.
8,884✔
507
                return nil, "", mc.resultUnchanged().handleOkPacket(data)
8,884✔
508

509
        case iAuthMoreData:
4,732✔
510
                return data[1:], "", err
4,732✔
511

512
        case iEOF:
672✔
513
                if len(data) == 1 {
768✔
514
                        // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
96✔
515
                        return nil, "mysql_old_password", nil
96✔
516
                }
96✔
517
                pluginEndIndex := bytes.IndexByte(data, 0x00)
576✔
518
                if pluginEndIndex < 0 {
576✔
519
                        return nil, "", ErrMalformPkt
×
520
                }
×
521
                plugin := string(data[1:pluginEndIndex])
576✔
522
                authData := data[pluginEndIndex+1:]
576✔
523
                if len(authData) > 0 && authData[len(authData)-1] == 0 {
1,056✔
524
                        authData = authData[:len(authData)-1]
480✔
525
                }
480✔
526
                return authData, plugin, nil
576✔
527

528
        default: // Error otherwise
74✔
529
                return nil, "", mc.handleErrorPacket(data)
74✔
530
        }
531
}
532

533
// Returns error if Packet is not a 'Result OK'-Packet
534
func (mc *okHandler) readResultOK() error {
7,241✔
535
        data, err := mc.conn().readPacket()
7,241✔
536
        if err != nil {
7,354✔
537
                return err
113✔
538
        }
113✔
539

540
        if data[0] == iOK {
14,232✔
541
                return mc.handleOkPacket(data)
7,104✔
542
        }
7,104✔
543
        return mc.conn().handleErrorPacket(data)
24✔
544
}
545

546
// Result Set Header Packet
547
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
548
func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
51,756✔
549
        // handleOkPacket replaces both values; other cases leave the values unchanged.
51,756✔
550
        mc.result.affectedRows = append(mc.result.affectedRows, 0)
51,756✔
551
        mc.result.insertIds = append(mc.result.insertIds, 0)
51,756✔
552

51,756✔
553
        data, err := mc.conn().readPacket()
51,756✔
554
        if err != nil {
51,896✔
555
                return 0, err
140✔
556
        }
140✔
557

558
        switch data[0] {
51,616✔
559
        case iOK:
28,556✔
560
                return 0, mc.handleOkPacket(data)
28,556✔
561

562
        case iERR:
546✔
563
                return 0, mc.conn().handleErrorPacket(data)
546✔
564

565
        case iLocalInFile:
480✔
566
                return 0, mc.handleInFileRequest(string(data[1:]))
480✔
567
        }
568

569
        // column count
570
        // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
571
        num, _, _ := readLengthEncodedInteger(data)
22,034✔
572
        // ignore remaining data in the packet. see #1478.
22,034✔
573
        return int(num), nil
22,034✔
574
}
575

576
// Error Packet
577
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
578
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
732✔
579
        if data[0] != iERR {
732✔
580
                return ErrMalformPkt
×
581
        }
×
582

583
        // 0xff [1 byte]
584

585
        // Error Number [16 bit uint]
586
        errno := binary.LittleEndian.Uint16(data[1:3])
732✔
587

732✔
588
        // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
732✔
589
        // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
732✔
590
        // 1836: ER_READ_ONLY_MODE
732✔
591
        if (errno == 1792 || errno == 1290 || errno == 1836) && mc.cfg.RejectReadOnly {
828✔
592
                // Oops; we are connected to a read-only connection, and won't be able
96✔
593
                // to issue any write statements. Since RejectReadOnly is configured,
96✔
594
                // we throw away this connection hoping this one would have write
96✔
595
                // permission. This is specifically for a possible race condition
96✔
596
                // during failover (e.g. on AWS Aurora). See README.md for more.
96✔
597
                //
96✔
598
                // We explicitly close the connection before returning
96✔
599
                // driver.ErrBadConn to ensure that `database/sql` purges this
96✔
600
                // connection and initiates a new one for next statement next time.
96✔
601
                mc.Close()
96✔
602
                return driver.ErrBadConn
96✔
603
        }
96✔
604

605
        me := &MySQLError{Number: errno}
636✔
606

636✔
607
        pos := 3
636✔
608

636✔
609
        // SQL State [optional: # + 5bytes string]
636✔
610
        if data[3] == 0x23 {
1,184✔
611
                copy(me.SQLState[:], data[4:4+5])
548✔
612
                pos = 9
548✔
613
        }
548✔
614

615
        // Error Message [string]
616
        me.Message = string(data[pos:])
636✔
617

636✔
618
        return me
636✔
619
}
620

621
func readStatus(b []byte) statusFlag {
79,314✔
622
        return statusFlag(b[0]) | statusFlag(b[1])<<8
79,314✔
623
}
79,314✔
624

625
// Returns an instance of okHandler for codepaths where mysqlConn.result doesn't
626
// need to be cleared first (e.g. during authentication, or while additional
627
// resultsets are being fetched.)
628
func (mc *mysqlConn) resultUnchanged() *okHandler {
14,288✔
629
        return (*okHandler)(mc)
14,288✔
630
}
14,288✔
631

632
// okHandler represents the state of the connection when mysqlConn.result has
633
// been prepared for processing of OK packets.
634
//
635
// To correctly populate mysqlConn.result (updated by handleOkPacket()), all
636
// callpaths must either:
637
//
638
// 1. first clear it using clearResult(), or
639
// 2. confirm that they don't need to (by calling resultUnchanged()).
640
//
641
// Both return an instance of type *okHandler.
642
type okHandler mysqlConn
643

644
// Exposes the underlying type's methods.
645
func (mc *okHandler) conn() *mysqlConn {
60,911✔
646
        return (*mysqlConn)(mc)
60,911✔
647
}
60,911✔
648

649
// clearResult clears the connection's stored affectedRows and insertIds
650
// fields.
651
//
652
// It returns a handler that can process OK responses.
653
func (mc *mysqlConn) clearResult() *okHandler {
92,309✔
654
        mc.result = mysqlResult{}
92,309✔
655
        return (*okHandler)(mc)
92,309✔
656
}
92,309✔
657

658
// Ok Packet
659
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
660
func (mc *okHandler) handleOkPacket(data []byte) error {
44,544✔
661
        var n, m int
44,544✔
662
        var affectedRows, insertId uint64
44,544✔
663

44,544✔
664
        // 0x00 [1 byte]
44,544✔
665

44,544✔
666
        // Affected rows [Length Coded Binary]
44,544✔
667
        affectedRows, _, n = readLengthEncodedInteger(data[1:])
44,544✔
668

44,544✔
669
        // Insert id [Length Coded Binary]
44,544✔
670
        insertId, _, m = readLengthEncodedInteger(data[1+n:])
44,544✔
671

44,544✔
672
        // Update for the current statement result (only used by
44,544✔
673
        // readResultSetHeaderPacket).
44,544✔
674
        if len(mc.result.affectedRows) > 0 {
73,388✔
675
                mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
28,844✔
676
        }
28,844✔
677
        if len(mc.result.insertIds) > 0 {
73,388✔
678
                mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
28,844✔
679
        }
28,844✔
680

681
        // server_status [2 bytes]
682
        mc.status = readStatus(data[1+n+m : 1+n+m+2])
44,544✔
683
        if mc.status&statusMoreResultsExists != 0 {
44,896✔
684
                return nil
352✔
685
        }
352✔
686

687
        // warning count [2 bytes]
688

689
        return nil
44,192✔
690
}
691

692
// Read Packets as Field Packets until EOF-Packet or an Error appears
693
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
694
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
21,938✔
695
        columns := make([]mysqlField, count)
21,938✔
696

21,938✔
697
        for i := 0; ; i++ {
12,656,182✔
698
                data, err := mc.readPacket()
12,634,244✔
699
                if err != nil {
12,634,244✔
700
                        return nil, err
×
701
                }
×
702

703
                // EOF Packet
704
                if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
12,656,182✔
705
                        if i == count {
43,876✔
706
                                return columns, nil
21,938✔
707
                        }
21,938✔
708
                        return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
×
709
                }
710

711
                // Catalog
712
                pos, err := skipLengthEncodedString(data)
12,612,306✔
713
                if err != nil {
12,612,306✔
714
                        return nil, err
×
715
                }
×
716

717
                // Database [len coded string]
718
                n, err := skipLengthEncodedString(data[pos:])
12,612,306✔
719
                if err != nil {
12,612,306✔
720
                        return nil, err
×
721
                }
×
722
                pos += n
12,612,306✔
723

12,612,306✔
724
                // Table [len coded string]
12,612,306✔
725
                if mc.cfg.ColumnsWithAlias {
12,612,370✔
726
                        tableName, _, n, err := readLengthEncodedString(data[pos:])
64✔
727
                        if err != nil {
64✔
728
                                return nil, err
×
729
                        }
×
730
                        pos += n
64✔
731
                        columns[i].tableName = string(tableName)
64✔
732
                } else {
12,612,242✔
733
                        n, err = skipLengthEncodedString(data[pos:])
12,612,242✔
734
                        if err != nil {
12,612,242✔
735
                                return nil, err
×
736
                        }
×
737
                        pos += n
12,612,242✔
738
                }
739

740
                // Original table [len coded string]
741
                n, err = skipLengthEncodedString(data[pos:])
12,612,306✔
742
                if err != nil {
12,612,306✔
743
                        return nil, err
×
744
                }
×
745
                pos += n
12,612,306✔
746

12,612,306✔
747
                // Name [len coded string]
12,612,306✔
748
                name, _, n, err := readLengthEncodedString(data[pos:])
12,612,306✔
749
                if err != nil {
12,612,306✔
750
                        return nil, err
×
751
                }
×
752
                columns[i].name = string(name)
12,612,306✔
753
                pos += n
12,612,306✔
754

12,612,306✔
755
                // Original name [len coded string]
12,612,306✔
756
                n, err = skipLengthEncodedString(data[pos:])
12,612,306✔
757
                if err != nil {
12,612,306✔
758
                        return nil, err
×
759
                }
×
760
                pos += n
12,612,306✔
761

12,612,306✔
762
                // Filler [uint8]
12,612,306✔
763
                pos++
12,612,306✔
764

12,612,306✔
765
                // Charset [charset, collation uint8]
12,612,306✔
766
                columns[i].charSet = data[pos]
12,612,306✔
767
                pos += 2
12,612,306✔
768

12,612,306✔
769
                // Length [uint32]
12,612,306✔
770
                columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
12,612,306✔
771
                pos += 4
12,612,306✔
772

12,612,306✔
773
                // Field type [uint8]
12,612,306✔
774
                columns[i].fieldType = fieldType(data[pos])
12,612,306✔
775
                pos++
12,612,306✔
776

12,612,306✔
777
                // Flags [uint16]
12,612,306✔
778
                columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
12,612,306✔
779
                pos += 2
12,612,306✔
780

12,612,306✔
781
                // Decimals [uint8]
12,612,306✔
782
                columns[i].decimals = data[pos]
12,612,306✔
783
                //pos++
784

785
                // Default value [len coded binary]
786
                //if pos < len(data) {
787
                //        defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
788
                //}
789
        }
790
}
791

792
// Read Packets as Field Packets until EOF-Packet or an Error appears
793
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
794
func (rows *textRows) readRow(dest []driver.Value) error {
16,450✔
795
        mc := rows.mc
16,450✔
796

16,450✔
797
        if rows.rs.done {
16,450✔
798
                return io.EOF
×
799
        }
×
800

801
        data, err := mc.readPacket()
16,450✔
802
        if err != nil {
16,450✔
803
                return err
×
804
        }
×
805

806
        // EOF Packet
807
        if data[0] == iEOF && len(data) == 5 {
17,506✔
808
                // server_status [2 bytes]
1,056✔
809
                rows.mc.status = readStatus(data[3:])
1,056✔
810
                rows.rs.done = true
1,056✔
811
                if !rows.HasNextResultSet() {
1,984✔
812
                        rows.mc = nil
928✔
813
                }
928✔
814
                return io.EOF
1,056✔
815
        }
816
        if data[0] == iERR {
15,394✔
817
                rows.mc = nil
×
818
                return mc.handleErrorPacket(data)
×
819
        }
×
820

821
        // RowSet Packet
822
        var (
15,394✔
823
                n      int
15,394✔
824
                isNull bool
15,394✔
825
                pos    int = 0
15,394✔
826
        )
15,394✔
827

15,394✔
828
        for i := range dest {
47,236✔
829
                // Read bytes and convert to string
31,842✔
830
                var buf []byte
31,842✔
831
                buf, isNull, n, err = readLengthEncodedString(data[pos:])
31,842✔
832
                pos += n
31,842✔
833

31,842✔
834
                if err != nil {
31,842✔
835
                        return err
×
836
                }
×
837

838
                if isNull {
33,812✔
839
                        dest[i] = nil
1,970✔
840
                        continue
1,970✔
841
                }
842

843
                switch rows.rs.columns[i].fieldType {
29,872✔
844
                case fieldTypeTimestamp,
845
                        fieldTypeDateTime,
846
                        fieldTypeDate,
847
                        fieldTypeNewDate:
3,940✔
848
                        if mc.parseTime {
6,630✔
849
                                dest[i], err = parseDateTime(buf, mc.cfg.Loc)
2,690✔
850
                        } else {
3,940✔
851
                                dest[i] = buf
1,250✔
852
                        }
1,250✔
853

854
                case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
6,936✔
855
                        dest[i], err = strconv.ParseInt(string(buf), 10, 64)
6,936✔
856

857
                case fieldTypeLongLong:
1,812✔
858
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
2,420✔
859
                                dest[i], err = strconv.ParseUint(string(buf), 10, 64)
608✔
860
                        } else {
1,812✔
861
                                dest[i], err = strconv.ParseInt(string(buf), 10, 64)
1,204✔
862
                        }
1,204✔
863

864
                case fieldTypeFloat:
736✔
865
                        var d float64
736✔
866
                        d, err = strconv.ParseFloat(string(buf), 32)
736✔
867
                        dest[i] = float32(d)
736✔
868

869
                case fieldTypeDouble:
544✔
870
                        dest[i], err = strconv.ParseFloat(string(buf), 64)
544✔
871

872
                default:
15,904✔
873
                        dest[i] = buf
15,904✔
874
                }
875
                if err != nil {
29,872✔
876
                        return err
×
877
                }
×
878
        }
879

880
        return nil
15,394✔
881
}
882

883
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
884
func (mc *mysqlConn) readUntilEOF() error {
33,042✔
885
        for {
12,665,252✔
886
                data, err := mc.readPacket()
12,632,210✔
887
                if err != nil {
12,632,210✔
888
                        return err
×
889
                }
×
890

891
                switch data[0] {
12,632,210✔
892
                case iERR:
×
893
                        return mc.handleErrorPacket(data)
×
894
                case iEOF:
33,042✔
895
                        if len(data) == 5 {
66,084✔
896
                                mc.status = readStatus(data[3:])
33,042✔
897
                        }
33,042✔
898
                        return nil
33,042✔
899
                }
900
        }
901
}
902

903
/******************************************************************************
904
*                           Prepared Statements                               *
905
******************************************************************************/
906

907
// Prepare Result Packets
908
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
909
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
8,464✔
910
        data, err := stmt.mc.readPacket()
8,464✔
911
        if err == nil {
16,928✔
912
                // packet indicator [1 byte]
8,464✔
913
                if data[0] != iOK {
8,464✔
914
                        return 0, stmt.mc.handleErrorPacket(data)
×
915
                }
×
916

917
                // statement id [4 bytes]
918
                stmt.id = binary.LittleEndian.Uint32(data[1:5])
8,464✔
919

8,464✔
920
                // Column count [16 bit uint]
8,464✔
921
                columnCount := binary.LittleEndian.Uint16(data[5:7])
8,464✔
922

8,464✔
923
                // Param count [16 bit uint]
8,464✔
924
                stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
8,464✔
925

8,464✔
926
                // Reserved [8 bit]
8,464✔
927

8,464✔
928
                // Warning count [16 bit uint]
8,464✔
929

8,464✔
930
                return columnCount, nil
8,464✔
931
        }
932
        return 0, err
×
933
}
934

935
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
936
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
64✔
937
        maxLen := stmt.mc.maxAllowedPacket - 1
64✔
938
        pktLen := maxLen
64✔
939

64✔
940
        // After the header (bytes 0-3) follows before the data:
64✔
941
        // 1 byte command
64✔
942
        // 4 bytes stmtID
64✔
943
        // 2 bytes paramID
64✔
944
        const dataOffset = 1 + 4 + 2
64✔
945

64✔
946
        // Cannot use the write buffer since
64✔
947
        // a) the buffer is too small
64✔
948
        // b) it is in use
64✔
949
        data := make([]byte, 4+1+4+2+len(arg))
64✔
950

64✔
951
        copy(data[4+dataOffset:], arg)
64✔
952

64✔
953
        for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
128✔
954
                if dataOffset+argLen < maxLen {
128✔
955
                        pktLen = dataOffset + argLen
64✔
956
                }
64✔
957

958
                // Add command byte [1 byte]
959
                data[4] = comStmtSendLongData
64✔
960

64✔
961
                // Add stmtID [32 bit]
64✔
962
                binary.LittleEndian.PutUint32(data[5:], stmt.id)
64✔
963

64✔
964
                // Add paramID [16 bit]
64✔
965
                binary.LittleEndian.PutUint16(data[9:], uint16(paramID))
64✔
966

64✔
967
                // Send CMD packet
64✔
968
                err := stmt.mc.writePacket(data[:4+pktLen])
64✔
969
                // Every COM_LONG_DATA packet reset Packet Sequence
64✔
970
                stmt.mc.resetSequence()
64✔
971
                if err == nil {
128✔
972
                        data = data[pktLen-dataOffset:]
64✔
973
                        continue
64✔
974
                }
975
                return err
×
976
        }
977

978
        return nil
64✔
979
}
980

981
// Execute Prepared Statement
982
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
983
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
8,880✔
984
        if len(args) != stmt.paramCount {
8,880✔
985
                return fmt.Errorf(
×
986
                        "argument count mismatch (got: %d; has: %d)",
×
987
                        len(args),
×
988
                        stmt.paramCount,
×
989
                )
×
990
        }
×
991

992
        const minPktLen = 4 + 1 + 4 + 1 + 4
8,880✔
993
        mc := stmt.mc
8,880✔
994

8,880✔
995
        // Determine threshold dynamically to avoid packet size shortage.
8,880✔
996
        longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
8,880✔
997
        if longDataSize < 64 {
8,880✔
998
                longDataSize = 64
×
999
        }
×
1000

1001
        // Reset packet-sequence
1002
        mc.resetSequence()
8,880✔
1003

8,880✔
1004
        var data []byte
8,880✔
1005
        var err error
8,880✔
1006

8,880✔
1007
        if len(args) == 0 {
9,520✔
1008
                data, err = mc.buf.takeBuffer(minPktLen)
640✔
1009
        } else {
8,880✔
1010
                data, err = mc.buf.takeCompleteBuffer()
8,240✔
1011
                // In this case the len(data) == cap(data) which is used to optimise the flow below.
8,240✔
1012
        }
8,240✔
1013
        if err != nil {
8,880✔
1014
                return err
×
1015
        }
×
1016

1017
        // command [1 byte]
1018
        data[4] = comStmtExecute
8,880✔
1019

8,880✔
1020
        // statement_id [4 bytes]
8,880✔
1021
        binary.LittleEndian.PutUint32(data[5:], stmt.id)
8,880✔
1022

8,880✔
1023
        // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
8,880✔
1024
        data[9] = 0x00
8,880✔
1025

8,880✔
1026
        // iteration_count (uint32(1)) [4 bytes]
8,880✔
1027
        binary.LittleEndian.PutUint32(data[10:], 1)
8,880✔
1028

8,880✔
1029
        if len(args) > 0 {
17,120✔
1030
                pos := minPktLen
8,240✔
1031

8,240✔
1032
                var nullMask []byte
8,240✔
1033
                if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
8,432✔
1034
                        // buffer has to be extended but we don't know by how much so
192✔
1035
                        // we depend on append after all data with known sizes fit.
192✔
1036
                        // We stop at that because we deal with a lot of columns here
192✔
1037
                        // which makes the required allocation size hard to guess.
192✔
1038
                        tmp := make([]byte, pos+maskLen+typesLen)
192✔
1039
                        copy(tmp[:pos], data[:pos])
192✔
1040
                        data = tmp
192✔
1041
                        nullMask = data[pos : pos+maskLen]
192✔
1042
                        // No need to clean nullMask as make ensures that.
192✔
1043
                        pos += maskLen
192✔
1044
                } else {
8,240✔
1045
                        nullMask = data[pos : pos+maskLen]
8,048✔
1046
                        for i := range nullMask {
16,096✔
1047
                                nullMask[i] = 0
8,048✔
1048
                        }
8,048✔
1049
                        pos += maskLen
8,048✔
1050
                }
1051

1052
                // newParameterBoundFlag 1 [1 byte]
1053
                data[pos] = 0x01
8,240✔
1054
                pos++
8,240✔
1055

8,240✔
1056
                // type of each parameter [len(args)*2 bytes]
8,240✔
1057
                paramTypes := data[pos:]
8,240✔
1058
                pos += len(args) * 2
8,240✔
1059

8,240✔
1060
                // value of each parameter [n bytes]
8,240✔
1061
                paramValues := data[pos:pos]
8,240✔
1062
                valuesCap := cap(paramValues)
8,240✔
1063

8,240✔
1064
                for i, arg := range args {
12,599,680✔
1065
                        // build NULL-bitmap
12,591,440✔
1066
                        if arg == nil {
18,882,896✔
1067
                                nullMask[i/8] |= 1 << (uint(i) & 7)
6,291,456✔
1068
                                paramTypes[i+i] = byte(fieldTypeNULL)
6,291,456✔
1069
                                paramTypes[i+i+1] = 0x00
6,291,456✔
1070
                                continue
6,291,456✔
1071
                        }
1072

1073
                        if v, ok := arg.(json.RawMessage); ok {
6,300,048✔
1074
                                arg = []byte(v)
64✔
1075
                        }
64✔
1076
                        // cache types and values
1077
                        switch v := arg.(type) {
6,299,984✔
1078
                        case int64:
1,280✔
1079
                                paramTypes[i+i] = byte(fieldTypeLongLong)
1,280✔
1080
                                paramTypes[i+i+1] = 0x00
1,280✔
1081
                                paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
1,280✔
1082

1083
                        case uint64:
128✔
1084
                                paramTypes[i+i] = byte(fieldTypeLongLong)
128✔
1085
                                paramTypes[i+i+1] = 0x80 // type is unsigned
128✔
1086
                                paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
128✔
1087

1088
                        case float64:
64✔
1089
                                paramTypes[i+i] = byte(fieldTypeDouble)
64✔
1090
                                paramTypes[i+i+1] = 0x00
64✔
1091
                                paramValues = binary.LittleEndian.AppendUint64(paramValues, math.Float64bits(v))
64✔
1092

1093
                        case bool:
96✔
1094
                                paramTypes[i+i] = byte(fieldTypeTiny)
96✔
1095
                                paramTypes[i+i+1] = 0x00
96✔
1096

96✔
1097
                                if v {
128✔
1098
                                        paramValues = append(paramValues, 0x01)
32✔
1099
                                } else {
96✔
1100
                                        paramValues = append(paramValues, 0x00)
64✔
1101
                                }
64✔
1102

1103
                        case []byte:
224✔
1104
                                // Common case (non-nil value) first
224✔
1105
                                if v != nil {
384✔
1106
                                        paramTypes[i+i] = byte(fieldTypeString)
160✔
1107
                                        paramTypes[i+i+1] = 0x00
160✔
1108

160✔
1109
                                        if len(v) < longDataSize {
320✔
1110
                                                paramValues = appendLengthEncodedInteger(paramValues,
160✔
1111
                                                        uint64(len(v)),
160✔
1112
                                                )
160✔
1113
                                                paramValues = append(paramValues, v...)
160✔
1114
                                        } else {
160✔
1115
                                                if err := stmt.writeCommandLongData(i, v); err != nil {
×
1116
                                                        return err
×
1117
                                                }
×
1118
                                        }
1119
                                        continue
160✔
1120
                                }
1121

1122
                                // Handle []byte(nil) as a NULL value
1123
                                nullMask[i/8] |= 1 << (uint(i) & 7)
64✔
1124
                                paramTypes[i+i] = byte(fieldTypeNULL)
64✔
1125
                                paramTypes[i+i+1] = 0x00
64✔
1126

1127
                        case string:
6,295,912✔
1128
                                paramTypes[i+i] = byte(fieldTypeString)
6,295,912✔
1129
                                paramTypes[i+i+1] = 0x00
6,295,912✔
1130

6,295,912✔
1131
                                if len(v) < longDataSize {
12,591,760✔
1132
                                        paramValues = appendLengthEncodedInteger(paramValues,
6,295,848✔
1133
                                                uint64(len(v)),
6,295,848✔
1134
                                        )
6,295,848✔
1135
                                        paramValues = append(paramValues, v...)
6,295,848✔
1136
                                } else {
6,295,912✔
1137
                                        if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
64✔
1138
                                                return err
×
1139
                                        }
×
1140
                                }
1141

1142
                        case time.Time:
2,280✔
1143
                                paramTypes[i+i] = byte(fieldTypeString)
2,280✔
1144
                                paramTypes[i+i+1] = 0x00
2,280✔
1145

2,280✔
1146
                                var a [64]byte
2,280✔
1147
                                var b = a[:0]
2,280✔
1148

2,280✔
1149
                                if v.IsZero() {
3,152✔
1150
                                        b = append(b, "0000-00-00"...)
872✔
1151
                                } else {
2,280✔
1152
                                        b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
1,408✔
1153
                                        if err != nil {
1,408✔
1154
                                                return err
×
1155
                                        }
×
1156
                                }
1157

1158
                                paramValues = appendLengthEncodedInteger(paramValues,
2,280✔
1159
                                        uint64(len(b)),
2,280✔
1160
                                )
2,280✔
1161
                                paramValues = append(paramValues, b...)
2,280✔
1162

1163
                        default:
×
1164
                                return fmt.Errorf("cannot convert type: %T", arg)
×
1165
                        }
1166
                }
1167

1168
                // Check if param values exceeded the available buffer
1169
                // In that case we must build the data packet with the new values buffer
1170
                if valuesCap != cap(paramValues) {
8,336✔
1171
                        data = append(data[:pos], paramValues...)
96✔
1172
                        mc.buf.store(data) // allow this buffer to be reused
96✔
1173
                }
96✔
1174

1175
                pos += len(paramValues)
8,240✔
1176
                data = data[:pos]
8,240✔
1177
        }
1178

1179
        err = mc.writePacket(data)
8,880✔
1180
        mc.syncSequence()
8,880✔
1181
        return err
8,880✔
1182
}
1183

1184
// For each remaining resultset in the stream, discards its rows and updates
1185
// mc.affectedRows and mc.insertIds.
1186
func (mc *okHandler) discardResults() error {
46,366✔
1187
        for mc.status&statusMoreResultsExists != 0 {
46,622✔
1188
                resLen, err := mc.readResultSetHeaderPacket()
256✔
1189
                if err != nil {
320✔
1190
                        return err
64✔
1191
                }
64✔
1192
                if resLen > 0 {
192✔
1193
                        // columns
×
1194
                        if err := mc.conn().readUntilEOF(); err != nil {
×
1195
                                return err
×
1196
                        }
×
1197
                        // rows
1198
                        if err := mc.conn().readUntilEOF(); err != nil {
×
1199
                                return err
×
1200
                        }
×
1201
                }
1202
        }
1203
        return nil
46,302✔
1204
}
1205

1206
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
1207
func (rows *binaryRows) readRow(dest []driver.Value) error {
7,664✔
1208
        data, err := rows.mc.readPacket()
7,664✔
1209
        if err != nil {
7,664✔
1210
                return err
×
1211
        }
×
1212

1213
        // packet indicator [1 byte]
1214
        if data[0] != iOK {
8,336✔
1215
                // EOF Packet
672✔
1216
                if data[0] == iEOF && len(data) == 5 {
1,344✔
1217
                        rows.mc.status = readStatus(data[3:])
672✔
1218
                        rows.rs.done = true
672✔
1219
                        if !rows.HasNextResultSet() {
1,056✔
1220
                                rows.mc = nil
384✔
1221
                        }
384✔
1222
                        return io.EOF
672✔
1223
                }
1224
                mc := rows.mc
×
1225
                rows.mc = nil
×
1226

×
1227
                // Error otherwise
×
1228
                return mc.handleErrorPacket(data)
×
1229
        }
1230

1231
        // NULL-bitmap,  [(column-count + 7 + 2) / 8 bytes]
1232
        pos := 1 + (len(dest)+7+2)>>3
6,992✔
1233
        nullMask := data[1:pos]
6,992✔
1234

6,992✔
1235
        for i := range dest {
15,264✔
1236
                // Field is NULL
8,272✔
1237
                // (byte >> bit-pos) % 2 == 1
8,272✔
1238
                if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
8,464✔
1239
                        dest[i] = nil
192✔
1240
                        continue
192✔
1241
                }
1242

1243
                // Convert to byte-coded string
1244
                switch rows.rs.columns[i].fieldType {
8,080✔
1245
                case fieldTypeNULL:
×
1246
                        dest[i] = nil
×
1247
                        continue
×
1248

1249
                // Numeric Types
1250
                case fieldTypeTiny:
64✔
1251
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
64✔
1252
                                dest[i] = int64(data[pos])
×
1253
                        } else {
64✔
1254
                                dest[i] = int64(int8(data[pos]))
64✔
1255
                        }
64✔
1256
                        pos++
64✔
1257
                        continue
64✔
1258

1259
                case fieldTypeShort, fieldTypeYear:
32✔
1260
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
32✔
1261
                                dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
×
1262
                        } else {
32✔
1263
                                dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
32✔
1264
                        }
32✔
1265
                        pos += 2
32✔
1266
                        continue
32✔
1267

1268
                case fieldTypeInt24, fieldTypeLong:
1,014✔
1269
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
1,046✔
1270
                                dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
32✔
1271
                        } else {
1,014✔
1272
                                dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
982✔
1273
                        }
982✔
1274
                        pos += 4
1,014✔
1275
                        continue
1,014✔
1276

1277
                case fieldTypeLongLong:
1,002✔
1278
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
1,130✔
1279
                                val := binary.LittleEndian.Uint64(data[pos : pos+8])
128✔
1280
                                if val > math.MaxInt64 {
192✔
1281
                                        dest[i] = uint64ToString(val)
64✔
1282
                                } else {
128✔
1283
                                        dest[i] = int64(val)
64✔
1284
                                }
64✔
1285
                        } else {
874✔
1286
                                dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
874✔
1287
                        }
874✔
1288
                        pos += 8
1,002✔
1289
                        continue
1,002✔
1290

1291
                case fieldTypeFloat:
64✔
1292
                        dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
64✔
1293
                        pos += 4
64✔
1294
                        continue
64✔
1295

1296
                case fieldTypeDouble:
64✔
1297
                        dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
64✔
1298
                        pos += 8
64✔
1299
                        continue
64✔
1300

1301
                // Length coded Binary Strings
1302
                case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
1303
                        fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
1304
                        fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
1305
                        fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
1306
                        fieldTypeVector:
256✔
1307
                        var isNull bool
256✔
1308
                        var n int
256✔
1309
                        dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
256✔
1310
                        pos += n
256✔
1311
                        if err == nil {
512✔
1312
                                if !isNull {
512✔
1313
                                        continue
256✔
1314
                                } else {
×
1315
                                        dest[i] = nil
×
1316
                                        continue
×
1317
                                }
1318
                        }
1319
                        return err
×
1320

1321
                case
1322
                        fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
1323
                        fieldTypeTime,                         // Time [-][H]HH:MM:SS[.fractal]
1324
                        fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
5,584✔
1325

5,584✔
1326
                        num, isNull, n := readLengthEncodedInteger(data[pos:])
5,584✔
1327
                        pos += n
5,584✔
1328

5,584✔
1329
                        switch {
5,584✔
1330
                        case isNull:
×
1331
                                dest[i] = nil
×
1332
                                continue
×
1333
                        case rows.rs.columns[i].fieldType == fieldTypeTime:
3,584✔
1334
                                // database/sql does not support an equivalent to TIME, return a string
3,584✔
1335
                                var dstlen uint8
3,584✔
1336
                                switch decimals := rows.rs.columns[i].decimals; decimals {
3,584✔
1337
                                case 0x00, 0x1f:
1,792✔
1338
                                        dstlen = 8
1,792✔
1339
                                case 1, 2, 3, 4, 5, 6:
1,792✔
1340
                                        dstlen = 8 + 1 + decimals
1,792✔
1341
                                default:
×
1342
                                        return fmt.Errorf(
×
1343
                                                "protocol error, illegal decimals value %d",
×
1344
                                                rows.rs.columns[i].decimals,
×
1345
                                        )
×
1346
                                }
1347
                                dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
3,584✔
1348
                        case rows.mc.parseTime:
1,000✔
1349
                                dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
1,000✔
1350
                        default:
1,000✔
1351
                                var dstlen uint8
1,000✔
1352
                                if rows.rs.columns[i].fieldType == fieldTypeDate {
1,200✔
1353
                                        dstlen = 10
200✔
1354
                                } else {
1,000✔
1355
                                        switch decimals := rows.rs.columns[i].decimals; decimals {
800✔
1356
                                        case 0x00, 0x1f:
400✔
1357
                                                dstlen = 19
400✔
1358
                                        case 1, 2, 3, 4, 5, 6:
400✔
1359
                                                dstlen = 19 + 1 + decimals
400✔
1360
                                        default:
×
1361
                                                return fmt.Errorf(
×
1362
                                                        "protocol error, illegal decimals value %d",
×
1363
                                                        rows.rs.columns[i].decimals,
×
1364
                                                )
×
1365
                                        }
1366
                                }
1367
                                dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
1,000✔
1368
                        }
1369

1370
                        if err == nil {
11,168✔
1371
                                pos += int(num)
5,584✔
1372
                                continue
5,584✔
1373
                        } else {
×
1374
                                return err
×
1375
                        }
×
1376

1377
                // Please report if this happens!
1378
                default:
×
1379
                        return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
×
1380
                }
1381
        }
1382

1383
        return nil
6,992✔
1384
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc