You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
615 lines
14 KiB
Go
615 lines
14 KiB
Go
package notify
|
|
|
|
import (
|
|
"b612.me/starcrypto"
|
|
"b612.me/stario"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"math/rand"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
type ClientCommon struct {
|
|
alive atomic.Value
|
|
status Status
|
|
byeFromServer bool
|
|
conn net.Conn
|
|
mu sync.Mutex
|
|
msgID uint64
|
|
queue *stario.StarQueue
|
|
stopFn context.CancelFunc
|
|
stopCtx context.Context
|
|
parallelNum int
|
|
maxReadTimeout time.Duration
|
|
maxWriteTimeout time.Duration
|
|
keyExchangeFn func(c Client) error
|
|
linkFns map[string]func(message *Message)
|
|
defaultFns func(message *Message)
|
|
msgEn func([]byte, []byte) []byte
|
|
msgDe func([]byte, []byte) []byte
|
|
noFinSyncMsgPool sync.Map
|
|
handshakeRsaPubKey []byte
|
|
SecretKey []byte
|
|
noFinSyncMsgMaxKeepSeconds int
|
|
lastHeartbeat int64
|
|
heartbeatPeriod time.Duration
|
|
wg stario.WaitGroup
|
|
netType NetType
|
|
showError bool
|
|
skipKeyExchange bool
|
|
useHeartBeat bool
|
|
sequenceDe func([]byte) (interface{}, error)
|
|
sequenceEn func(interface{}) ([]byte, error)
|
|
debugMode bool
|
|
}
|
|
|
|
func (c *ClientCommon) Connect(network string, addr string) error {
|
|
if c.alive.Load().(bool) {
|
|
return errors.New("client already run")
|
|
}
|
|
c.stopCtx, c.stopFn = context.WithCancel(context.Background())
|
|
c.queue = stario.NewQueueCtx(c.stopCtx, 4, math.MaxUint32)
|
|
conn, err := net.Dial(network, addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.alive.Store(true)
|
|
c.status.Alive = true
|
|
c.conn = conn
|
|
if c.useHeartBeat {
|
|
go c.Heartbeat()
|
|
}
|
|
return c.clientPostInit()
|
|
}
|
|
|
|
func (c *ClientCommon) DebugMode(dmg bool) {
|
|
c.mu.Lock()
|
|
c.debugMode = dmg
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
func (c *ClientCommon) IsDebugMode() bool {
|
|
return c.debugMode
|
|
}
|
|
|
|
func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error {
|
|
if c.alive.Load().(bool) {
|
|
return errors.New("client already run")
|
|
}
|
|
c.stopCtx, c.stopFn = context.WithCancel(context.Background())
|
|
c.queue = stario.NewQueueCtx(c.stopCtx, 4, math.MaxUint32)
|
|
conn, err := net.DialTimeout(network, addr, timeout)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.alive.Store(true)
|
|
c.status.Alive = true
|
|
c.conn = conn
|
|
if c.useHeartBeat {
|
|
go c.Heartbeat()
|
|
}
|
|
return c.clientPostInit()
|
|
}
|
|
|
|
func (c *ClientCommon) monitorPool() {
|
|
for {
|
|
select {
|
|
case <-c.stopCtx.Done():
|
|
c.noFinSyncMsgPool.Range(func(k, v interface{}) bool {
|
|
data := v.(WaitMsg)
|
|
close(data.Reply)
|
|
c.noFinSyncMsgPool.Delete(k)
|
|
return true
|
|
})
|
|
return
|
|
case <-time.After(time.Second * 30):
|
|
}
|
|
now := time.Now()
|
|
if c.noFinSyncMsgMaxKeepSeconds > 0 {
|
|
c.noFinSyncMsgPool.Range(func(k, v interface{}) bool {
|
|
data := v.(WaitMsg)
|
|
if data.Time.Add(time.Duration(c.noFinSyncMsgMaxKeepSeconds) * time.Second).Before(now) {
|
|
close(data.Reply)
|
|
c.noFinSyncMsgPool.Delete(k)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) SkipExchangeKey() bool {
|
|
return c.skipKeyExchange
|
|
}
|
|
|
|
func (c *ClientCommon) SetSkipExchangeKey(val bool) {
|
|
c.skipKeyExchange = val
|
|
}
|
|
|
|
func (c *ClientCommon) clientPostInit() error {
|
|
go c.readMessage()
|
|
go c.loadMessage()
|
|
if !c.skipKeyExchange {
|
|
err := c.keyExchangeFn(c)
|
|
if err != nil {
|
|
c.alive.Store(false)
|
|
c.mu.Lock()
|
|
c.status = Status{
|
|
Alive: false,
|
|
Reason: "key exchange failed",
|
|
Err: err,
|
|
}
|
|
c.mu.Unlock()
|
|
c.stopFn()
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
func NewClient() Client {
|
|
var client = ClientCommon{
|
|
maxReadTimeout: 0,
|
|
maxWriteTimeout: 0,
|
|
sequenceEn: encode,
|
|
sequenceDe: Decode,
|
|
keyExchangeFn: aesRsaHello,
|
|
SecretKey: defaultAesKey,
|
|
handshakeRsaPubKey: defaultRsaPubKey,
|
|
msgEn: defaultMsgEn,
|
|
msgDe: defaultMsgDe,
|
|
}
|
|
client.alive.Store(false)
|
|
//heartbeat should not controlable for user
|
|
client.useHeartBeat = true
|
|
client.heartbeatPeriod = time.Second * 20
|
|
client.linkFns = make(map[string]func(*Message))
|
|
client.defaultFns = func(message *Message) {
|
|
return
|
|
}
|
|
client.wg = stario.NewWaitGroup(0)
|
|
client.stopCtx, client.stopFn = context.WithCancel(context.Background())
|
|
return &client
|
|
}
|
|
|
|
func (c *ClientCommon) Heartbeat() {
|
|
failedCount := 0
|
|
for {
|
|
select {
|
|
case <-c.stopCtx.Done():
|
|
return
|
|
case <-time.After(c.heartbeatPeriod):
|
|
}
|
|
_, err := c.sendWait(TransferMsg{
|
|
ID: 10000,
|
|
Key: "heartbeat",
|
|
Value: nil,
|
|
Type: MSG_SYS_WAIT,
|
|
}, time.Second*5)
|
|
if err == nil {
|
|
c.lastHeartbeat = time.Now().Unix()
|
|
failedCount = 0
|
|
}
|
|
if c.debugMode {
|
|
fmt.Println("failed to recv heartbeat,timeout!")
|
|
}
|
|
failedCount++
|
|
if failedCount >= 3 {
|
|
if c.debugMode {
|
|
fmt.Println("heatbeat failed more than 3 times,stop client")
|
|
}
|
|
c.alive.Store(false)
|
|
c.mu.Lock()
|
|
c.status = Status{
|
|
Alive: false,
|
|
Reason: "heartbeat failed more than 3 times",
|
|
Err: errors.New("heartbeat failed more than 3 times"),
|
|
}
|
|
c.mu.Unlock()
|
|
c.stopFn()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) ShowError(std bool) {
|
|
c.mu.Lock()
|
|
c.showError = std
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
func (c *ClientCommon) readMessage() {
|
|
for {
|
|
select {
|
|
case <-c.stopCtx.Done():
|
|
c.conn.Close()
|
|
return
|
|
default:
|
|
}
|
|
data := make([]byte, 8192)
|
|
if c.maxReadTimeout.Seconds() != 0 {
|
|
if err := c.conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout)); err != nil {
|
|
//TODO:ALERT
|
|
}
|
|
}
|
|
readNum, err := c.conn.Read(data)
|
|
if err == os.ErrDeadlineExceeded {
|
|
if readNum != 0 {
|
|
c.queue.ParseMessage(data[:readNum], "b612")
|
|
}
|
|
continue
|
|
}
|
|
if err != nil {
|
|
if c.showError || c.debugMode {
|
|
fmt.Println("client read error", err)
|
|
}
|
|
c.alive.Store(false)
|
|
c.mu.Lock()
|
|
c.status = Status{
|
|
Alive: false,
|
|
Reason: "client read error",
|
|
Err: err,
|
|
}
|
|
c.mu.Unlock()
|
|
c.stopFn()
|
|
continue
|
|
}
|
|
c.queue.ParseMessage(data[:readNum], "b612")
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) sayGoodBye() error {
|
|
_, err := c.sendWait(TransferMsg{
|
|
ID: 10010,
|
|
Key: "bye",
|
|
Value: nil,
|
|
Type: MSG_SYS_WAIT,
|
|
}, time.Second*3)
|
|
return err
|
|
}
|
|
|
|
func (c *ClientCommon) loadMessage() {
|
|
for {
|
|
select {
|
|
case <-c.stopCtx.Done():
|
|
//say goodbye
|
|
if !c.byeFromServer {
|
|
c.sayGoodBye()
|
|
}
|
|
c.conn.Close()
|
|
return
|
|
case data, ok := <-c.queue.RestoreChan():
|
|
if !ok {
|
|
continue
|
|
}
|
|
c.wg.Add(1)
|
|
go func(data stario.MsgQueue) {
|
|
defer c.wg.Done()
|
|
//fmt.Println("c received:", float64(time.Now().UnixNano()-nowd)/1000000)
|
|
now := time.Now()
|
|
//transfer to Msg
|
|
msg, err := c.sequenceDe(c.msgDe(c.SecretKey, data.Msg))
|
|
if err != nil {
|
|
if c.showError || c.debugMode {
|
|
fmt.Println("client decode data error", err)
|
|
}
|
|
return
|
|
}
|
|
message := Message{
|
|
ServerConn: c,
|
|
TransferMsg: msg.(TransferMsg),
|
|
NetType: NET_CLIENT,
|
|
}
|
|
message.Time = now
|
|
c.dispatchMsg(message)
|
|
}(data)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) dispatchMsg(message Message) {
|
|
switch message.TransferMsg.Type {
|
|
case MSG_SYS_WAIT:
|
|
fallthrough
|
|
case MSG_SYS:
|
|
c.sysMsg(message)
|
|
return
|
|
case MSG_KEY_CHANGE:
|
|
fallthrough
|
|
case MSG_SYS_REPLY:
|
|
fallthrough
|
|
case MSG_SYNC_REPLY:
|
|
data, ok := c.noFinSyncMsgPool.Load(message.ID)
|
|
if ok {
|
|
wait := data.(WaitMsg)
|
|
wait.Reply <- message
|
|
c.noFinSyncMsgPool.Delete(message.ID)
|
|
return
|
|
}
|
|
//return
|
|
fallthrough
|
|
default:
|
|
}
|
|
callFn := func(fn func(*Message)) {
|
|
fn(&message)
|
|
}
|
|
fn, ok := c.linkFns[message.Key]
|
|
if ok {
|
|
callFn(fn)
|
|
}
|
|
if c.defaultFns != nil {
|
|
callFn(c.defaultFns)
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) sysMsg(message Message) {
|
|
switch message.Key {
|
|
case "bye":
|
|
if message.TransferMsg.Type == MSG_SYS_WAIT {
|
|
//fmt.Println("recv stop signal from server")
|
|
c.byeFromServer = true
|
|
message.Reply(nil)
|
|
}
|
|
c.alive.Store(false)
|
|
c.mu.Lock()
|
|
c.status = Status{
|
|
Alive: false,
|
|
Reason: "recv stop signal from server",
|
|
Err: nil,
|
|
}
|
|
c.mu.Unlock()
|
|
c.stopFn()
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) SetDefaultLink(fn func(message *Message)) {
|
|
c.defaultFns = fn
|
|
}
|
|
|
|
func (c *ClientCommon) SetLink(key string, fn func(*Message)) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.linkFns[key] = fn
|
|
}
|
|
|
|
func (c *ClientCommon) send(msg TransferMsg) (WaitMsg, error) {
|
|
var wait WaitMsg
|
|
if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 {
|
|
msg.ID = atomic.AddUint64(&c.msgID, 1)
|
|
}
|
|
data, err := c.sequenceEn(msg)
|
|
if err != nil {
|
|
return WaitMsg{}, err
|
|
}
|
|
data = c.msgEn(c.SecretKey, data)
|
|
data = c.queue.BuildMessage(data)
|
|
if c.maxWriteTimeout.Seconds() != 0 {
|
|
c.conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
|
|
}
|
|
_, err = c.conn.Write(data)
|
|
if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_KEY_CHANGE || msg.Type == MSG_SYS_WAIT) {
|
|
wait.Time = time.Now()
|
|
wait.TransferMsg = msg
|
|
wait.Reply = make(chan Message, 1)
|
|
c.noFinSyncMsgPool.Store(msg.ID, wait)
|
|
}
|
|
return wait, err
|
|
}
|
|
|
|
func (c *ClientCommon) Send(key string, value MsgVal) error {
|
|
_, err := c.send(TransferMsg{
|
|
Key: key,
|
|
Value: value,
|
|
Type: MSG_ASYNC,
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (c *ClientCommon) sendWait(msg TransferMsg, timeout time.Duration) (Message, error) {
|
|
data, err := c.send(msg)
|
|
if err != nil {
|
|
return Message{}, err
|
|
}
|
|
if timeout.Seconds() == 0 {
|
|
msg, ok := <-data.Reply
|
|
if !ok {
|
|
return msg, os.ErrInvalid
|
|
}
|
|
return msg, nil
|
|
}
|
|
select {
|
|
case <-time.After(timeout):
|
|
close(data.Reply)
|
|
c.noFinSyncMsgPool.Delete(data.TransferMsg.ID)
|
|
return Message{}, os.ErrDeadlineExceeded
|
|
case <-c.stopCtx.Done():
|
|
return Message{}, errors.New("service shutdown")
|
|
case msg, ok := <-data.Reply:
|
|
if !ok {
|
|
return msg, os.ErrInvalid
|
|
}
|
|
return msg, nil
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) sendCtx(msg TransferMsg, ctx context.Context) (Message, error) {
|
|
data, err := c.send(msg)
|
|
if err != nil {
|
|
return Message{}, err
|
|
}
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
close(data.Reply)
|
|
c.noFinSyncMsgPool.Delete(data.TransferMsg.ID)
|
|
return Message{}, os.ErrDeadlineExceeded
|
|
case <-c.stopCtx.Done():
|
|
return Message{}, errors.New("service shutdown")
|
|
case msg, ok := <-data.Reply:
|
|
if !ok {
|
|
return msg, os.ErrInvalid
|
|
}
|
|
return msg, nil
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) {
|
|
data, err := c.sequenceEn(val)
|
|
if err != nil {
|
|
return Message{}, err
|
|
}
|
|
return c.sendCtx(TransferMsg{
|
|
Key: key,
|
|
Value: data,
|
|
Type: MSG_SYNC_ASK,
|
|
}, ctx)
|
|
}
|
|
|
|
func (c *ClientCommon) SendObj(key string, val interface{}) error {
|
|
data, err := encode(val)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = c.send(TransferMsg{
|
|
Key: key,
|
|
Value: data,
|
|
Type: MSG_ASYNC,
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (c *ClientCommon) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) {
|
|
return c.sendCtx(TransferMsg{
|
|
Key: key,
|
|
Value: value,
|
|
Type: MSG_SYNC_ASK,
|
|
}, ctx)
|
|
}
|
|
|
|
func (c *ClientCommon) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) {
|
|
return c.sendWait(TransferMsg{
|
|
Key: key,
|
|
Value: value,
|
|
Type: MSG_SYNC_ASK,
|
|
}, timeout)
|
|
}
|
|
|
|
func (c *ClientCommon) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) {
|
|
data, err := c.sequenceEn(value)
|
|
if err != nil {
|
|
return Message{}, err
|
|
}
|
|
return c.SendWait(key, data, timeout)
|
|
}
|
|
|
|
func (c *ClientCommon) Reply(m Message, value MsgVal) error {
|
|
return m.Reply(value)
|
|
}
|
|
|
|
func (c *ClientCommon) ExchangeKey(newKey []byte) error {
|
|
pubKey, err := starcrypto.DecodeRsaPublicKey(c.handshakeRsaPubKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
newSendKey, err := starcrypto.RSAEncrypt(pubKey, newKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
data, err := c.sendWait(TransferMsg{
|
|
ID: 19961127,
|
|
Key: "sirius",
|
|
Value: newSendKey,
|
|
Type: MSG_KEY_CHANGE,
|
|
}, time.Second*10)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if string(data.Value) != "success" {
|
|
return errors.New("cannot exchange new aes-key")
|
|
}
|
|
c.SecretKey = newKey
|
|
time.Sleep(time.Millisecond * 100)
|
|
return nil
|
|
}
|
|
|
|
func aesRsaHello(c Client) error {
|
|
newAesKey := []byte(fmt.Sprintf("%d%d%d%s", time.Now().UnixNano(), rand.Int63(), rand.Int63(), "b612.me"))
|
|
newAesKey = []byte(starcrypto.Md5Str(newAesKey))
|
|
return c.ExchangeKey(newAesKey)
|
|
}
|
|
|
|
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
|
|
return c.msgEn
|
|
}
|
|
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
|
|
c.msgEn = fn
|
|
}
|
|
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
|
|
return c.msgDe
|
|
}
|
|
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
|
|
c.msgDe = fn
|
|
}
|
|
|
|
func (c *ClientCommon) HeartbeatPeroid() time.Duration {
|
|
return c.heartbeatPeriod
|
|
}
|
|
func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) {
|
|
c.heartbeatPeriod = duration
|
|
}
|
|
|
|
func (c *ClientCommon) GetSecretKey() []byte {
|
|
return c.SecretKey
|
|
}
|
|
func (c *ClientCommon) SetSecretKey(key []byte) {
|
|
c.SecretKey = key
|
|
}
|
|
func (c *ClientCommon) RsaPubKey() []byte {
|
|
return c.handshakeRsaPubKey
|
|
}
|
|
func (c *ClientCommon) SetRsaPubKey(key []byte) {
|
|
c.handshakeRsaPubKey = key
|
|
}
|
|
func (c *ClientCommon) Stop() error {
|
|
if !c.alive.Load().(bool) {
|
|
return nil
|
|
}
|
|
c.alive.Store(false)
|
|
c.mu.Lock()
|
|
c.status = Status{
|
|
Alive: false,
|
|
Reason: "recv stop signal from user",
|
|
Err: nil,
|
|
}
|
|
c.mu.Unlock()
|
|
c.stopFn()
|
|
return nil
|
|
}
|
|
func (c *ClientCommon) StopMonitorChan() <-chan struct{} {
|
|
return c.stopCtx.Done()
|
|
}
|
|
|
|
func (c *ClientCommon) Status() Status {
|
|
return c.status
|
|
}
|
|
|
|
func (c *ClientCommon) GetSequenceEn() func(interface{}) ([]byte, error) {
|
|
return c.sequenceEn
|
|
}
|
|
func (c *ClientCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) {
|
|
c.sequenceEn = fn
|
|
}
|
|
func (c *ClientCommon) GetSequenceDe() func([]byte) (interface{}, error) {
|
|
return c.sequenceDe
|
|
}
|
|
func (c *ClientCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) {
|
|
c.sequenceDe = fn
|
|
}
|