package mysql import ( "crypto/rand" "crypto/rsa" "crypto/sha1" "crypto/sha256" "encoding/binary" "fmt" "io" mrand "math/rand" "runtime" "strings" "time" "github.com/Masterminds/semver" "github.com/pingcap/errors" "github.com/siddontang/go/hack" ) func Pstack() string { buf := make([]byte, 1024) n := runtime.Stack(buf, false) return string(buf[0:n]) } func CalcPassword(scramble, password []byte) []byte { if len(password) == 0 { return nil } // stage1Hash = SHA1(password) crypt := sha1.New() crypt.Write(password) stage1 := crypt.Sum(nil) // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) // inner Hash crypt.Reset() crypt.Write(stage1) hash := crypt.Sum(nil) // outer Hash crypt.Reset() crypt.Write(scramble) crypt.Write(hash) scramble = crypt.Sum(nil) // token = scrambleHash XOR stage1Hash for i := range scramble { scramble[i] ^= stage1[i] } return scramble } // CalcCachingSha2Password: Hash password using MySQL 8+ method (SHA256) func CalcCachingSha2Password(scramble []byte, password string) []byte { if len(password) == 0 { return nil } // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) crypt := sha256.New() crypt.Write([]byte(password)) message1 := crypt.Sum(nil) crypt.Reset() crypt.Write(message1) message1Hash := crypt.Sum(nil) crypt.Reset() crypt.Write(message1Hash) crypt.Write(scramble) message2 := crypt.Sum(nil) for i := range message1 { message1[i] ^= message2[i] } return message1 } func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { plain := make([]byte, len(password)+1) copy(plain, password) for i := range plain { j := i % len(seed) plain[i] ^= seed[j] } sha1v := sha1.New() return rsa.EncryptOAEP(sha1v, rand.Reader, pub, plain, nil) } // AppendLengthEncodedInteger: encodes a uint64 value and appends it to the given bytes slice func AppendLengthEncodedInteger(b []byte, n uint64) []byte { switch { case n <= 250: return append(b, byte(n)) case n <= 0xffff: return append(b, 0xfc, byte(n), byte(n>>8)) case n <= 0xffffff: return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) } return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) } func RandomBuf(size int) []byte { buf := make([]byte, size) mrand.Seed(time.Now().UTC().UnixNano()) min, max := 30, 127 for i := 0; i < size; i++ { buf[i] = byte(min + mrand.Intn(max-min)) } return buf } // FixedLengthInt: little endian func FixedLengthInt(buf []byte) uint64 { var num uint64 = 0 for i, b := range buf { num |= uint64(b) << (uint(i) * 8) } return num } // BFixedLengthInt: big endian func BFixedLengthInt(buf []byte) uint64 { var num uint64 = 0 for i, b := range buf { num |= uint64(b) << (uint(len(buf)-i-1) * 8) } return num } func LengthEncodedInt(b []byte) (num uint64, isNull bool, n int) { if len(b) == 0 { return 0, true, 0 } switch b[0] { // 251: NULL case 0xfb: return 0, true, 1 // 252: value of following 2 case 0xfc: return uint64(b[1]) | uint64(b[2])<<8, false, 3 // 253: value of following 3 case 0xfd: return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 // 254: value of following 8 case 0xfe: return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | uint64(b[7])<<48 | uint64(b[8])<<56, false, 9 } // 0-250: value of first byte return uint64(b[0]), false, 1 } func PutLengthEncodedInt(n uint64) []byte { switch { case n <= 250: return []byte{byte(n)} case n <= 0xffff: return []byte{0xfc, byte(n), byte(n >> 8)} case n <= 0xffffff: return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)} case n <= 0xffffffffffffffff: return []byte{0xfe, byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)} } return nil } // LengthEncodedString returns the string read as a bytes slice, whether the value is NULL, // the number of bytes read and an error, in case the string is longer than // the input slice func LengthEncodedString(b []byte) ([]byte, bool, int, error) { // Get length num, isNull, n := LengthEncodedInt(b) if num < 1 { return b[n:n], isNull, n, nil } n += int(num) // Check data length if len(b) >= n { return b[n-int(num) : n : n], false, n, nil } return nil, false, n, io.EOF } func SkipLengthEncodedString(b []byte) (int, error) { // Get length num, _, n := LengthEncodedInt(b) if num < 1 { return n, nil } n += int(num) // Check data length if len(b) >= n { return n, nil } return n, io.EOF } func PutLengthEncodedString(b []byte) []byte { data := make([]byte, 0, len(b)+9) data = append(data, PutLengthEncodedInt(uint64(len(b)))...) data = append(data, b...) return data } func Uint16ToBytes(n uint16) []byte { return []byte{ byte(n), byte(n >> 8), } } func Uint32ToBytes(n uint32) []byte { return []byte{ byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), } } func Uint64ToBytes(n uint64) []byte { return []byte{ byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56), } } func FormatBinaryDate(n int, data []byte) ([]byte, error) { switch n { case 0: return []byte("0000-00-00"), nil case 4: return []byte(fmt.Sprintf("%04d-%02d-%02d", binary.LittleEndian.Uint16(data[:2]), data[2], data[3])), nil default: return nil, errors.Errorf("invalid date packet length %d", n) } } func FormatBinaryDateTime(n int, data []byte) ([]byte, error) { switch n { case 0: return []byte("0000-00-00 00:00:00"), nil case 4: return []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00", binary.LittleEndian.Uint16(data[:2]), data[2], data[3])), nil case 7: return []byte(fmt.Sprintf( "%04d-%02d-%02d %02d:%02d:%02d", binary.LittleEndian.Uint16(data[:2]), data[2], data[3], data[4], data[5], data[6])), nil case 11: return []byte(fmt.Sprintf( "%04d-%02d-%02d %02d:%02d:%02d.%06d", binary.LittleEndian.Uint16(data[:2]), data[2], data[3], data[4], data[5], data[6], binary.LittleEndian.Uint32(data[7:11]))), nil default: return nil, errors.Errorf("invalid datetime packet length %d", n) } } func FormatBinaryTime(n int, data []byte) ([]byte, error) { if n == 0 { return []byte("0000-00-00"), nil } var sign byte if data[0] == 1 { sign = byte('-') } switch n { case 8: return []byte(fmt.Sprintf( "%c%02d:%02d:%02d", sign, uint16(data[1])*24+uint16(data[5]), data[6], data[7], )), nil case 12: return []byte(fmt.Sprintf( "%c%02d:%02d:%02d.%06d", sign, uint16(data[1])*24+uint16(data[5]), data[6], data[7], binary.LittleEndian.Uint32(data[8:12]), )), nil default: return nil, errors.Errorf("invalid time packet length %d", n) } } var ( DONTESCAPE = byte(255) EncodeMap [256]byte ) // Escape: only support utf-8 func Escape(sql string) string { dest := make([]byte, 0, 2*len(sql)) for _, w := range hack.Slice(sql) { if c := EncodeMap[w]; c == DONTESCAPE { dest = append(dest, w) } else { dest = append(dest, '\\', c) } } return string(dest) } func GetNetProto(addr string) string { if strings.Contains(addr, "/") { return "unix" } else { return "tcp" } } // ErrorEqual returns a boolean indicating whether err1 is equal to err2. func ErrorEqual(err1, err2 error) bool { e1 := errors.Cause(err1) e2 := errors.Cause(err2) if e1 == e2 { return true } if e1 == nil || e2 == nil { return e1 == e2 } return e1.Error() == e2.Error() } func CompareServerVersions(a, b string) (int, error) { var ( aVer, bVer *semver.Version err error ) if aVer, err = semver.NewVersion(a); err != nil { return 0, fmt.Errorf("cannot parse %q as semver: %w", a, err) } if bVer, err = semver.NewVersion(b); err != nil { return 0, fmt.Errorf("cannot parse %q as semver: %w", b, err) } return aVer.Compare(bVer), nil } var encodeRef = map[byte]byte{ '\x00': '0', '\'': '\'', '"': '"', '\b': 'b', '\n': 'n', '\r': 'r', '\t': 't', 26: 'Z', // ctl-Z '\\': '\\', } func init() { for i := range EncodeMap { EncodeMap[i] = DONTESCAPE } for i := range EncodeMap { if to, ok := encodeRef[byte(i)]; ok { EncodeMap[byte(i)] = to } } }