You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
537 lines
14 KiB
Go
537 lines
14 KiB
Go
package client
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pingcap/errors"
|
|
|
|
. "github.com/starainrt/go-mysql/mysql"
|
|
"github.com/starainrt/go-mysql/packet"
|
|
"github.com/starainrt/go-mysql/utils"
|
|
)
|
|
|
|
type Conn struct {
|
|
*packet.Conn
|
|
|
|
user string
|
|
password string
|
|
db string
|
|
tlsConfig *tls.Config
|
|
proto string
|
|
|
|
serverVersion string
|
|
// server capabilities
|
|
capability uint32
|
|
// client-set capabilities only
|
|
ccaps uint32
|
|
|
|
attributes map[string]string
|
|
|
|
status uint16
|
|
|
|
charset string
|
|
|
|
salt []byte
|
|
authPluginName string
|
|
|
|
connectionID uint32
|
|
}
|
|
|
|
// This function will be called for every row in resultset from ExecuteSelectStreaming.
|
|
type SelectPerRowCallback func(row []FieldValue) error
|
|
|
|
// This function will be called once per result from ExecuteSelectStreaming
|
|
type SelectPerResultCallback func(result *Result) error
|
|
|
|
// This function will be called once per result from ExecuteMultiple
|
|
type ExecPerResultCallback func(result *Result, err error)
|
|
|
|
func getNetProto(addr string) string {
|
|
proto := "tcp"
|
|
if strings.Contains(addr, "/") {
|
|
proto = "unix"
|
|
}
|
|
return proto
|
|
}
|
|
|
|
// Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock.
|
|
// Accepts a series of configuration functions as a variadic argument.
|
|
func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
|
defer cancel()
|
|
|
|
dialer := &net.Dialer{}
|
|
|
|
return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...)
|
|
}
|
|
|
|
// Dialer connects to the address on the named network using the provided context.
|
|
type Dialer func(ctx context.Context, network, address string) (net.Conn, error)
|
|
|
|
// Connect to a MySQL server using the given Dialer.
|
|
func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) {
|
|
c := new(Conn)
|
|
|
|
if network == "" {
|
|
network = getNetProto(addr)
|
|
}
|
|
|
|
var err error
|
|
conn, err := dialer(ctx, network, addr)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
c.user = user
|
|
c.password = password
|
|
c.db = dbName
|
|
c.proto = network
|
|
c.Conn = packet.NewConn(conn)
|
|
|
|
// use default charset here, utf-8
|
|
c.charset = DEFAULT_CHARSET
|
|
|
|
// Apply configuration functions.
|
|
for i := range options {
|
|
options[i](c)
|
|
}
|
|
|
|
if c.tlsConfig != nil {
|
|
seq := c.Conn.Sequence
|
|
c.Conn = packet.NewTLSConn(conn)
|
|
c.Conn.Sequence = seq
|
|
}
|
|
|
|
if err = c.handshake(); err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func (c *Conn) handshake() error {
|
|
var err error
|
|
if err = c.readInitialHandshake(); err != nil {
|
|
c.Close()
|
|
return errors.Trace(fmt.Errorf("readInitialHandshake: %w", err))
|
|
}
|
|
|
|
if err := c.writeAuthHandshake(); err != nil {
|
|
c.Close()
|
|
|
|
return errors.Trace(fmt.Errorf("writeAuthHandshake: %w", err))
|
|
}
|
|
|
|
if err := c.handleAuthResult(); err != nil {
|
|
c.Close()
|
|
return errors.Trace(fmt.Errorf("handleAuthResult: %w", err))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) Close() error {
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
func (c *Conn) Ping() error {
|
|
if err := c.writeCommand(COM_PING); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
if _, err := c.readOK(); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetCapability enables the use of a specific capability
|
|
func (c *Conn) SetCapability(cap uint32) {
|
|
c.ccaps |= cap
|
|
}
|
|
|
|
// UnsetCapability disables the use of a specific capability
|
|
func (c *Conn) UnsetCapability(cap uint32) {
|
|
c.ccaps &= ^cap
|
|
}
|
|
|
|
// UseSSL: use default SSL
|
|
// pass to options when connect
|
|
func (c *Conn) UseSSL(insecureSkipVerify bool) {
|
|
c.tlsConfig = &tls.Config{InsecureSkipVerify: insecureSkipVerify}
|
|
}
|
|
|
|
// SetTLSConfig: use user-specified TLS config
|
|
// pass to options when connect
|
|
func (c *Conn) SetTLSConfig(config *tls.Config) {
|
|
c.tlsConfig = config
|
|
}
|
|
|
|
func (c *Conn) UseDB(dbName string) error {
|
|
if c.db == dbName {
|
|
return nil
|
|
}
|
|
|
|
if err := c.writeCommandStr(COM_INIT_DB, dbName); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
if _, err := c.readOK(); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
c.db = dbName
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) GetDB() string {
|
|
return c.db
|
|
}
|
|
|
|
func (c *Conn) GetServerVersion() string {
|
|
return c.serverVersion
|
|
}
|
|
|
|
func (c *Conn) CompareServerVersion(v string) (int, error) {
|
|
return CompareServerVersions(c.serverVersion, v)
|
|
}
|
|
|
|
func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
|
|
if len(args) == 0 {
|
|
return c.exec(command)
|
|
} else {
|
|
if s, err := c.Prepare(command); err != nil {
|
|
return nil, errors.Trace(err)
|
|
} else {
|
|
var r *Result
|
|
r, err = s.Execute(args...)
|
|
s.Close()
|
|
return r, err
|
|
}
|
|
}
|
|
}
|
|
|
|
// ExecuteMultiple will call perResultCallback for every result of the multiple queries
|
|
// that are executed.
|
|
//
|
|
// When ExecuteMultiple is used, the connection should have the SERVER_MORE_RESULTS_EXISTS
|
|
// flag set to signal the server multiple queries are executed. Handling the responses
|
|
// is up to the implementation of perResultCallback.
|
|
//
|
|
// Example:
|
|
//
|
|
// queries := "SELECT 1; SELECT NOW();"
|
|
// conn.ExecuteMultiple(queries, func(result *mysql.Result, err error) {
|
|
// // Use the result as you want
|
|
// })
|
|
func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) {
|
|
if err := c.writeCommandStr(COM_QUERY, query); err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
var err error
|
|
var result *Result
|
|
|
|
bs := utils.ByteSliceGet(16)
|
|
defer utils.ByteSlicePut(bs)
|
|
|
|
for {
|
|
bs.B, err = c.ReadPacketReuseMem(bs.B[:0])
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
switch bs.B[0] {
|
|
case OK_HEADER:
|
|
result, err = c.handleOKPacket(bs.B)
|
|
case ERR_HEADER:
|
|
err = c.handleErrorPacket(bytes.Repeat(bs.B, 1))
|
|
result = nil
|
|
case LocalInFile_HEADER:
|
|
err = ErrMalformPacket
|
|
result = nil
|
|
default:
|
|
result, err = c.readResultset(bs.B, false)
|
|
}
|
|
// call user-defined callback
|
|
perResultCallback(result, err)
|
|
|
|
// if there was an error of this was the last result, stop looping
|
|
if err != nil || result.Status&SERVER_MORE_RESULTS_EXISTS == 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
// return an empty result(set) signaling we're done streaming a multiple
|
|
// streaming session
|
|
// if this would end up in WriteValue, it would just be ignored as all
|
|
// responses should have been handled in perResultCallback
|
|
return &Result{Resultset: &Resultset{
|
|
Streaming: StreamingMultiple,
|
|
StreamingDone: true,
|
|
}}, nil
|
|
}
|
|
|
|
// ExecuteSelectStreaming will call perRowCallback for every row in resultset
|
|
// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
|
|
// When given, perResultCallback will be called once per result
|
|
//
|
|
// ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving.
|
|
//
|
|
// Example:
|
|
//
|
|
// var result mysql.Result
|
|
// conn.ExecuteSelectStreaming(`SELECT ... LIMIT 100500`, &result, func(row []mysql.FieldValue) error {
|
|
// // Use the row as you want.
|
|
// // You must not save FieldValue.AsString() value after this callback is done. Copy it if you need.
|
|
// return nil
|
|
// }, nil)
|
|
func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error {
|
|
if err := c.writeCommandStr(COM_QUERY, command); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
return c.readResultStreaming(false, result, perRowCallback, perResultCallback)
|
|
}
|
|
|
|
func (c *Conn) Begin() error {
|
|
_, err := c.exec("BEGIN")
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
func (c *Conn) Commit() error {
|
|
_, err := c.exec("COMMIT")
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
func (c *Conn) Rollback() error {
|
|
_, err := c.exec("ROLLBACK")
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
func (c *Conn) SetAttributes(attributes map[string]string) {
|
|
c.attributes = attributes
|
|
}
|
|
|
|
func (c *Conn) SetCharset(charset string) error {
|
|
if c.charset == charset {
|
|
return nil
|
|
}
|
|
|
|
if _, err := c.exec(fmt.Sprintf("SET NAMES %s", charset)); err != nil {
|
|
return errors.Trace(err)
|
|
} else {
|
|
c.charset = charset
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
|
|
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
fs := make([]*Field, 0, 4)
|
|
var f *Field
|
|
for {
|
|
data, err := c.ReadPacket()
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
// ERR Packet
|
|
if data[0] == ERR_HEADER {
|
|
return nil, c.handleErrorPacket(data)
|
|
}
|
|
|
|
// EOF Packet
|
|
if c.isEOFPacket(data) {
|
|
return fs, nil
|
|
}
|
|
|
|
if f, err = FieldData(data).Parse(); err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
fs = append(fs, f)
|
|
}
|
|
}
|
|
|
|
func (c *Conn) SetAutoCommit() error {
|
|
if !c.IsAutoCommit() {
|
|
if _, err := c.exec("SET AUTOCOMMIT = 1"); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) IsAutoCommit() bool {
|
|
return c.status&SERVER_STATUS_AUTOCOMMIT > 0
|
|
}
|
|
|
|
func (c *Conn) IsInTransaction() bool {
|
|
return c.status&SERVER_STATUS_IN_TRANS > 0
|
|
}
|
|
|
|
func (c *Conn) GetCharset() string {
|
|
return c.charset
|
|
}
|
|
|
|
func (c *Conn) GetConnectionID() uint32 {
|
|
return c.connectionID
|
|
}
|
|
|
|
func (c *Conn) HandleOKPacket(data []byte) *Result {
|
|
r, _ := c.handleOKPacket(data)
|
|
return r
|
|
}
|
|
|
|
func (c *Conn) HandleErrorPacket(data []byte) error {
|
|
return c.handleErrorPacket(data)
|
|
}
|
|
|
|
func (c *Conn) ReadOKPacket() (*Result, error) {
|
|
return c.readOK()
|
|
}
|
|
|
|
func (c *Conn) exec(query string) (*Result, error) {
|
|
if err := c.writeCommandStr(COM_QUERY, query); err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
return c.readResult(false)
|
|
}
|
|
|
|
func (c *Conn) CapabilityString() string {
|
|
var caps []string
|
|
capability := c.capability
|
|
for i := 0; capability != 0; i++ {
|
|
field := uint32(1 << i)
|
|
if capability&field == 0 {
|
|
continue
|
|
}
|
|
capability ^= field
|
|
|
|
switch field {
|
|
case CLIENT_LONG_PASSWORD:
|
|
caps = append(caps, "CLIENT_LONG_PASSWORD")
|
|
case CLIENT_FOUND_ROWS:
|
|
caps = append(caps, "CLIENT_FOUND_ROWS")
|
|
case CLIENT_LONG_FLAG:
|
|
caps = append(caps, "CLIENT_LONG_FLAG")
|
|
case CLIENT_CONNECT_WITH_DB:
|
|
caps = append(caps, "CLIENT_CONNECT_WITH_DB")
|
|
case CLIENT_NO_SCHEMA:
|
|
caps = append(caps, "CLIENT_NO_SCHEMA")
|
|
case CLIENT_COMPRESS:
|
|
caps = append(caps, "CLIENT_COMPRESS")
|
|
case CLIENT_ODBC:
|
|
caps = append(caps, "CLIENT_ODBC")
|
|
case CLIENT_LOCAL_FILES:
|
|
caps = append(caps, "CLIENT_LOCAL_FILES")
|
|
case CLIENT_IGNORE_SPACE:
|
|
caps = append(caps, "CLIENT_IGNORE_SPACE")
|
|
case CLIENT_PROTOCOL_41:
|
|
caps = append(caps, "CLIENT_PROTOCOL_41")
|
|
case CLIENT_INTERACTIVE:
|
|
caps = append(caps, "CLIENT_INTERACTIVE")
|
|
case CLIENT_SSL:
|
|
caps = append(caps, "CLIENT_SSL")
|
|
case CLIENT_IGNORE_SIGPIPE:
|
|
caps = append(caps, "CLIENT_IGNORE_SIGPIPE")
|
|
case CLIENT_TRANSACTIONS:
|
|
caps = append(caps, "CLIENT_TRANSACTIONS")
|
|
case CLIENT_RESERVED:
|
|
caps = append(caps, "CLIENT_RESERVED")
|
|
case CLIENT_SECURE_CONNECTION:
|
|
caps = append(caps, "CLIENT_SECURE_CONNECTION")
|
|
case CLIENT_MULTI_STATEMENTS:
|
|
caps = append(caps, "CLIENT_MULTI_STATEMENTS")
|
|
case CLIENT_MULTI_RESULTS:
|
|
caps = append(caps, "CLIENT_MULTI_RESULTS")
|
|
case CLIENT_PS_MULTI_RESULTS:
|
|
caps = append(caps, "CLIENT_PS_MULTI_RESULTS")
|
|
case CLIENT_PLUGIN_AUTH:
|
|
caps = append(caps, "CLIENT_PLUGIN_AUTH")
|
|
case CLIENT_CONNECT_ATTRS:
|
|
caps = append(caps, "CLIENT_CONNECT_ATTRS")
|
|
case CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA:
|
|
caps = append(caps, "CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA")
|
|
case CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS:
|
|
caps = append(caps, "CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS")
|
|
case CLIENT_SESSION_TRACK:
|
|
caps = append(caps, "CLIENT_SESSION_TRACK")
|
|
case CLIENT_DEPRECATE_EOF:
|
|
caps = append(caps, "CLIENT_DEPRECATE_EOF")
|
|
case CLIENT_OPTIONAL_RESULTSET_METADATA:
|
|
caps = append(caps, "CLIENT_OPTIONAL_RESULTSET_METADATA")
|
|
case CLIENT_ZSTD_COMPRESSION_ALGORITHM:
|
|
caps = append(caps, "CLIENT_ZSTD_COMPRESSION_ALGORITHM")
|
|
case CLIENT_QUERY_ATTRIBUTES:
|
|
caps = append(caps, "CLIENT_QUERY_ATTRIBUTES")
|
|
case MULTI_FACTOR_AUTHENTICATION:
|
|
caps = append(caps, "MULTI_FACTOR_AUTHENTICATION")
|
|
case CLIENT_CAPABILITY_EXTENSION:
|
|
caps = append(caps, "CLIENT_CAPABILITY_EXTENSION")
|
|
case CLIENT_SSL_VERIFY_SERVER_CERT:
|
|
caps = append(caps, "CLIENT_SSL_VERIFY_SERVER_CERT")
|
|
case CLIENT_REMEMBER_OPTIONS:
|
|
caps = append(caps, "CLIENT_REMEMBER_OPTIONS")
|
|
default:
|
|
caps = append(caps, fmt.Sprintf("(%d)", field))
|
|
}
|
|
}
|
|
|
|
return strings.Join(caps, "|")
|
|
}
|
|
|
|
func (c *Conn) StatusString() string {
|
|
var stats []string
|
|
status := c.status
|
|
for i := 0; status != 0; i++ {
|
|
field := uint16(1 << i)
|
|
if status&field == 0 {
|
|
continue
|
|
}
|
|
status ^= field
|
|
|
|
switch field {
|
|
case SERVER_STATUS_IN_TRANS:
|
|
stats = append(stats, "SERVER_STATUS_IN_TRANS")
|
|
case SERVER_STATUS_AUTOCOMMIT:
|
|
stats = append(stats, "SERVER_STATUS_AUTOCOMMIT")
|
|
case SERVER_MORE_RESULTS_EXISTS:
|
|
stats = append(stats, "SERVER_MORE_RESULTS_EXISTS")
|
|
case SERVER_STATUS_NO_GOOD_INDEX_USED:
|
|
stats = append(stats, "SERVER_STATUS_NO_GOOD_INDEX_USED")
|
|
case SERVER_STATUS_NO_INDEX_USED:
|
|
stats = append(stats, "SERVER_STATUS_NO_INDEX_USED")
|
|
case SERVER_STATUS_CURSOR_EXISTS:
|
|
stats = append(stats, "SERVER_STATUS_CURSOR_EXISTS")
|
|
case SERVER_STATUS_LAST_ROW_SEND:
|
|
stats = append(stats, "SERVER_STATUS_LAST_ROW_SEND")
|
|
case SERVER_STATUS_DB_DROPPED:
|
|
stats = append(stats, "SERVER_STATUS_DB_DROPPED")
|
|
case SERVER_STATUS_NO_BACKSLASH_ESCAPED:
|
|
stats = append(stats, "SERVER_STATUS_NO_BACKSLASH_ESCAPED")
|
|
case SERVER_STATUS_METADATA_CHANGED:
|
|
stats = append(stats, "SERVER_STATUS_METADATA_CHANGED")
|
|
case SERVER_QUERY_WAS_SLOW:
|
|
stats = append(stats, "SERVER_QUERY_WAS_SLOW")
|
|
case SERVER_PS_OUT_PARAMS:
|
|
stats = append(stats, "SERVER_PS_OUT_PARAMS")
|
|
default:
|
|
stats = append(stats, fmt.Sprintf("(%d)", field))
|
|
}
|
|
}
|
|
|
|
return strings.Join(stats, "|")
|
|
}
|