package notify import ( "b612.me/starcrypto" "b612.me/stario" "b612.me/starnet" "context" "errors" "fmt" "math" "math/rand" "net" "os" "sync" "sync/atomic" "time" ) type ClientCommon struct { alive atomic.Value status Status byeFromServer bool conn net.Conn mu sync.Mutex msgID uint64 queue *starnet.StarQueue stopFn context.CancelFunc stopCtx context.Context parallelNum int maxReadTimeout time.Duration maxWriteTimeout time.Duration keyExchangeFn func(c Client) error linkFns map[string]func(message *Message) defaultFns func(message *Message) msgEn func([]byte, []byte) []byte msgDe func([]byte, []byte) []byte noFinSyncMsgPool sync.Map handshakeRsaPubKey []byte SecretKey []byte noFinSyncMsgMaxKeepSeconds int lastHeartbeat int64 heartbeatPeriod time.Duration wg stario.WaitGroup netType NetType showError bool skipKeyExchange bool useHeartBeat bool sequenceDe func([]byte) (interface{}, error) sequenceEn func(interface{}) ([]byte, error) debugMode bool } func (c *ClientCommon) Connect(network string, addr string) error { if c.alive.Load().(bool) { return errors.New("client already run") } c.stopCtx, c.stopFn = context.WithCancel(context.Background()) c.queue = starnet.NewQueueCtx(c.stopCtx, 4, math.MaxUint32) conn, err := net.Dial(network, addr) if err != nil { return err } c.alive.Store(true) c.status.Alive = true c.conn = conn if c.useHeartBeat { go c.Heartbeat() } return c.clientPostInit() } func (c *ClientCommon) DebugMode(dmg bool) { c.mu.Lock() c.debugMode = dmg c.mu.Unlock() } func (c *ClientCommon) IsDebugMode() bool { return c.debugMode } func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error { if c.alive.Load().(bool) { return errors.New("client already run") } c.stopCtx, c.stopFn = context.WithCancel(context.Background()) c.queue = starnet.NewQueueCtx(c.stopCtx, 4, math.MaxUint32) conn, err := net.DialTimeout(network, addr, timeout) if err != nil { return err } c.alive.Store(true) c.status.Alive = true c.conn = conn if c.useHeartBeat { go c.Heartbeat() } return c.clientPostInit() } func (c *ClientCommon) monitorPool() { for { select { case <-c.stopCtx.Done(): c.noFinSyncMsgPool.Range(func(k, v interface{}) bool { data := v.(WaitMsg) close(data.Reply) c.noFinSyncMsgPool.Delete(k) return true }) return case <-time.After(time.Second * 30): } now := time.Now() if c.noFinSyncMsgMaxKeepSeconds > 0 { c.noFinSyncMsgPool.Range(func(k, v interface{}) bool { data := v.(WaitMsg) if data.Time.Add(time.Duration(c.noFinSyncMsgMaxKeepSeconds) * time.Second).Before(now) { close(data.Reply) c.noFinSyncMsgPool.Delete(k) } return true }) } } } func (c *ClientCommon) SkipExchangeKey() bool { return c.skipKeyExchange } func (c *ClientCommon) SetSkipExchangeKey(val bool) { c.skipKeyExchange = val } func (c *ClientCommon) clientPostInit() error { go c.readMessage() go c.loadMessage() if !c.skipKeyExchange { err := c.keyExchangeFn(c) if err != nil { c.alive.Store(false) c.mu.Lock() c.status = Status{ Alive: false, Reason: "key exchange failed", Err: err, } c.mu.Unlock() c.stopFn() return err } } return nil } func NewClient() Client { var client = ClientCommon{ maxReadTimeout: 0, maxWriteTimeout: 0, sequenceEn: encode, sequenceDe: Decode, keyExchangeFn: aesRsaHello, SecretKey: defaultAesKey, handshakeRsaPubKey: defaultRsaPubKey, msgEn: defaultMsgEn, msgDe: defaultMsgDe, } client.alive.Store(false) //heartbeat should not controlable for user client.useHeartBeat = true client.heartbeatPeriod = time.Second * 20 client.linkFns = make(map[string]func(*Message)) client.defaultFns = func(message *Message) { return } client.wg = stario.NewWaitGroup(0) client.stopCtx, client.stopFn = context.WithCancel(context.Background()) return &client } func (c *ClientCommon) Heartbeat() { failedCount := 0 for { select { case <-c.stopCtx.Done(): return case <-time.After(c.heartbeatPeriod): } _, err := c.sendWait(TransferMsg{ ID: 10000, Key: "heartbeat", Value: nil, Type: MSG_SYS_WAIT, }, time.Second*5) if err == nil { c.lastHeartbeat = time.Now().Unix() failedCount = 0 } if c.debugMode { fmt.Println("failed to recv heartbeat,timeout!") } failedCount++ if failedCount >= 3 { if c.debugMode { fmt.Println("heatbeat failed more than 3 times,stop client") } c.alive.Store(false) c.mu.Lock() c.status = Status{ Alive: false, Reason: "heartbeat failed more than 3 times", Err: errors.New("heartbeat failed more than 3 times"), } c.mu.Unlock() c.stopFn() return } } } func (c *ClientCommon) ShowError(std bool) { c.mu.Lock() c.showError = std c.mu.Unlock() } func (c *ClientCommon) readMessage() { for { select { case <-c.stopCtx.Done(): c.conn.Close() return default: } data := make([]byte, 8192) if c.maxReadTimeout.Seconds() != 0 { if err := c.conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout)); err != nil { //TODO:ALERT } } readNum, err := c.conn.Read(data) if err == os.ErrDeadlineExceeded { if readNum != 0 { c.queue.ParseMessage(data[:readNum], "b612") } continue } if err != nil { if c.showError || c.debugMode { fmt.Println("client read error", err) } c.alive.Store(false) c.mu.Lock() c.status = Status{ Alive: false, Reason: "client read error", Err: err, } c.mu.Unlock() c.stopFn() continue } c.queue.ParseMessage(data[:readNum], "b612") } } func (c *ClientCommon) sayGoodBye() error { _, err := c.sendWait(TransferMsg{ ID: 10010, Key: "bye", Value: nil, Type: MSG_SYS_WAIT, }, time.Second*3) return err } func (c *ClientCommon) loadMessage() { for { select { case <-c.stopCtx.Done(): //say goodbye if !c.byeFromServer { c.sayGoodBye() } c.conn.Close() return case data, ok := <-c.queue.RestoreChan(): if !ok { continue } c.wg.Add(1) go func(data starnet.MsgQueue) { defer c.wg.Done() //fmt.Println("c received:", float64(time.Now().UnixNano()-nowd)/1000000) now := time.Now() //transfer to Msg msg, err := c.sequenceDe(c.msgDe(c.SecretKey, data.Msg)) if err != nil { if c.showError || c.debugMode { fmt.Println("client decode data error", err) } return } message := Message{ ServerConn: c, TransferMsg: msg.(TransferMsg), NetType: NET_CLIENT, } message.Time = now c.dispatchMsg(message) }(data) } } } func (c *ClientCommon) dispatchMsg(message Message) { switch message.TransferMsg.Type { case MSG_SYS_WAIT: fallthrough case MSG_SYS: c.sysMsg(message) return case MSG_KEY_CHANGE: fallthrough case MSG_SYS_REPLY: fallthrough case MSG_SYNC_REPLY: data, ok := c.noFinSyncMsgPool.Load(message.ID) if ok { wait := data.(WaitMsg) wait.Reply <- message c.noFinSyncMsgPool.Delete(message.ID) return } //return fallthrough default: } callFn := func(fn func(*Message)) { fn(&message) } fn, ok := c.linkFns[message.Key] if ok { callFn(fn) } if c.defaultFns != nil { callFn(c.defaultFns) } } func (c *ClientCommon) sysMsg(message Message) { switch message.Key { case "bye": if message.TransferMsg.Type == MSG_SYS_WAIT { //fmt.Println("recv stop signal from server") c.byeFromServer = true message.Reply(nil) } c.alive.Store(false) c.mu.Lock() c.status = Status{ Alive: false, Reason: "recv stop signal from server", Err: nil, } c.mu.Unlock() c.stopFn() } } func (c *ClientCommon) SetDefaultLink(fn func(message *Message)) { c.defaultFns = fn } func (c *ClientCommon) SetLink(key string, fn func(*Message)) { c.mu.Lock() defer c.mu.Unlock() c.linkFns[key] = fn } func (c *ClientCommon) send(msg TransferMsg) (WaitMsg, error) { var wait WaitMsg if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { msg.ID = atomic.AddUint64(&c.msgID, 1) } data, err := c.sequenceEn(msg) if err != nil { return WaitMsg{}, err } data = c.msgEn(c.SecretKey, data) data = c.queue.BuildMessage(data) if c.maxWriteTimeout.Seconds() != 0 { c.conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout)) } _, err = c.conn.Write(data) if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_KEY_CHANGE || msg.Type == MSG_SYS_WAIT) { wait.Time = time.Now() wait.TransferMsg = msg wait.Reply = make(chan Message, 1) c.noFinSyncMsgPool.Store(msg.ID, wait) } return wait, err } func (c *ClientCommon) Send(key string, value MsgVal) error { _, err := c.send(TransferMsg{ Key: key, Value: value, Type: MSG_ASYNC, }) return err } func (c *ClientCommon) sendWait(msg TransferMsg, timeout time.Duration) (Message, error) { data, err := c.send(msg) if err != nil { return Message{}, err } if timeout.Seconds() == 0 { msg, ok := <-data.Reply if !ok { return msg, os.ErrInvalid } return msg, nil } select { case <-time.After(timeout): close(data.Reply) c.noFinSyncMsgPool.Delete(data.TransferMsg.ID) return Message{}, os.ErrDeadlineExceeded case <-c.stopCtx.Done(): return Message{}, errors.New("service shutdown") case msg, ok := <-data.Reply: if !ok { return msg, os.ErrInvalid } return msg, nil } } func (c *ClientCommon) sendCtx(msg TransferMsg, ctx context.Context) (Message, error) { data, err := c.send(msg) if err != nil { return Message{}, err } if ctx == nil { ctx = context.Background() } select { case <-ctx.Done(): close(data.Reply) c.noFinSyncMsgPool.Delete(data.TransferMsg.ID) return Message{}, os.ErrDeadlineExceeded case <-c.stopCtx.Done(): return Message{}, errors.New("service shutdown") case msg, ok := <-data.Reply: if !ok { return msg, os.ErrInvalid } return msg, nil } } func (c *ClientCommon) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) { data, err := c.sequenceEn(val) if err != nil { return Message{}, err } return c.sendCtx(TransferMsg{ Key: key, Value: data, Type: MSG_SYNC_ASK, }, ctx) } func (c *ClientCommon) SendObj(key string, val interface{}) error { data, err := encode(val) if err != nil { return err } _, err = c.send(TransferMsg{ Key: key, Value: data, Type: MSG_ASYNC, }) return err } func (c *ClientCommon) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) { return c.sendCtx(TransferMsg{ Key: key, Value: value, Type: MSG_SYNC_ASK, }, ctx) } func (c *ClientCommon) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) { return c.sendWait(TransferMsg{ Key: key, Value: value, Type: MSG_SYNC_ASK, }, timeout) } func (c *ClientCommon) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) { data, err := c.sequenceEn(value) if err != nil { return Message{}, err } return c.SendWait(key, data, timeout) } func (c *ClientCommon) Reply(m Message, value MsgVal) error { return m.Reply(value) } func (c *ClientCommon) ExchangeKey(newKey []byte) error { pubKey, err := starcrypto.DecodeRsaPublicKey(c.handshakeRsaPubKey) if err != nil { return err } newSendKey, err := starcrypto.RSAEncrypt(pubKey, newKey) if err != nil { return err } data, err := c.sendWait(TransferMsg{ ID: 19961127, Key: "sirius", Value: newSendKey, Type: MSG_KEY_CHANGE, }, time.Second*10) if err != nil { return err } if string(data.Value) != "success" { return errors.New("cannot exchange new aes-key") } c.SecretKey = newKey time.Sleep(time.Millisecond * 100) return nil } func aesRsaHello(c Client) error { newAesKey := []byte(fmt.Sprintf("%d%d%d%s", time.Now().UnixNano(), rand.Int63(), rand.Int63(), "b612.me")) newAesKey = []byte(starcrypto.Md5Str(newAesKey)) return c.ExchangeKey(newAesKey) } func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte { return c.msgEn } func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) { c.msgEn = fn } func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte { return c.msgDe } func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) { c.msgDe = fn } func (c *ClientCommon) HeartbeatPeroid() time.Duration { return c.heartbeatPeriod } func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) { c.heartbeatPeriod = duration } func (c *ClientCommon) GetSecretKey() []byte { return c.SecretKey } func (c *ClientCommon) SetSecretKey(key []byte) { c.SecretKey = key } func (c *ClientCommon) RsaPubKey() []byte { return c.handshakeRsaPubKey } func (c *ClientCommon) SetRsaPubKey(key []byte) { c.handshakeRsaPubKey = key } func (c *ClientCommon) Stop() error { if !c.alive.Load().(bool) { return nil } c.alive.Store(false) c.mu.Lock() c.status = Status{ Alive: false, Reason: "recv stop signal from user", Err: nil, } c.mu.Unlock() c.stopFn() return nil } func (c *ClientCommon) StopMonitorChan() <-chan struct{} { return c.stopCtx.Done() } func (c *ClientCommon) Status() Status { return c.status } func (c *ClientCommon) GetSequenceEn() func(interface{}) ([]byte, error) { return c.sequenceEn } func (c *ClientCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) { c.sequenceEn = fn } func (c *ClientCommon) GetSequenceDe() func([]byte) (interface{}, error) { return c.sequenceDe } func (c *ClientCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) { c.sequenceDe = fn }