You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
328 lines
9.8 KiB
Go
328 lines
9.8 KiB
Go
2 years ago
|
package client
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"crypto/tls"
|
||
|
"encoding/binary"
|
||
|
"fmt"
|
||
|
|
||
|
"github.com/pingcap/errors"
|
||
|
. "github.com/starainrt/go-mysql/mysql"
|
||
|
"github.com/starainrt/go-mysql/packet"
|
||
|
)
|
||
|
|
||
|
const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
|
||
|
|
||
|
// defines the supported auth plugins
|
||
|
var supportedAuthPlugins = []string{AUTH_NATIVE_PASSWORD, AUTH_SHA256_PASSWORD, AUTH_CACHING_SHA2_PASSWORD}
|
||
|
|
||
|
// helper function to determine what auth methods are allowed by this client
|
||
|
func authPluginAllowed(pluginName string) bool {
|
||
|
for _, p := range supportedAuthPlugins {
|
||
|
if pluginName == p {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// See:
|
||
|
// - https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
|
||
|
// - https://github.com/alibaba/canal/blob/0ec46991499a22870dde4ae736b2586cbcbfea94/driver/src/main/java/com/alibaba/otter/canal/parse/driver/mysql/packets/server/HandshakeInitializationPacket.java#L89
|
||
|
// - https://github.com/vapor/mysql-nio/blob/main/Sources/MySQLNIO/Protocol/MySQLProtocol%2BHandshakeV10.swift
|
||
|
// - https://github.com/github/vitess-gh/blob/70ae1a2b3a116ff6411b0f40852d6e71382f6e07/go/mysql/client.go
|
||
|
func (c *Conn) readInitialHandshake() error {
|
||
|
data, err := c.ReadPacket()
|
||
|
if err != nil {
|
||
|
return errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
if data[0] == ERR_HEADER {
|
||
|
return errors.Annotate(c.handleErrorPacket(data), "read initial handshake error")
|
||
|
}
|
||
|
|
||
|
if data[0] < MinProtocolVersion {
|
||
|
return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
|
||
|
}
|
||
|
pos := 1
|
||
|
|
||
|
// skip mysql version
|
||
|
// mysql version end with 0x00
|
||
|
version := data[pos : bytes.IndexByte(data[pos:], 0x00)+1]
|
||
|
c.serverVersion = string(version)
|
||
|
pos += len(version) + 1 /*trailing zero byte*/
|
||
|
|
||
|
// connection id length is 4
|
||
|
c.connectionID = binary.LittleEndian.Uint32(data[pos : pos+4])
|
||
|
pos += 4
|
||
|
|
||
|
// first 8 bytes of the plugin provided data (scramble)
|
||
|
c.salt = append(c.salt[:0], data[pos:pos+8]...)
|
||
|
pos += 8
|
||
|
|
||
|
if data[pos] != 0 { // 0x00 byte, terminating the first part of a scramble
|
||
|
return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos]))
|
||
|
}
|
||
|
pos++
|
||
|
|
||
|
// The lower 2 bytes of the Capabilities Flags
|
||
|
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
||
|
// check protocol
|
||
|
if c.capability&CLIENT_PROTOCOL_41 == 0 {
|
||
|
return errors.New("the MySQL server can not support protocol 41 and above required by the client")
|
||
|
}
|
||
|
if c.capability&CLIENT_SSL == 0 && c.tlsConfig != nil {
|
||
|
return errors.New("the MySQL Server does not support TLS required by the client")
|
||
|
}
|
||
|
pos += 2
|
||
|
|
||
|
if len(data) > pos {
|
||
|
// default server a_protocol_character_set, only the lower 8-bits
|
||
|
// c.charset = data[pos]
|
||
|
pos += 1
|
||
|
|
||
|
c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
|
||
|
pos += 2
|
||
|
|
||
|
// The upper 2 bytes of the Capabilities Flags
|
||
|
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
|
||
|
pos += 2
|
||
|
|
||
|
// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
|
||
|
authPluginDataLen := data[pos]
|
||
|
if (c.capability&CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) {
|
||
|
return errors.Errorf("invalid auth plugin data filler %d", authPluginDataLen)
|
||
|
}
|
||
|
pos++
|
||
|
|
||
|
// skip reserved (all [00] ?)
|
||
|
pos += 10
|
||
|
|
||
|
if c.capability&CLIENT_SECURE_CONNECTION != 0 {
|
||
|
// Rest of the plugin provided data (scramble)
|
||
|
|
||
|
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
|
||
|
// $len=MAX(13, length of auth-plugin-data - 8)
|
||
|
//
|
||
|
// https://github.com/mysql/mysql-server/blob/1bfe02bdad6604d54913c62614bde57a055c8332/sql/auth/sql_authentication.cc#L1641-L1642
|
||
|
// the first packet *must* have at least 20 bytes of a scramble.
|
||
|
// if a plugin provided less, we pad it to 20 with zeros
|
||
|
rest := int(authPluginDataLen) - 8
|
||
|
if max := 12 + 1; rest < max {
|
||
|
rest = max
|
||
|
}
|
||
|
|
||
|
authPluginDataPart2 := data[pos : pos+rest]
|
||
|
pos += rest
|
||
|
|
||
|
c.salt = append(c.salt, authPluginDataPart2...)
|
||
|
}
|
||
|
|
||
|
if c.capability&CLIENT_PLUGIN_AUTH != 0 {
|
||
|
c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
|
||
|
pos += len(c.authPluginName)
|
||
|
|
||
|
if data[pos] != 0 {
|
||
|
return errors.Errorf("expect 0x00 after authPluginName, got %q", rune(data[pos]))
|
||
|
}
|
||
|
// pos++ // ineffectual
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// if server gives no default auth plugin name, use a client default
|
||
|
if c.authPluginName == "" {
|
||
|
c.authPluginName = defaultAuthPluginName
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// generate auth response data according to auth plugin
|
||
|
//
|
||
|
// NOTE: the returned boolean value indicates whether to add a \NUL to the end of data.
|
||
|
// it is quite tricky because MySQL server expects different formats of responses in different auth situations.
|
||
|
// here the \NUL needs to be added when sending back the empty password or cleartext password in 'sha256_password'
|
||
|
// authentication.
|
||
|
func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) {
|
||
|
// password hashing
|
||
|
switch c.authPluginName {
|
||
|
case AUTH_NATIVE_PASSWORD:
|
||
|
return CalcPassword(authData[:20], []byte(c.password)), false, nil
|
||
|
case AUTH_CACHING_SHA2_PASSWORD:
|
||
|
return CalcCachingSha2Password(authData, c.password), false, nil
|
||
|
case AUTH_CLEAR_PASSWORD:
|
||
|
return []byte(c.password), true, nil
|
||
|
case AUTH_SHA256_PASSWORD:
|
||
|
if len(c.password) == 0 {
|
||
|
return nil, true, nil
|
||
|
}
|
||
|
if c.tlsConfig != nil || c.proto == "unix" {
|
||
|
// write cleartext auth packet
|
||
|
// see: https://dev.mysql.com/doc/refman/8.0/en/sha256-pluggable-authentication.html
|
||
|
return []byte(c.password), true, nil
|
||
|
} else {
|
||
|
// request public key from server
|
||
|
// see: https://dev.mysql.com/doc/internals/en/public-key-retrieval.html
|
||
|
return []byte{1}, false, nil
|
||
|
}
|
||
|
default:
|
||
|
// not reachable
|
||
|
return nil, false, fmt.Errorf("auth plugin '%s' is not supported", c.authPluginName)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// generate connection attributes data
|
||
|
func (c *Conn) genAttributes() []byte {
|
||
|
if len(c.attributes) == 0 {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
attrData := make([]byte, 0)
|
||
|
for k, v := range c.attributes {
|
||
|
attrData = append(attrData, PutLengthEncodedString([]byte(k))...)
|
||
|
attrData = append(attrData, PutLengthEncodedString([]byte(v))...)
|
||
|
}
|
||
|
return append(PutLengthEncodedInt(uint64(len(attrData))), attrData...)
|
||
|
}
|
||
|
|
||
|
// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||
|
func (c *Conn) writeAuthHandshake() error {
|
||
|
if !authPluginAllowed(c.authPluginName) {
|
||
|
return fmt.Errorf("unknow auth plugin name '%s'", c.authPluginName)
|
||
|
}
|
||
|
|
||
|
// Set default client capabilities that reflect the abilities of this library
|
||
|
capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION |
|
||
|
CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH
|
||
|
// Adjust client capability flags based on server support
|
||
|
capability |= c.capability & CLIENT_LONG_FLAG
|
||
|
// Adjust client capability flags on specific client requests
|
||
|
// Only flags that would make any sense setting and aren't handled elsewhere
|
||
|
// in the library are supported here
|
||
|
capability |= c.ccaps&CLIENT_FOUND_ROWS | c.ccaps&CLIENT_IGNORE_SPACE |
|
||
|
c.ccaps&CLIENT_MULTI_STATEMENTS | c.ccaps&CLIENT_MULTI_RESULTS |
|
||
|
c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS
|
||
|
|
||
|
// To enable TLS / SSL
|
||
|
if c.tlsConfig != nil {
|
||
|
capability |= CLIENT_SSL
|
||
|
}
|
||
|
|
||
|
auth, addNull, err := c.genAuthResponse(c.salt)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// encode length of the auth plugin data
|
||
|
// here we use the Length-Encoded-Integer(LEI) as the data length may not fit into one byte
|
||
|
// see: https://dev.mysql.com/doc/internals/en/integer.html#length-encoded-integer
|
||
|
var authRespLEIBuf [9]byte
|
||
|
authRespLEI := AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth)))
|
||
|
if len(authRespLEI) > 1 {
|
||
|
// if the length can not be written in 1 byte, it must be written as a
|
||
|
// length encoded integer
|
||
|
capability |= CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
|
||
|
}
|
||
|
|
||
|
// packet length
|
||
|
// capability 4
|
||
|
// max-packet size 4
|
||
|
// charset 1
|
||
|
// reserved all[0] 23
|
||
|
// username
|
||
|
// auth
|
||
|
// mysql_native_password + null-terminated
|
||
|
length := 4 + 4 + 1 + 23 + len(c.user) + 1 + len(authRespLEI) + len(auth) + 21 + 1
|
||
|
if addNull {
|
||
|
length++
|
||
|
}
|
||
|
// db name
|
||
|
if len(c.db) > 0 {
|
||
|
capability |= CLIENT_CONNECT_WITH_DB
|
||
|
length += len(c.db) + 1
|
||
|
}
|
||
|
// connection attributes
|
||
|
attrData := c.genAttributes()
|
||
|
if len(attrData) > 0 {
|
||
|
capability |= CLIENT_CONNECT_ATTRS
|
||
|
length += len(attrData)
|
||
|
}
|
||
|
|
||
|
data := make([]byte, length+4)
|
||
|
|
||
|
// capability [32 bit]
|
||
|
data[4] = byte(capability)
|
||
|
data[5] = byte(capability >> 8)
|
||
|
data[6] = byte(capability >> 16)
|
||
|
data[7] = byte(capability >> 24)
|
||
|
|
||
|
// MaxPacketSize [32 bit] (none)
|
||
|
data[8] = 0x00
|
||
|
data[9] = 0x00
|
||
|
data[10] = 0x00
|
||
|
data[11] = 0x00
|
||
|
|
||
|
// Charset [1 byte]
|
||
|
// use default collation id 33 here, is utf-8
|
||
|
data[12] = DEFAULT_COLLATION_ID
|
||
|
|
||
|
// SSL Connection Request Packet
|
||
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
|
||
|
if c.tlsConfig != nil {
|
||
|
// Send TLS / SSL request packet
|
||
|
if err := c.WritePacket(data[:(4+4+1+23)+4]); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Switch to TLS
|
||
|
tlsConn := tls.Client(c.Conn.Conn, c.tlsConfig)
|
||
|
if err := tlsConn.Handshake(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
currentSequence := c.Sequence
|
||
|
c.Conn = packet.NewConn(tlsConn)
|
||
|
c.Sequence = currentSequence
|
||
|
}
|
||
|
|
||
|
// Filler [23 bytes] (all 0x00)
|
||
|
pos := 13
|
||
|
for ; pos < 13+23; pos++ {
|
||
|
data[pos] = 0
|
||
|
}
|
||
|
|
||
|
// User [null terminated string]
|
||
|
if len(c.user) > 0 {
|
||
|
pos += copy(data[pos:], c.user)
|
||
|
}
|
||
|
data[pos] = 0x00
|
||
|
pos++
|
||
|
|
||
|
// auth [length encoded integer]
|
||
|
pos += copy(data[pos:], authRespLEI)
|
||
|
pos += copy(data[pos:], auth)
|
||
|
if addNull {
|
||
|
data[pos] = 0x00
|
||
|
pos++
|
||
|
}
|
||
|
|
||
|
// db [null terminated string]
|
||
|
if len(c.db) > 0 {
|
||
|
pos += copy(data[pos:], c.db)
|
||
|
data[pos] = 0x00
|
||
|
pos++
|
||
|
}
|
||
|
|
||
|
// Assume native client during response
|
||
|
pos += copy(data[pos:], c.authPluginName)
|
||
|
data[pos] = 0x00
|
||
|
pos++
|
||
|
|
||
|
// connection attributes
|
||
|
if len(attrData) > 0 {
|
||
|
copy(data[pos:], attrData)
|
||
|
}
|
||
|
|
||
|
return c.WritePacket(data)
|
||
|
}
|