• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

go-sql-driver / mysql / 14665995848

25 Apr 2025 01:39PM UTC coverage: 82.772% (-0.2%) from 82.961%
14665995848

Pull #1696

github

rusher
use named constants for caching_sha2_password
Pull Request #1696: Authentication

321 of 387 new or added lines in 9 files covered. (82.95%)

7 existing lines in 3 files now uncovered.

3315 of 4005 relevant lines covered (82.77%)

2440698.09 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.0
/auth.go
1
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2
//
3
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
4
//
5
// This Source Code Form is subject to the terms of the Mozilla Public
6
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7
// You can obtain one at http://mozilla.org/MPL/2.0/.
8

9
package mysql
10

11
import (
12
        "bytes"
13
        "crypto/rsa"
14
        "fmt"
15
        "sync"
16
)
17

18
// server pub keys registry
19
var (
20
        serverPubKeyLock     sync.RWMutex
21
        serverPubKeyRegistry map[string]*rsa.PublicKey
22
)
23

24
// RegisterServerPubKey registers a server RSA public key which can be used to
25
// send data in a secure manner to the server without receiving the public key
26
// in a potentially insecure way from the server first.
27
// Registered keys can afterwards be used adding serverPubKey=<name> to the DSN.
28
//
29
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver
30
// after registering it and may not be modified.
31
//
32
//        data, err := os.ReadFile("mykey.pem")
33
//        if err != nil {
34
//                log.Fatal(err)
35
//        }
36
//
37
//        block, _ := pem.Decode(data)
38
//        if block == nil || block.Type != "PUBLIC KEY" {
39
//                log.Fatal("failed to decode PEM block containing public key")
40
//        }
41
//
42
//        pub, err := x509.ParsePKIXPublicKey(block.Bytes)
43
//        if err != nil {
44
//                log.Fatal(err)
45
//        }
46
//
47
//        if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
48
//                mysql.RegisterServerPubKey("mykey", rsaPubKey)
49
//        } else {
50
//                log.Fatal("not a RSA public key")
51
//        }
52
func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) {
96✔
53
        serverPubKeyLock.Lock()
96✔
54
        if serverPubKeyRegistry == nil {
128✔
55
                serverPubKeyRegistry = make(map[string]*rsa.PublicKey)
32✔
56
        }
32✔
57

58
        serverPubKeyRegistry[name] = pubKey
96✔
59
        serverPubKeyLock.Unlock()
96✔
60
}
61

62
// DeregisterServerPubKey removes the public key registered with the given name.
63
func DeregisterServerPubKey(name string) {
96✔
64
        serverPubKeyLock.Lock()
96✔
65
        if serverPubKeyRegistry != nil {
192✔
66
                delete(serverPubKeyRegistry, name)
96✔
67
        }
96✔
68
        serverPubKeyLock.Unlock()
96✔
69
}
70

71
func getServerPubKey(name string) (pubKey *rsa.PublicKey) {
128✔
72
        serverPubKeyLock.RLock()
128✔
73
        if v, ok := serverPubKeyRegistry[name]; ok {
224✔
74
                pubKey = v
96✔
75
        }
96✔
76
        serverPubKeyLock.RUnlock()
128✔
77
        return
128✔
78
}
79

80
// handleAuthResult processes the initial authentication packet and manages subsequent
81
// authentication flow. It reads the first authentication packet and hands off processing
82
// to the appropriate auth plugin.
83
func (mc *mysqlConn) handleAuthResult(remainingSwitch uint, initialSeed []byte, authPlugin AuthPlugin) error {
16,352✔
84
        if remainingSwitch == 0 {
16,352✔
NEW
85
                return fmt.Errorf("maximum of %d authentication switch reached", authMaximumSwitch)
×
UNCOV
86
        }
×
87

88
        data, err := mc.readPacket()
16,352✔
89
        if err != nil {
18,101✔
90
                return err
1,749✔
91
        }
1,749✔
92
        if len(data) == 0 {
14,603✔
NEW
93
                return fmt.Errorf("%w: empty auth response packet", ErrMalformPkt)
×
UNCOV
94
        }
×
95

96
        data, err = authPlugin.continuationAuth(data, initialSeed, mc)
14,603✔
97
        if err != nil {
14,618✔
98
                return err
15✔
99
        }
15✔
100

101
        switch data[0] {
14,588✔
102
        case iOK:
13,690✔
103
                return mc.resultUnchanged().handleOkPacket(data)
13,690✔
104
        case iERR:
98✔
105
                return mc.handleErrorPacket(data)
98✔
106
        case iEOF:
800✔
107
                plugin, authData := mc.parseAuthSwitchData(data, initialSeed)
800✔
108

800✔
109
                authPlugin, exists := globalPluginRegistry.GetPlugin(plugin)
800✔
110
                if !exists {
800✔
NEW
111
                        return fmt.Errorf("this authentication plugin '%s' is not supported", plugin)
×
UNCOV
112
                }
×
113

114
                initialAuthResponse, err := authPlugin.InitAuth(authData, mc.cfg)
800✔
115
                if err != nil {
928✔
116
                        return err
128✔
117
                }
128✔
118

119
                if err := mc.writeAuthSwitchPacket(initialAuthResponse); err != nil {
672✔
NEW
120
                        return err
×
UNCOV
121
                }
×
122

123
                remainingSwitch--
672✔
124
                return mc.handleAuthResult(remainingSwitch, authData, authPlugin)
672✔
125

126
        default:
×
NEW
127
                return ErrMalformPkt
×
128
        }
129
}
130

131
// parseAuthSwitchData extracts the authentication plugin name and associated data
132
// from an authentication switch request packet.
133
func (mc *mysqlConn) parseAuthSwitchData(data []byte, initialSeed []byte) (string, []byte) {
800✔
134
        if len(data) == 1 {
896✔
135
                // Special case for the old authentication protocol
96✔
136
                return "mysql_old_password", initialSeed
96✔
137
        }
96✔
138

139
        pluginEndIndex := bytes.IndexByte(data, 0x00)
704✔
140
        if pluginEndIndex < 0 {
704✔
NEW
141
                return "", nil
×
UNCOV
142
        }
×
143

144
        plugin := string(data[1:pluginEndIndex])
704✔
145
        authData := data[pluginEndIndex+1:]
704✔
146
        if len(authData) > 0 && authData[len(authData)-1] == 0 {
1,248✔
147
                authData = authData[:len(authData)-1]
544✔
148
        }
544✔
149

150
        savedAuthData := make([]byte, len(authData))
704✔
151
        copy(savedAuthData, authData)
704✔
152
        return plugin, savedAuthData
704✔
153
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc