• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

go-sql-driver / mysql / 14665995848

25 Apr 2025 01:39PM UTC coverage: 82.772% (-0.2%) from 82.961%
14665995848

Pull #1696

github

rusher
use named constants for caching_sha2_password
Pull Request #1696: Authentication

321 of 387 new or added lines in 9 files covered. (82.95%)

7 existing lines in 3 files now uncovered.

3315 of 4005 relevant lines covered (82.77%)

2440698.09 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.16
/packets.go
1
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2
//
3
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
4
//
5
// This Source Code Form is subject to the terms of the Mozilla Public
6
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7
// You can obtain one at http://mozilla.org/MPL/2.0/.
8

9
package mysql
10

11
import (
12
        "bytes"
13
        "crypto/tls"
14
        "database/sql/driver"
15
        "encoding/binary"
16
        "encoding/json"
17
        "fmt"
18
        "io"
19
        "math"
20
        "strconv"
21
        "time"
22
)
23

24
// MySQL client/server protocol documentations.
25
// https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html
26
// https://mariadb.com/kb/en/clientserver-protocol/
27

28
// read n bytes from mc.buf
29
func (mc *mysqlConn) readNext(n int) ([]byte, error) {
33,928,048✔
30
        if mc.buf.len() < n {
34,031,485✔
31
                err := mc.buf.fill(n, mc.readWithTimeout)
103,437✔
32
                if err != nil {
105,557✔
33
                        return nil, err
2,120✔
34
                }
2,120✔
35
        }
36
        return mc.buf.readNext(n), nil
33,925,928✔
37
}
38

39
// Read packet to buffer 'data'
40
func (mc *mysqlConn) readPacket() ([]byte, error) {
25,392,282✔
41
        var prevData []byte
25,392,282✔
42
        invalidSequence := false
25,392,282✔
43

25,392,282✔
44
        readNext := mc.readNext
25,392,282✔
45
        if mc.compress {
33,839,256✔
46
                readNext = mc.compIO.readNext
8,446,974✔
47
        }
8,446,974✔
48

49
        for {
50,784,980✔
50
                // read packet header
25,392,698✔
51
                data, err := readNext(4)
25,392,698✔
52
                if err != nil {
25,394,818✔
53
                        mc.close()
2,120✔
54
                        if cerr := mc.canceled.Value(); cerr != nil {
4,176✔
55
                                return nil, cerr
2,056✔
56
                        }
2,056✔
57
                        mc.log(err)
64✔
58
                        return nil, ErrInvalidConn
64✔
59
                }
60

61
                // packet length [24 bit]
62
                pktLen := getUint24(data[:3])
25,390,578✔
63
                seq := data[3]
25,390,578✔
64

25,390,578✔
65
                if mc.compress {
33,837,648✔
66
                        // MySQL and MariaDB doesn't check packet nr in compressed packet.
8,447,070✔
67
                        if debug && seq != mc.compressSequence {
8,447,070✔
68
                                fmt.Printf("[debug] mismatched compression sequence nr: expected: %v, got %v",
×
69
                                        mc.compressSequence, seq)
×
70
                        }
×
71
                        mc.compressSequence = seq + 1
8,447,070✔
72
                } else {
16,943,508✔
73
                        // check packet sync [8 bit]
16,943,508✔
74
                        if seq != mc.sequence {
16,943,572✔
75
                                mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seq))
64✔
76
                                // For large packets, we stop reading as soon as sync error.
64✔
77
                                if len(prevData) > 0 {
64✔
78
                                        mc.close()
×
79
                                        return nil, ErrPktSyncMul
×
80
                                }
×
81
                                invalidSequence = true
64✔
82
                        }
83
                        mc.sequence++
16,943,508✔
84
                }
85

86
                // packets with length 0 terminate a previous packet which is a
87
                // multiple of (2^24)-1 bytes long
88
                if pktLen == 0 {
25,390,674✔
89
                        // there was no previous packet
96✔
90
                        if prevData == nil {
128✔
91
                                mc.log(ErrMalformPkt)
32✔
92
                                mc.close()
32✔
93
                                return nil, ErrInvalidConn
32✔
94
                        }
32✔
95
                        return prevData, nil
64✔
96
                }
97

98
                // read packet body [pktLen bytes]
99
                data, err = readNext(pktLen)
25,390,482✔
100
                if err != nil {
25,390,482✔
101
                        mc.close()
×
102
                        if cerr := mc.canceled.Value(); cerr != nil {
×
103
                                return nil, cerr
×
104
                        }
×
105
                        mc.log(err)
×
106
                        return nil, ErrInvalidConn
×
107
                }
108

109
                // return data if this was the last packet
110
                if pktLen < maxPacketSize {
50,780,548✔
111
                        // zero allocations for non-split packets
25,390,066✔
112
                        if prevData != nil {
25,390,290✔
113
                                data = append(prevData, data...)
224✔
114
                        }
224✔
115
                        if invalidSequence {
25,390,130✔
116
                                mc.close()
64✔
117
                                // return sync error only for regular packet.
64✔
118
                                // error packets may have wrong sequence number.
64✔
119
                                if data[0] != iERR {
128✔
120
                                        return nil, ErrPktSync
64✔
121
                                }
64✔
122
                        }
123
                        return data, nil
25,390,002✔
124
                }
125

126
                prevData = append(prevData, data...)
416✔
127
        }
128
}
129

130
// Write packet buffer 'data'
131
func (mc *mysqlConn) writePacket(data []byte) error {
102,322✔
132
        pktLen := len(data) - 4
102,322✔
133
        if pktLen > mc.maxAllowedPacket {
102,322✔
134
                return ErrPktTooLarge
×
135
        }
×
136

137
        writeFunc := mc.writeWithTimeout
102,322✔
138
        if mc.compress {
126,963✔
139
                writeFunc = mc.compIO.writePackets
24,641✔
140
        }
24,641✔
141

142
        for {
204,932✔
143
                size := min(maxPacketSize, pktLen)
102,610✔
144
                putUint24(data[:3], size)
102,610✔
145
                data[3] = mc.sequence
102,610✔
146

102,610✔
147
                // Write packet
102,610✔
148
                if debug {
102,610✔
149
                        fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence)
×
150
                }
×
151

152
                n, err := writeFunc(data[:4+size])
102,610✔
153
                if err != nil {
102,742✔
154
                        mc.cleanup()
132✔
155
                        if cerr := mc.canceled.Value(); cerr != nil {
188✔
156
                                return cerr
56✔
157
                        }
56✔
158
                        if n == 0 && pktLen == len(data)-4 {
120✔
159
                                // only for the first loop iteration when nothing was written yet
44✔
160
                                mc.log(err)
44✔
161
                                return errBadConnNoWrite
44✔
162
                        } else {
76✔
163
                                return err
32✔
164
                        }
32✔
165
                }
166
                if n != 4+size {
102,478✔
167
                        // io.Writer(b) must return a non-nil error if it cannot write len(b) bytes.
×
168
                        // The io.ErrShortWrite error is used to indicate that this rule has not been followed.
×
169
                        mc.cleanup()
×
170
                        return io.ErrShortWrite
×
171
                }
×
172

173
                mc.sequence++
102,478✔
174
                if size != maxPacketSize {
204,668✔
175
                        return nil
102,190✔
176
                }
102,190✔
177
                pktLen -= size
288✔
178
                data = data[size:]
288✔
179
        }
180
}
181

182
/******************************************************************************
183
*                           Initialization Process                            *
184
******************************************************************************/
185

186
// Handshake Initialization Packet
187
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
188
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
17,065✔
189
        data, err = mc.readPacket()
17,065✔
190
        if err != nil {
17,155✔
191
                return
90✔
192
        }
90✔
193

194
        if data[0] == iERR {
17,074✔
195
                return nil, "", mc.handleErrorPacket(data)
99✔
196
        }
99✔
197

198
        // protocol version [1 byte]
199
        if data[0] < minProtocolVersion {
16,876✔
200
                return nil, "", fmt.Errorf(
×
201
                        "unsupported protocol version %d. Version %d or higher is required",
×
202
                        data[0],
×
203
                        minProtocolVersion,
×
204
                )
×
205
        }
×
206

207
        // server version [null terminated string]
208
        // connection id [4 bytes]
209
        pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
16,876✔
210

16,876✔
211
        // first part of the password cipher [8 bytes]
16,876✔
212
        authData := data[pos : pos+8]
16,876✔
213

16,876✔
214
        // (filler) always 0x00 [1 byte]
16,876✔
215
        pos += 8 + 1
16,876✔
216

16,876✔
217
        // capability flags (lower 2 bytes) [2 bytes]
16,876✔
218
        mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
16,876✔
219
        if mc.flags&clientProtocol41 == 0 {
16,876✔
220
                return nil, "", ErrOldProtocol
×
221
        }
×
222
        if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
16,876✔
223
                if mc.cfg.AllowFallbackToPlaintext {
×
224
                        mc.cfg.TLS = nil
×
225
                } else {
×
226
                        return nil, "", ErrNoTLS
×
227
                }
×
228
        }
229
        pos += 2
16,876✔
230

16,876✔
231
        if len(data) > pos {
33,752✔
232
                // character set [1 byte]
16,876✔
233
                // status flags [2 bytes]
16,876✔
234
                pos += 3
16,876✔
235
                // capability flags (upper 2 bytes) [2 bytes]
16,876✔
236
                mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
16,876✔
237
                pos += 2
16,876✔
238
                // length of auth-plugin-data [1 byte]
16,876✔
239
                // reserved (all [00]) [10 bytes]
16,876✔
240
                pos += 11
16,876✔
241

16,876✔
242
                // second part of the password cipher [minimum 13 bytes],
16,876✔
243
                // where len=MAX(13, length of auth-plugin-data - 8)
16,876✔
244
                //
16,876✔
245
                // The web documentation is ambiguous about the length. However,
16,876✔
246
                // according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
16,876✔
247
                // the 13th byte is "\0 byte, terminating the second part of
16,876✔
248
                // a scramble". So the second part of the password cipher is
16,876✔
249
                // a NULL terminated string that's at least 13 bytes with the
16,876✔
250
                // last byte being NULL.
16,876✔
251
                //
16,876✔
252
                // The official Python library uses the fixed length 12
16,876✔
253
                // which seems to work but technically could have a hidden bug.
16,876✔
254
                authData = append(authData, data[pos:pos+12]...)
16,876✔
255
                pos += 13
16,876✔
256

16,876✔
257
                // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
16,876✔
258
                // \NUL otherwise
16,876✔
259
                if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
33,720✔
260
                        plugin = string(data[pos : pos+end])
16,844✔
261
                } else {
16,876✔
262
                        plugin = string(data[pos:])
32✔
263
                }
32✔
264

265
                // make a memory safe copy of the cipher slice
266
                var b [20]byte
16,876✔
267
                copy(b[:], authData)
16,876✔
268
                return b[:], plugin, nil
16,876✔
269
        }
270

271
        // make a memory safe copy of the cipher slice
272
        var b [8]byte
×
273
        copy(b[:], authData)
×
274
        return b[:], plugin, nil
×
275
}
276

277
// Client Authentication Packet
278
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
279
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
17,292✔
280
        // Adjust client flags based on server support
17,292✔
281
        clientFlags := clientProtocol41 |
17,292✔
282
                clientSecureConn |
17,292✔
283
                clientLongPassword |
17,292✔
284
                clientTransactions |
17,292✔
285
                clientLocalFiles |
17,292✔
286
                clientPluginAuth |
17,292✔
287
                clientMultiResults |
17,292✔
288
                mc.flags&clientConnectAttrs |
17,292✔
289
                mc.flags&clientLongFlag
17,292✔
290

17,292✔
291
        sendConnectAttrs := mc.flags&clientConnectAttrs != 0
17,292✔
292

17,292✔
293
        if mc.cfg.ClientFoundRows {
17,324✔
294
                clientFlags |= clientFoundRows
32✔
295
        }
32✔
296
        if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
22,157✔
297
                clientFlags |= clientCompress
4,865✔
298
        }
4,865✔
299
        // To enable TLS / SSL
300
        if mc.cfg.TLS != nil {
22,426✔
301
                clientFlags |= clientSSL
5,134✔
302
        }
5,134✔
303

304
        if mc.cfg.MultiStatements {
17,580✔
305
                clientFlags |= clientMultiStatements
288✔
306
        }
288✔
307

308
        // encode length of the auth plugin data
309
        var authRespLEIBuf [9]byte
17,292✔
310
        authRespLen := len(authResp)
17,292✔
311
        authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
17,292✔
312
        if len(authRespLEI) > 1 {
17,324✔
313
                // if the length can not be written in 1 byte, it must be written as a
32✔
314
                // length encoded integer
32✔
315
                clientFlags |= clientPluginAuthLenEncClientData
32✔
316
        }
32✔
317

318
        pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
17,292✔
319

17,292✔
320
        // To specify a db name
17,292✔
321
        if n := len(mc.cfg.DBName); n > 0 {
34,136✔
322
                clientFlags |= clientConnectWithDB
16,844✔
323
                pktLen += n + 1
16,844✔
324
        }
16,844✔
325

326
        // encode length of the connection attributes
327
        var connAttrsLEI []byte
17,292✔
328
        if sendConnectAttrs {
34,136✔
329
                var connAttrsLEIBuf [9]byte
16,844✔
330
                connAttrsLen := len(mc.connector.encodedAttributes)
16,844✔
331
                connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
16,844✔
332
                pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
16,844✔
333
        }
16,844✔
334

335
        // Calculate packet length and get buffer with that size
336
        data, err := mc.buf.takeBuffer(pktLen + 4)
17,292✔
337
        if err != nil {
17,292✔
338
                mc.cleanup()
×
339
                return err
×
340
        }
×
341

342
        // ClientFlags [32 bit]
343
        binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags))
17,292✔
344

17,292✔
345
        // MaxPacketSize [32 bit] (none)
17,292✔
346
        binary.LittleEndian.PutUint32(data[8:], 0)
17,292✔
347

17,292✔
348
        // Collation ID [1 byte]
17,292✔
349
        data[12] = defaultCollationID
17,292✔
350
        if cname := mc.cfg.Collation; cname != "" {
17,676✔
351
                colID, ok := collations[cname]
384✔
352
                if ok {
768✔
353
                        data[12] = colID
384✔
354
                } else if len(mc.cfg.charsets) > 0 {
384✔
355
                        // When cfg.charset is set, the collation is set by `SET NAMES <charset> COLLATE <collation>`.
×
356
                        return fmt.Errorf("unknown collation: %q", cname)
×
357
                }
×
358
        }
359

360
        // Filler [23 bytes] (all 0x00)
361
        pos := 13
17,292✔
362
        for ; pos < 13+23; pos++ {
415,008✔
363
                data[pos] = 0
397,716✔
364
        }
397,716✔
365

366
        // SSL Connection Request Packet
367
        // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
368
        if mc.cfg.TLS != nil {
22,426✔
369
                // Send TLS / SSL request packet
5,134✔
370
                if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
5,184✔
371
                        return err
50✔
372
                }
50✔
373

374
                // Switch to TLS
375
                tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
5,084✔
376
                if err := tlsConn.Handshake(); err != nil {
7,380✔
377
                        if cerr := mc.canceled.Value(); cerr != nil {
4,592✔
378
                                return cerr
2,296✔
379
                        }
2,296✔
380
                        return err
×
381
                }
382
                mc.netConn = tlsConn
2,788✔
383
        }
384

385
        // User [null terminated string]
386
        if len(mc.cfg.User) > 0 {
29,892✔
387
                pos += copy(data[pos:], mc.cfg.User)
14,946✔
388
        }
14,946✔
389
        data[pos] = 0x00
14,946✔
390
        pos++
14,946✔
391

14,946✔
392
        // Auth Data [length encoded integer]
14,946✔
393
        pos += copy(data[pos:], authRespLEI)
14,946✔
394
        pos += copy(data[pos:], authResp)
14,946✔
395

14,946✔
396
        // Databasename [null terminated string]
14,946✔
397
        if len(mc.cfg.DBName) > 0 {
29,444✔
398
                pos += copy(data[pos:], mc.cfg.DBName)
14,498✔
399
                data[pos] = 0x00
14,498✔
400
                pos++
14,498✔
401
        }
14,498✔
402

403
        pos += copy(data[pos:], plugin)
14,946✔
404
        data[pos] = 0x00
14,946✔
405
        pos++
14,946✔
406

14,946✔
407
        // Connection Attributes
14,946✔
408
        if sendConnectAttrs {
29,444✔
409
                pos += copy(data[pos:], connAttrsLEI)
14,498✔
410
                pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
14,498✔
411
        }
14,498✔
412

413
        // Send Auth packet
414
        return mc.writePacket(data[:pos])
14,946✔
415
}
416

417
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
418
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
1,027✔
419
        pktLen := 4 + len(authData)
1,027✔
420
        data, err := mc.buf.takeBuffer(pktLen)
1,027✔
421
        if err != nil {
1,027✔
422
                mc.cleanup()
×
423
                return err
×
424
        }
×
425

426
        // Add the auth data [EOF]
427
        copy(data[4:], authData)
1,027✔
428
        return mc.writePacket(data)
1,027✔
429
}
430

431
/******************************************************************************
432
*                             Command Packets                                 *
433
******************************************************************************/
434

435
func (mc *mysqlConn) writeCommandPacket(command byte) error {
12,513✔
436
        // Reset Packet Sequence
12,513✔
437
        mc.resetSequence()
12,513✔
438

12,513✔
439
        data, err := mc.buf.takeSmallBuffer(4 + 1)
12,513✔
440
        if err != nil {
12,514✔
441
                return err
1✔
442
        }
1✔
443

444
        // Add command byte
445
        data[4] = command
12,512✔
446

12,512✔
447
        // Send CMD packet
12,512✔
448
        return mc.writePacket(data)
12,512✔
449
}
450

451
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
50,388✔
452
        // Reset Packet Sequence
50,388✔
453
        mc.resetSequence()
50,388✔
454

50,388✔
455
        pktLen := 1 + len(arg)
50,388✔
456
        data, err := mc.buf.takeBuffer(pktLen + 4)
50,388✔
457
        if err != nil {
50,388✔
458
                return err
×
459
        }
×
460

461
        // Add command byte
462
        data[4] = command
50,388✔
463

50,388✔
464
        // Add arg
50,388✔
465
        copy(data[5:], arg)
50,388✔
466

50,388✔
467
        // Send CMD packet
50,388✔
468
        err = mc.writePacket(data)
50,388✔
469
        mc.syncSequence()
50,388✔
470
        return err
50,388✔
471
}
472

473
func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
8,336✔
474
        // Reset Packet Sequence
8,336✔
475
        mc.resetSequence()
8,336✔
476

8,336✔
477
        data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
8,336✔
478
        if err != nil {
8,336✔
479
                return err
×
480
        }
×
481

482
        // Add command byte
483
        data[4] = command
8,336✔
484

8,336✔
485
        // Add arg [32 bit]
8,336✔
486
        binary.LittleEndian.PutUint32(data[5:], arg)
8,336✔
487

8,336✔
488
        // Send CMD packet
8,336✔
489
        return mc.writePacket(data)
8,336✔
490
}
491

492
/******************************************************************************
493
*                              Result Packets                                 *
494
******************************************************************************/
495

496
// Returns error if Packet is not a 'Result OK'-Packet
497
func (mc *okHandler) readResultOK() error {
2,522✔
498
        data, err := mc.conn().readPacket()
2,522✔
499
        if err != nil {
2,584✔
500
                return err
62✔
501
        }
62✔
502

503
        if data[0] == iOK {
4,920✔
504
                return mc.handleOkPacket(data)
2,460✔
505
        }
2,460✔
UNCOV
506
        return mc.conn().handleErrorPacket(data)
×
507
}
508

509
// Result Set Header Packet
510
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
511
func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
51,732✔
512
        // handleOkPacket replaces both values; other cases leave the values unchanged.
51,732✔
513
        mc.result.affectedRows = append(mc.result.affectedRows, 0)
51,732✔
514
        mc.result.insertIds = append(mc.result.insertIds, 0)
51,732✔
515

51,732✔
516
        data, err := mc.conn().readPacket()
51,732✔
517
        if err != nil {
51,872✔
518
                return 0, err
140✔
519
        }
140✔
520

521
        switch data[0] {
51,592✔
522
        case iOK:
28,534✔
523
                return 0, mc.handleOkPacket(data)
28,534✔
524

525
        case iERR:
546✔
526
                return 0, mc.conn().handleErrorPacket(data)
546✔
527

528
        case iLocalInFile:
480✔
529
                return 0, mc.handleInFileRequest(string(data[1:]))
480✔
530
        }
531

532
        // column count
533
        // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
534
        num, _, _ := readLengthEncodedInteger(data)
22,032✔
535
        // ignore remaining data in the packet. see #1478.
22,032✔
536
        return int(num), nil
22,032✔
537
}
538

539
// Error Packet
540
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
541
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
743✔
542
        if data[0] != iERR {
743✔
543
                return ErrMalformPkt
×
544
        }
×
545

546
        // 0xff [1 byte]
547

548
        // Error Number [16 bit uint]
549
        errno := binary.LittleEndian.Uint16(data[1:3])
743✔
550

743✔
551
        // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
743✔
552
        // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
743✔
553
        // 1836: ER_READ_ONLY_MODE
743✔
554
        if (errno == 1792 || errno == 1290 || errno == 1836) && mc.cfg.RejectReadOnly {
839✔
555
                // Oops; we are connected to a read-only connection, and won't be able
96✔
556
                // to issue any write statements. Since RejectReadOnly is configured,
96✔
557
                // we throw away this connection hoping this one would have write
96✔
558
                // permission. This is specifically for a possible race condition
96✔
559
                // during failover (e.g. on AWS Aurora). See README.md for more.
96✔
560
                //
96✔
561
                // We explicitly close the connection before returning
96✔
562
                // driver.ErrBadConn to ensure that `database/sql` purges this
96✔
563
                // connection and initiates a new one for next statement next time.
96✔
564
                mc.Close()
96✔
565
                return driver.ErrBadConn
96✔
566
        }
96✔
567

568
        me := &MySQLError{Number: errno}
647✔
569

647✔
570
        pos := 3
647✔
571

647✔
572
        // SQL State [optional: # + 5bytes string]
647✔
573
        if data[3] == 0x23 {
1,195✔
574
                copy(me.SQLState[:], data[4:4+5])
548✔
575
                pos = 9
548✔
576
        }
548✔
577

578
        // Error Message [string]
579
        me.Message = string(data[pos:])
647✔
580

647✔
581
        return me
647✔
582
}
583

584
func readStatus(b []byte) statusFlag {
79,497✔
585
        return statusFlag(b[0]) | statusFlag(b[1])<<8
79,497✔
586
}
79,497✔
587

588
// Returns an instance of okHandler for codepaths where mysqlConn.result doesn't
589
// need to be cleared first (e.g. during authentication, or while additional
590
// resultsets are being fetched.)
591
func (mc *mysqlConn) resultUnchanged() *okHandler {
14,362✔
592
        return (*okHandler)(mc)
14,362✔
593
}
14,362✔
594

595
// okHandler represents the state of the connection when mysqlConn.result has
596
// been prepared for processing of OK packets.
597
//
598
// To correctly populate mysqlConn.result (updated by handleOkPacket()), all
599
// callpaths must either:
600
//
601
// 1. first clear it using clearResult(), or
602
// 2. confirm that they don't need to (by calling resultUnchanged()).
603
//
604
// Both return an instance of type *okHandler.
605
type okHandler mysqlConn
606

607
// Exposes the underlying type's methods.
608
func (mc *okHandler) conn() *mysqlConn {
56,144✔
609
        return (*mysqlConn)(mc)
56,144✔
610
}
56,144✔
611

612
// clearResult clears the connection's stored affectedRows and insertIds
613
// fields.
614
//
615
// It returns a handler that can process OK responses.
616
func (mc *mysqlConn) clearResult() *okHandler {
92,212✔
617
        mc.result = mysqlResult{}
92,212✔
618
        return (*okHandler)(mc)
92,212✔
619
}
92,212✔
620

621
// Ok Packet
622
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
623
func (mc *okHandler) handleOkPacket(data []byte) error {
44,684✔
624
        var n, m int
44,684✔
625
        var affectedRows, insertId uint64
44,684✔
626

44,684✔
627
        // 0x00 [1 byte]
44,684✔
628

44,684✔
629
        // Affected rows [Length Coded Binary]
44,684✔
630
        affectedRows, _, n = readLengthEncodedInteger(data[1:])
44,684✔
631

44,684✔
632
        // Insert id [Length Coded Binary]
44,684✔
633
        insertId, _, m = readLengthEncodedInteger(data[1+n:])
44,684✔
634

44,684✔
635
        // Update for the current statement result (only used by
44,684✔
636
        // readResultSetHeaderPacket).
44,684✔
637
        if len(mc.result.affectedRows) > 0 {
73,506✔
638
                mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
28,822✔
639
        }
28,822✔
640
        if len(mc.result.insertIds) > 0 {
73,506✔
641
                mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
28,822✔
642
        }
28,822✔
643

644
        // server_status [2 bytes]
645
        mc.status = readStatus(data[1+n+m : 1+n+m+2])
44,684✔
646
        if mc.status&statusMoreResultsExists != 0 {
45,036✔
647
                return nil
352✔
648
        }
352✔
649

650
        // warning count [2 bytes]
651

652
        return nil
44,332✔
653
}
654

655
// Read Packets as Field Packets until EOF-Packet or an Error appears
656
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
657
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
21,891✔
658
        columns := make([]mysqlField, count)
21,891✔
659

21,891✔
660
        for i := 0; ; i++ {
12,655,996✔
661
                data, err := mc.readPacket()
12,634,105✔
662
                if err != nil {
12,634,105✔
663
                        return nil, err
×
664
                }
×
665

666
                // EOF Packet
667
                if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
12,655,996✔
668
                        if i == count {
43,782✔
669
                                return columns, nil
21,891✔
670
                        }
21,891✔
671
                        return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
×
672
                }
673

674
                // Catalog
675
                pos, err := skipLengthEncodedString(data)
12,612,214✔
676
                if err != nil {
12,612,214✔
677
                        return nil, err
×
678
                }
×
679

680
                // Database [len coded string]
681
                n, err := skipLengthEncodedString(data[pos:])
12,612,214✔
682
                if err != nil {
12,612,214✔
683
                        return nil, err
×
684
                }
×
685
                pos += n
12,612,214✔
686

12,612,214✔
687
                // Table [len coded string]
12,612,214✔
688
                if mc.cfg.ColumnsWithAlias {
12,612,278✔
689
                        tableName, _, n, err := readLengthEncodedString(data[pos:])
64✔
690
                        if err != nil {
64✔
691
                                return nil, err
×
692
                        }
×
693
                        pos += n
64✔
694
                        columns[i].tableName = string(tableName)
64✔
695
                } else {
12,612,150✔
696
                        n, err = skipLengthEncodedString(data[pos:])
12,612,150✔
697
                        if err != nil {
12,612,150✔
698
                                return nil, err
×
699
                        }
×
700
                        pos += n
12,612,150✔
701
                }
702

703
                // Original table [len coded string]
704
                n, err = skipLengthEncodedString(data[pos:])
12,612,214✔
705
                if err != nil {
12,612,214✔
706
                        return nil, err
×
707
                }
×
708
                pos += n
12,612,214✔
709

12,612,214✔
710
                // Name [len coded string]
12,612,214✔
711
                name, _, n, err := readLengthEncodedString(data[pos:])
12,612,214✔
712
                if err != nil {
12,612,214✔
713
                        return nil, err
×
714
                }
×
715
                columns[i].name = string(name)
12,612,214✔
716
                pos += n
12,612,214✔
717

12,612,214✔
718
                // Original name [len coded string]
12,612,214✔
719
                n, err = skipLengthEncodedString(data[pos:])
12,612,214✔
720
                if err != nil {
12,612,214✔
721
                        return nil, err
×
722
                }
×
723
                pos += n
12,612,214✔
724

12,612,214✔
725
                // Filler [uint8]
12,612,214✔
726
                pos++
12,612,214✔
727

12,612,214✔
728
                // Charset [charset, collation uint8]
12,612,214✔
729
                columns[i].charSet = data[pos]
12,612,214✔
730
                pos += 2
12,612,214✔
731

12,612,214✔
732
                // Length [uint32]
12,612,214✔
733
                columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
12,612,214✔
734
                pos += 4
12,612,214✔
735

12,612,214✔
736
                // Field type [uint8]
12,612,214✔
737
                columns[i].fieldType = fieldType(data[pos])
12,612,214✔
738
                pos++
12,612,214✔
739

12,612,214✔
740
                // Flags [uint16]
12,612,214✔
741
                columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
12,612,214✔
742
                pos += 2
12,612,214✔
743

12,612,214✔
744
                // Decimals [uint8]
12,612,214✔
745
                columns[i].decimals = data[pos]
12,612,214✔
746
                //pos++
747

748
                // Default value [len coded binary]
749
                //if pos < len(data) {
750
                //        defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
751
                //}
752
        }
753
}
754

755
// Read Packets as Field Packets until EOF-Packet or an Error appears
756
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
757
func (rows *textRows) readRow(dest []driver.Value) error {
16,358✔
758
        mc := rows.mc
16,358✔
759

16,358✔
760
        if rows.rs.done {
16,358✔
761
                return io.EOF
×
762
        }
×
763

764
        data, err := mc.readPacket()
16,358✔
765
        if err != nil {
16,358✔
766
                return err
×
767
        }
×
768

769
        // EOF Packet
770
        if data[0] == iEOF && len(data) == 5 {
17,369✔
771
                // server_status [2 bytes]
1,011✔
772
                rows.mc.status = readStatus(data[3:])
1,011✔
773
                rows.rs.done = true
1,011✔
774
                if !rows.HasNextResultSet() {
1,894✔
775
                        rows.mc = nil
883✔
776
                }
883✔
777
                return io.EOF
1,011✔
778
        }
779
        if data[0] == iERR {
15,347✔
780
                rows.mc = nil
×
781
                return mc.handleErrorPacket(data)
×
782
        }
×
783

784
        // RowSet Packet
785
        var (
15,347✔
786
                n      int
15,347✔
787
                isNull bool
15,347✔
788
                pos    int = 0
15,347✔
789
        )
15,347✔
790

15,347✔
791
        for i := range dest {
47,097✔
792
                // Read bytes and convert to string
31,750✔
793
                var buf []byte
31,750✔
794
                buf, isNull, n, err = readLengthEncodedString(data[pos:])
31,750✔
795
                pos += n
31,750✔
796

31,750✔
797
                if err != nil {
31,750✔
798
                        return err
×
799
                }
×
800

801
                if isNull {
33,720✔
802
                        dest[i] = nil
1,970✔
803
                        continue
1,970✔
804
                }
805

806
                switch rows.rs.columns[i].fieldType {
29,780✔
807
                case fieldTypeTimestamp,
808
                        fieldTypeDateTime,
809
                        fieldTypeDate,
810
                        fieldTypeNewDate:
3,940✔
811
                        if mc.parseTime {
6,630✔
812
                                dest[i], err = parseDateTime(buf, mc.cfg.Loc)
2,690✔
813
                        } else {
3,940✔
814
                                dest[i] = buf
1,250✔
815
                        }
1,250✔
816

817
                case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
6,936✔
818
                        dest[i], err = strconv.ParseInt(string(buf), 10, 64)
6,936✔
819

820
                case fieldTypeLongLong:
1,810✔
821
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
2,418✔
822
                                dest[i], err = strconv.ParseUint(string(buf), 10, 64)
608✔
823
                        } else {
1,810✔
824
                                dest[i], err = strconv.ParseInt(string(buf), 10, 64)
1,202✔
825
                        }
1,202✔
826

827
                case fieldTypeFloat:
736✔
828
                        var d float64
736✔
829
                        d, err = strconv.ParseFloat(string(buf), 32)
736✔
830
                        dest[i] = float32(d)
736✔
831

832
                case fieldTypeDouble:
544✔
833
                        dest[i], err = strconv.ParseFloat(string(buf), 64)
544✔
834

835
                default:
15,814✔
836
                        dest[i] = buf
15,814✔
837
                }
838
                if err != nil {
29,780✔
839
                        return err
×
840
                }
×
841
        }
842

843
        return nil
15,347✔
844
}
845

846
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
847
func (mc *mysqlConn) readUntilEOF() error {
33,130✔
848
        for {
12,665,563✔
849
                data, err := mc.readPacket()
12,632,433✔
850
                if err != nil {
12,632,433✔
851
                        return err
×
852
                }
×
853

854
                switch data[0] {
12,632,433✔
855
                case iERR:
×
856
                        return mc.handleErrorPacket(data)
×
857
                case iEOF:
33,130✔
858
                        if len(data) == 5 {
66,260✔
859
                                mc.status = readStatus(data[3:])
33,130✔
860
                        }
33,130✔
861
                        return nil
33,130✔
862
                }
863
        }
864
}
865

866
/******************************************************************************
867
*                           Prepared Statements                               *
868
******************************************************************************/
869

870
// Prepare Result Packets
871
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
872
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
8,464✔
873
        data, err := stmt.mc.readPacket()
8,464✔
874
        if err == nil {
16,928✔
875
                // packet indicator [1 byte]
8,464✔
876
                if data[0] != iOK {
8,464✔
877
                        return 0, stmt.mc.handleErrorPacket(data)
×
878
                }
×
879

880
                // statement id [4 bytes]
881
                stmt.id = binary.LittleEndian.Uint32(data[1:5])
8,464✔
882

8,464✔
883
                // Column count [16 bit uint]
8,464✔
884
                columnCount := binary.LittleEndian.Uint16(data[5:7])
8,464✔
885

8,464✔
886
                // Param count [16 bit uint]
8,464✔
887
                stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
8,464✔
888

8,464✔
889
                // Reserved [8 bit]
8,464✔
890

8,464✔
891
                // Warning count [16 bit uint]
8,464✔
892

8,464✔
893
                return columnCount, nil
8,464✔
894
        }
895
        return 0, err
×
896
}
897

898
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
899
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
64✔
900
        maxLen := stmt.mc.maxAllowedPacket - 1
64✔
901
        pktLen := maxLen
64✔
902

64✔
903
        // After the header (bytes 0-3) follows before the data:
64✔
904
        // 1 byte command
64✔
905
        // 4 bytes stmtID
64✔
906
        // 2 bytes paramID
64✔
907
        const dataOffset = 1 + 4 + 2
64✔
908

64✔
909
        // Cannot use the write buffer since
64✔
910
        // a) the buffer is too small
64✔
911
        // b) it is in use
64✔
912
        data := make([]byte, 4+1+4+2+len(arg))
64✔
913

64✔
914
        copy(data[4+dataOffset:], arg)
64✔
915

64✔
916
        for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
128✔
917
                if dataOffset+argLen < maxLen {
128✔
918
                        pktLen = dataOffset + argLen
64✔
919
                }
64✔
920

921
                stmt.mc.resetSequence()
64✔
922
                // Add command byte [1 byte]
64✔
923
                data[4] = comStmtSendLongData
64✔
924

64✔
925
                // Add stmtID [32 bit]
64✔
926
                binary.LittleEndian.PutUint32(data[5:], stmt.id)
64✔
927

64✔
928
                // Add paramID [16 bit]
64✔
929
                binary.LittleEndian.PutUint16(data[9:], uint16(paramID))
64✔
930

64✔
931
                // Send CMD packet
64✔
932
                err := stmt.mc.writePacket(data[:4+pktLen])
64✔
933
                if err == nil {
128✔
934
                        data = data[pktLen-dataOffset:]
64✔
935
                        continue
64✔
936
                }
937
                return err
×
938
        }
939

940
        // Reset Packet Sequence
941
        stmt.mc.resetSequence()
64✔
942
        return nil
64✔
943
}
944

945
// Execute Prepared Statement
946
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
947
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
8,880✔
948
        if len(args) != stmt.paramCount {
8,880✔
949
                return fmt.Errorf(
×
950
                        "argument count mismatch (got: %d; has: %d)",
×
951
                        len(args),
×
952
                        stmt.paramCount,
×
953
                )
×
954
        }
×
955

956
        const minPktLen = 4 + 1 + 4 + 1 + 4
8,880✔
957
        mc := stmt.mc
8,880✔
958

8,880✔
959
        // Determine threshold dynamically to avoid packet size shortage.
8,880✔
960
        longDataSize := max(mc.maxAllowedPacket/(stmt.paramCount+1), 64)
8,880✔
961

8,880✔
962
        // Reset packet-sequence
8,880✔
963
        mc.resetSequence()
8,880✔
964

8,880✔
965
        var data []byte
8,880✔
966
        var err error
8,880✔
967

8,880✔
968
        if len(args) == 0 {
9,520✔
969
                data, err = mc.buf.takeBuffer(minPktLen)
640✔
970
        } else {
8,880✔
971
                data, err = mc.buf.takeCompleteBuffer()
8,240✔
972
                // In this case the len(data) == cap(data) which is used to optimise the flow below.
8,240✔
973
        }
8,240✔
974
        if err != nil {
8,880✔
975
                return err
×
976
        }
×
977

978
        // command [1 byte]
979
        data[4] = comStmtExecute
8,880✔
980

8,880✔
981
        // statement_id [4 bytes]
8,880✔
982
        binary.LittleEndian.PutUint32(data[5:], stmt.id)
8,880✔
983

8,880✔
984
        // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
8,880✔
985
        data[9] = 0x00
8,880✔
986

8,880✔
987
        // iteration_count (uint32(1)) [4 bytes]
8,880✔
988
        binary.LittleEndian.PutUint32(data[10:], 1)
8,880✔
989

8,880✔
990
        if len(args) > 0 {
17,120✔
991
                pos := minPktLen
8,240✔
992

8,240✔
993
                var nullMask []byte
8,240✔
994
                if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
8,432✔
995
                        // buffer has to be extended but we don't know by how much so
192✔
996
                        // we depend on append after all data with known sizes fit.
192✔
997
                        // We stop at that because we deal with a lot of columns here
192✔
998
                        // which makes the required allocation size hard to guess.
192✔
999
                        tmp := make([]byte, pos+maskLen+typesLen)
192✔
1000
                        copy(tmp[:pos], data[:pos])
192✔
1001
                        data = tmp
192✔
1002
                        nullMask = data[pos : pos+maskLen]
192✔
1003
                        // No need to clean nullMask as make ensures that.
192✔
1004
                        pos += maskLen
192✔
1005
                } else {
8,240✔
1006
                        nullMask = data[pos : pos+maskLen]
8,048✔
1007
                        for i := range nullMask {
16,096✔
1008
                                nullMask[i] = 0
8,048✔
1009
                        }
8,048✔
1010
                        pos += maskLen
8,048✔
1011
                }
1012

1013
                // newParameterBoundFlag 1 [1 byte]
1014
                data[pos] = 0x01
8,240✔
1015
                pos++
8,240✔
1016

8,240✔
1017
                // type of each parameter [len(args)*2 bytes]
8,240✔
1018
                paramTypes := data[pos:]
8,240✔
1019
                pos += len(args) * 2
8,240✔
1020

8,240✔
1021
                // value of each parameter [n bytes]
8,240✔
1022
                paramValues := data[pos:pos]
8,240✔
1023
                valuesCap := cap(paramValues)
8,240✔
1024

8,240✔
1025
                for i, arg := range args {
12,599,680✔
1026
                        // build NULL-bitmap
12,591,440✔
1027
                        if arg == nil {
18,882,896✔
1028
                                nullMask[i/8] |= 1 << (uint(i) & 7)
6,291,456✔
1029
                                paramTypes[i+i] = byte(fieldTypeNULL)
6,291,456✔
1030
                                paramTypes[i+i+1] = 0x00
6,291,456✔
1031
                                continue
6,291,456✔
1032
                        }
1033

1034
                        if v, ok := arg.(json.RawMessage); ok {
6,300,048✔
1035
                                arg = []byte(v)
64✔
1036
                        }
64✔
1037
                        // cache types and values
1038
                        switch v := arg.(type) {
6,299,984✔
1039
                        case int64:
1,280✔
1040
                                paramTypes[i+i] = byte(fieldTypeLongLong)
1,280✔
1041
                                paramTypes[i+i+1] = 0x00
1,280✔
1042
                                paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
1,280✔
1043

1044
                        case uint64:
128✔
1045
                                paramTypes[i+i] = byte(fieldTypeLongLong)
128✔
1046
                                paramTypes[i+i+1] = 0x80 // type is unsigned
128✔
1047
                                paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
128✔
1048

1049
                        case float64:
64✔
1050
                                paramTypes[i+i] = byte(fieldTypeDouble)
64✔
1051
                                paramTypes[i+i+1] = 0x00
64✔
1052
                                paramValues = binary.LittleEndian.AppendUint64(paramValues, math.Float64bits(v))
64✔
1053

1054
                        case bool:
96✔
1055
                                paramTypes[i+i] = byte(fieldTypeTiny)
96✔
1056
                                paramTypes[i+i+1] = 0x00
96✔
1057

96✔
1058
                                if v {
128✔
1059
                                        paramValues = append(paramValues, 0x01)
32✔
1060
                                } else {
96✔
1061
                                        paramValues = append(paramValues, 0x00)
64✔
1062
                                }
64✔
1063

1064
                        case []byte:
224✔
1065
                                // Common case (non-nil value) first
224✔
1066
                                if v != nil {
384✔
1067
                                        paramTypes[i+i] = byte(fieldTypeString)
160✔
1068
                                        paramTypes[i+i+1] = 0x00
160✔
1069

160✔
1070
                                        if len(v) < longDataSize {
320✔
1071
                                                paramValues = appendLengthEncodedInteger(paramValues,
160✔
1072
                                                        uint64(len(v)),
160✔
1073
                                                )
160✔
1074
                                                paramValues = append(paramValues, v...)
160✔
1075
                                        } else {
160✔
1076
                                                if err := stmt.writeCommandLongData(i, v); err != nil {
×
1077
                                                        return err
×
1078
                                                }
×
1079
                                        }
1080
                                        continue
160✔
1081
                                }
1082

1083
                                // Handle []byte(nil) as a NULL value
1084
                                nullMask[i/8] |= 1 << (uint(i) & 7)
64✔
1085
                                paramTypes[i+i] = byte(fieldTypeNULL)
64✔
1086
                                paramTypes[i+i+1] = 0x00
64✔
1087

1088
                        case string:
6,295,912✔
1089
                                paramTypes[i+i] = byte(fieldTypeString)
6,295,912✔
1090
                                paramTypes[i+i+1] = 0x00
6,295,912✔
1091

6,295,912✔
1092
                                if len(v) < longDataSize {
12,591,760✔
1093
                                        paramValues = appendLengthEncodedInteger(paramValues,
6,295,848✔
1094
                                                uint64(len(v)),
6,295,848✔
1095
                                        )
6,295,848✔
1096
                                        paramValues = append(paramValues, v...)
6,295,848✔
1097
                                } else {
6,295,912✔
1098
                                        if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
64✔
1099
                                                return err
×
1100
                                        }
×
1101
                                }
1102

1103
                        case time.Time:
2,280✔
1104
                                paramTypes[i+i] = byte(fieldTypeString)
2,280✔
1105
                                paramTypes[i+i+1] = 0x00
2,280✔
1106

2,280✔
1107
                                var a [64]byte
2,280✔
1108
                                var b = a[:0]
2,280✔
1109

2,280✔
1110
                                if v.IsZero() {
3,152✔
1111
                                        b = append(b, "0000-00-00"...)
872✔
1112
                                } else {
2,280✔
1113
                                        b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
1,408✔
1114
                                        if err != nil {
1,408✔
1115
                                                return err
×
1116
                                        }
×
1117
                                }
1118

1119
                                paramValues = appendLengthEncodedInteger(paramValues,
2,280✔
1120
                                        uint64(len(b)),
2,280✔
1121
                                )
2,280✔
1122
                                paramValues = append(paramValues, b...)
2,280✔
1123

1124
                        default:
×
1125
                                return fmt.Errorf("cannot convert type: %T", arg)
×
1126
                        }
1127
                }
1128

1129
                // Check if param values exceeded the available buffer
1130
                // In that case we must build the data packet with the new values buffer
1131
                if valuesCap != cap(paramValues) {
8,336✔
1132
                        data = append(data[:pos], paramValues...)
96✔
1133
                        mc.buf.store(data) // allow this buffer to be reused
96✔
1134
                }
96✔
1135

1136
                pos += len(paramValues)
8,240✔
1137
                data = data[:pos]
8,240✔
1138
        }
1139

1140
        err = mc.writePacket(data)
8,880✔
1141
        mc.syncSequence()
8,880✔
1142
        return err
8,880✔
1143
}
1144

1145
// For each remaining resultset in the stream, discards its rows and updates
1146
// mc.affectedRows and mc.insertIds.
1147
func (mc *okHandler) discardResults() error {
46,342✔
1148
        for mc.status&statusMoreResultsExists != 0 {
46,598✔
1149
                resLen, err := mc.readResultSetHeaderPacket()
256✔
1150
                if err != nil {
320✔
1151
                        return err
64✔
1152
                }
64✔
1153
                if resLen > 0 {
192✔
1154
                        // columns
×
1155
                        if err := mc.conn().readUntilEOF(); err != nil {
×
1156
                                return err
×
1157
                        }
×
1158
                        // rows
1159
                        if err := mc.conn().readUntilEOF(); err != nil {
×
1160
                                return err
×
1161
                        }
×
1162
                }
1163
        }
1164
        return nil
46,278✔
1165
}
1166

1167
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
1168
func (rows *binaryRows) readRow(dest []driver.Value) error {
7,664✔
1169
        data, err := rows.mc.readPacket()
7,664✔
1170
        if err != nil {
7,664✔
1171
                return err
×
1172
        }
×
1173

1174
        // packet indicator [1 byte]
1175
        if data[0] != iOK {
8,336✔
1176
                // EOF Packet
672✔
1177
                if data[0] == iEOF && len(data) == 5 {
1,344✔
1178
                        rows.mc.status = readStatus(data[3:])
672✔
1179
                        rows.rs.done = true
672✔
1180
                        if !rows.HasNextResultSet() {
1,056✔
1181
                                rows.mc = nil
384✔
1182
                        }
384✔
1183
                        return io.EOF
672✔
1184
                }
1185
                mc := rows.mc
×
1186
                rows.mc = nil
×
1187

×
1188
                // Error otherwise
×
1189
                return mc.handleErrorPacket(data)
×
1190
        }
1191

1192
        // NULL-bitmap,  [(column-count + 7 + 2) / 8 bytes]
1193
        pos := 1 + (len(dest)+7+2)>>3
6,992✔
1194
        nullMask := data[1:pos]
6,992✔
1195

6,992✔
1196
        for i := range dest {
15,264✔
1197
                // Field is NULL
8,272✔
1198
                // (byte >> bit-pos) % 2 == 1
8,272✔
1199
                if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
8,464✔
1200
                        dest[i] = nil
192✔
1201
                        continue
192✔
1202
                }
1203

1204
                // Convert to byte-coded string
1205
                switch rows.rs.columns[i].fieldType {
8,080✔
1206
                case fieldTypeNULL:
×
1207
                        dest[i] = nil
×
1208
                        continue
×
1209

1210
                // Numeric Types
1211
                case fieldTypeTiny:
64✔
1212
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
64✔
1213
                                dest[i] = int64(data[pos])
×
1214
                        } else {
64✔
1215
                                dest[i] = int64(int8(data[pos]))
64✔
1216
                        }
64✔
1217
                        pos++
64✔
1218
                        continue
64✔
1219

1220
                case fieldTypeShort, fieldTypeYear:
32✔
1221
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
32✔
1222
                                dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
×
1223
                        } else {
32✔
1224
                                dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
32✔
1225
                        }
32✔
1226
                        pos += 2
32✔
1227
                        continue
32✔
1228

1229
                case fieldTypeInt24, fieldTypeLong:
1,014✔
1230
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
1,046✔
1231
                                dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
32✔
1232
                        } else {
1,014✔
1233
                                dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
982✔
1234
                        }
982✔
1235
                        pos += 4
1,014✔
1236
                        continue
1,014✔
1237

1238
                case fieldTypeLongLong:
1,002✔
1239
                        if rows.rs.columns[i].flags&flagUnsigned != 0 {
1,130✔
1240
                                val := binary.LittleEndian.Uint64(data[pos : pos+8])
128✔
1241
                                if val > math.MaxInt64 {
192✔
1242
                                        dest[i] = uint64ToString(val)
64✔
1243
                                } else {
128✔
1244
                                        dest[i] = int64(val)
64✔
1245
                                }
64✔
1246
                        } else {
874✔
1247
                                dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
874✔
1248
                        }
874✔
1249
                        pos += 8
1,002✔
1250
                        continue
1,002✔
1251

1252
                case fieldTypeFloat:
64✔
1253
                        dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
64✔
1254
                        pos += 4
64✔
1255
                        continue
64✔
1256

1257
                case fieldTypeDouble:
64✔
1258
                        dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
64✔
1259
                        pos += 8
64✔
1260
                        continue
64✔
1261

1262
                // Length coded Binary Strings
1263
                case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
1264
                        fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
1265
                        fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
1266
                        fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
1267
                        fieldTypeVector:
256✔
1268
                        var isNull bool
256✔
1269
                        var n int
256✔
1270
                        dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
256✔
1271
                        pos += n
256✔
1272
                        if err == nil {
512✔
1273
                                if !isNull {
512✔
1274
                                        continue
256✔
1275
                                } else {
×
1276
                                        dest[i] = nil
×
1277
                                        continue
×
1278
                                }
1279
                        }
1280
                        return err
×
1281

1282
                case
1283
                        fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
1284
                        fieldTypeTime,                         // Time [-][H]HH:MM:SS[.fractal]
1285
                        fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
5,584✔
1286

5,584✔
1287
                        num, isNull, n := readLengthEncodedInteger(data[pos:])
5,584✔
1288
                        pos += n
5,584✔
1289

5,584✔
1290
                        switch {
5,584✔
1291
                        case isNull:
×
1292
                                dest[i] = nil
×
1293
                                continue
×
1294
                        case rows.rs.columns[i].fieldType == fieldTypeTime:
3,584✔
1295
                                // database/sql does not support an equivalent to TIME, return a string
3,584✔
1296
                                var dstlen uint8
3,584✔
1297
                                switch decimals := rows.rs.columns[i].decimals; decimals {
3,584✔
1298
                                case 0x00, 0x1f:
1,792✔
1299
                                        dstlen = 8
1,792✔
1300
                                case 1, 2, 3, 4, 5, 6:
1,792✔
1301
                                        dstlen = 8 + 1 + decimals
1,792✔
1302
                                default:
×
1303
                                        return fmt.Errorf(
×
1304
                                                "protocol error, illegal decimals value %d",
×
1305
                                                rows.rs.columns[i].decimals,
×
1306
                                        )
×
1307
                                }
1308
                                dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
3,584✔
1309
                        case rows.mc.parseTime:
1,000✔
1310
                                dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
1,000✔
1311
                        default:
1,000✔
1312
                                var dstlen uint8
1,000✔
1313
                                if rows.rs.columns[i].fieldType == fieldTypeDate {
1,200✔
1314
                                        dstlen = 10
200✔
1315
                                } else {
1,000✔
1316
                                        switch decimals := rows.rs.columns[i].decimals; decimals {
800✔
1317
                                        case 0x00, 0x1f:
400✔
1318
                                                dstlen = 19
400✔
1319
                                        case 1, 2, 3, 4, 5, 6:
400✔
1320
                                                dstlen = 19 + 1 + decimals
400✔
1321
                                        default:
×
1322
                                                return fmt.Errorf(
×
1323
                                                        "protocol error, illegal decimals value %d",
×
1324
                                                        rows.rs.columns[i].decimals,
×
1325
                                                )
×
1326
                                        }
1327
                                }
1328
                                dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
1,000✔
1329
                        }
1330

1331
                        if err == nil {
11,168✔
1332
                                pos += int(num)
5,584✔
1333
                                continue
5,584✔
1334
                        } else {
×
1335
                                return err
×
1336
                        }
×
1337

1338
                // Please report if this happens!
1339
                default:
×
1340
                        return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
×
1341
                }
1342
        }
1343

1344
        return nil
6,992✔
1345
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc