You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

615 lines
14 KiB
Go

package notify
import (
"b612.me/starcrypto"
"b612.me/stario"
"b612.me/starnet"
"context"
"errors"
"fmt"
"math/rand"
"net"
"os"
"sync"
"sync/atomic"
"time"
)
type ClientCommon struct {
alive atomic.Value
status Status
byeFromServer bool
conn net.Conn
mu sync.Mutex
msgID uint64
queue *starnet.StarQueue
stopFn context.CancelFunc
stopCtx context.Context
parallelNum int
maxReadTimeout time.Duration
maxWriteTimeout time.Duration
keyExchangeFn func(c Client) error
linkFns map[string]func(message *Message)
defaultFns func(message *Message)
msgEn func([]byte, []byte) []byte
msgDe func([]byte, []byte) []byte
noFinSyncMsgPool sync.Map
handshakeRsaPubKey []byte
SecretKey []byte
noFinSyncMsgMaxKeepSeconds int
lastHeartbeat int64
heartbeatPeriod time.Duration
wg stario.WaitGroup
netType NetType
showError bool
skipKeyExchange bool
useHeartBeat bool
sequenceDe func([]byte) (interface{}, error)
sequenceEn func(interface{}) ([]byte, error)
debugMode bool
}
func (c *ClientCommon) Connect(network string, addr string) error {
if c.alive.Load().(bool) {
return errors.New("client already run")
}
c.stopCtx, c.stopFn = context.WithCancel(context.Background())
c.queue = starnet.NewQueueCtx(c.stopCtx, 4)
conn, err := net.Dial(network, addr)
if err != nil {
return err
}
c.alive.Store(true)
c.status.Alive = true
c.conn = conn
if c.useHeartBeat {
go c.Heartbeat()
}
return c.clientPostInit()
}
func (c *ClientCommon) DebugMode(dmg bool) {
c.mu.Lock()
c.debugMode = dmg
c.mu.Unlock()
}
func (c *ClientCommon) IsDebugMode() bool {
return c.debugMode
}
func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error {
if c.alive.Load().(bool) {
return errors.New("client already run")
}
c.stopCtx, c.stopFn = context.WithCancel(context.Background())
c.queue = starnet.NewQueueCtx(c.stopCtx, 4)
conn, err := net.DialTimeout(network, addr, timeout)
if err != nil {
return err
}
c.alive.Store(true)
c.status.Alive = true
c.conn = conn
if c.useHeartBeat {
go c.Heartbeat()
}
return c.clientPostInit()
}
func (c *ClientCommon) monitorPool() {
for {
select {
case <-c.stopCtx.Done():
c.noFinSyncMsgPool.Range(func(k, v interface{}) bool {
data := v.(WaitMsg)
close(data.Reply)
c.noFinSyncMsgPool.Delete(k)
return true
})
return
case <-time.After(time.Second * 30):
}
now := time.Now()
if c.noFinSyncMsgMaxKeepSeconds > 0 {
c.noFinSyncMsgPool.Range(func(k, v interface{}) bool {
data := v.(WaitMsg)
if data.Time.Add(time.Duration(c.noFinSyncMsgMaxKeepSeconds) * time.Second).Before(now) {
close(data.Reply)
c.noFinSyncMsgPool.Delete(k)
}
return true
})
}
}
}
func (c *ClientCommon) SkipExchangeKey() bool {
return c.skipKeyExchange
}
func (c *ClientCommon) SetSkipExchangeKey(val bool) {
c.skipKeyExchange = val
}
func (c *ClientCommon) clientPostInit() error {
go c.readMessage()
go c.loadMessage()
if !c.skipKeyExchange {
err := c.keyExchangeFn(c)
if err != nil {
c.alive.Store(false)
c.mu.Lock()
c.status = Status{
Alive: false,
Reason: "key exchange failed",
Err: err,
}
c.mu.Unlock()
c.stopFn()
return err
}
}
return nil
}
func NewClient() Client {
var client = ClientCommon{
maxReadTimeout: 0,
maxWriteTimeout: 0,
sequenceEn: encode,
sequenceDe: Decode,
keyExchangeFn: aesRsaHello,
SecretKey: defaultAesKey,
handshakeRsaPubKey: defaultRsaPubKey,
msgEn: defaultMsgEn,
msgDe: defaultMsgDe,
}
client.alive.Store(false)
//heartbeat should not controlable for user
client.useHeartBeat = true
client.heartbeatPeriod = time.Second * 20
client.linkFns = make(map[string]func(*Message))
client.defaultFns = func(message *Message) {
return
}
client.wg = stario.NewWaitGroup(0)
client.stopCtx, client.stopFn = context.WithCancel(context.Background())
return &client
}
func (c *ClientCommon) Heartbeat() {
failedCount := 0
for {
select {
case <-c.stopCtx.Done():
return
case <-time.After(c.heartbeatPeriod):
}
_, err := c.sendWait(TransferMsg{
ID: 10000,
Key: "heartbeat",
Value: nil,
Type: MSG_SYS_WAIT,
}, time.Second*5)
if err == nil {
c.lastHeartbeat = time.Now().Unix()
failedCount = 0
}
if c.debugMode {
fmt.Println("failed to recv heartbeat,timeout!")
}
failedCount++
if failedCount >= 3 {
if c.debugMode {
fmt.Println("heatbeat failed more than 3 times,stop client")
}
c.alive.Store(false)
c.mu.Lock()
c.status = Status{
Alive: false,
Reason: "heartbeat failed more than 3 times",
Err: errors.New("heartbeat failed more than 3 times"),
}
c.mu.Unlock()
c.stopFn()
return
}
}
}
func (c *ClientCommon) ShowError(std bool) {
c.mu.Lock()
c.showError = std
c.mu.Unlock()
}
func (c *ClientCommon) readMessage() {
for {
select {
case <-c.stopCtx.Done():
c.conn.Close()
return
default:
}
data := make([]byte, 8192)
if c.maxReadTimeout.Seconds() != 0 {
if err := c.conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout)); err != nil {
//TODO:ALERT
}
}
readNum, err := c.conn.Read(data)
if err == os.ErrDeadlineExceeded {
if readNum != 0 {
c.queue.ParseMessage(data[:readNum], "b612")
}
continue
}
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client read error", err)
}
c.alive.Store(false)
c.mu.Lock()
c.status = Status{
Alive: false,
Reason: "client read error",
Err: err,
}
c.mu.Unlock()
c.stopFn()
continue
}
c.queue.ParseMessage(data[:readNum], "b612")
}
}
func (c *ClientCommon) sayGoodBye() error {
_, err := c.sendWait(TransferMsg{
ID: 10010,
Key: "bye",
Value: nil,
Type: MSG_SYS_WAIT,
}, time.Second*3)
return err
}
func (c *ClientCommon) loadMessage() {
for {
select {
case <-c.stopCtx.Done():
//say goodbye
if !c.byeFromServer {
c.sayGoodBye()
}
c.conn.Close()
return
case data, ok := <-c.queue.RestoreChan():
if !ok {
continue
}
c.wg.Add(1)
go func(data starnet.MsgQueue) {
defer c.wg.Done()
//fmt.Println("c received:", float64(time.Now().UnixNano()-nowd)/1000000)
now := time.Now()
//transfer to Msg
msg, err := c.sequenceDe(c.msgDe(c.SecretKey, data.Msg))
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client decode data error", err)
}
return
}
message := Message{
ServerConn: c,
TransferMsg: msg.(TransferMsg),
NetType: NET_CLIENT,
}
message.Time = now
c.dispatchMsg(message)
}(data)
}
}
}
func (c *ClientCommon) dispatchMsg(message Message) {
switch message.TransferMsg.Type {
case MSG_SYS_WAIT:
fallthrough
case MSG_SYS:
c.sysMsg(message)
return
case MSG_KEY_CHANGE:
fallthrough
case MSG_SYS_REPLY:
fallthrough
case MSG_SYNC_REPLY:
data, ok := c.noFinSyncMsgPool.Load(message.ID)
if ok {
wait := data.(WaitMsg)
wait.Reply <- message
c.noFinSyncMsgPool.Delete(message.ID)
return
}
//return
fallthrough
default:
}
callFn := func(fn func(*Message)) {
fn(&message)
}
fn, ok := c.linkFns[message.Key]
if ok {
callFn(fn)
}
if c.defaultFns != nil {
callFn(c.defaultFns)
}
}
func (c *ClientCommon) sysMsg(message Message) {
switch message.Key {
case "bye":
if message.TransferMsg.Type == MSG_SYS_WAIT {
//fmt.Println("recv stop signal from server")
c.byeFromServer = true
message.Reply(nil)
}
c.alive.Store(false)
c.mu.Lock()
c.status = Status{
Alive: false,
Reason: "recv stop signal from server",
Err: nil,
}
c.mu.Unlock()
c.stopFn()
}
}
func (c *ClientCommon) SetDefaultLink(fn func(message *Message)) {
c.defaultFns = fn
}
func (c *ClientCommon) SetLink(key string, fn func(*Message)) {
c.mu.Lock()
defer c.mu.Unlock()
c.linkFns[key] = fn
}
func (c *ClientCommon) send(msg TransferMsg) (WaitMsg, error) {
var wait WaitMsg
if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 {
msg.ID = atomic.AddUint64(&c.msgID, 1)
}
data, err := c.sequenceEn(msg)
if err != nil {
return WaitMsg{}, err
}
data = c.msgEn(c.SecretKey, data)
data = c.queue.BuildMessage(data)
if c.maxWriteTimeout.Seconds() != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
}
_, err = c.conn.Write(data)
if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_KEY_CHANGE || msg.Type == MSG_SYS_WAIT) {
wait.Time = time.Now()
wait.TransferMsg = msg
wait.Reply = make(chan Message, 1)
c.noFinSyncMsgPool.Store(msg.ID, wait)
}
return wait, err
}
func (c *ClientCommon) Send(key string, value MsgVal) error {
_, err := c.send(TransferMsg{
Key: key,
Value: value,
Type: MSG_ASYNC,
})
return err
}
func (c *ClientCommon) sendWait(msg TransferMsg, timeout time.Duration) (Message, error) {
data, err := c.send(msg)
if err != nil {
return Message{}, err
}
if timeout.Seconds() == 0 {
msg, ok := <-data.Reply
if !ok {
return msg, os.ErrInvalid
}
return msg, nil
}
select {
case <-time.After(timeout):
close(data.Reply)
c.noFinSyncMsgPool.Delete(data.TransferMsg.ID)
return Message{}, os.ErrDeadlineExceeded
case <-c.stopCtx.Done():
return Message{}, errors.New("service shutdown")
case msg, ok := <-data.Reply:
if !ok {
return msg, os.ErrInvalid
}
return msg, nil
}
}
func (c *ClientCommon) sendCtx(msg TransferMsg, ctx context.Context) (Message, error) {
data, err := c.send(msg)
if err != nil {
return Message{}, err
}
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
close(data.Reply)
c.noFinSyncMsgPool.Delete(data.TransferMsg.ID)
return Message{}, os.ErrDeadlineExceeded
case <-c.stopCtx.Done():
return Message{}, errors.New("service shutdown")
case msg, ok := <-data.Reply:
if !ok {
return msg, os.ErrInvalid
}
return msg, nil
}
}
func (c *ClientCommon) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) {
data, err := c.sequenceEn(val)
if err != nil {
return Message{}, err
}
return c.sendCtx(TransferMsg{
Key: key,
Value: data,
Type: MSG_SYNC_ASK,
}, ctx)
}
func (c *ClientCommon) SendObj(key string, val interface{}) error {
data, err := encode(val)
if err != nil {
return err
}
_, err = c.send(TransferMsg{
Key: key,
Value: data,
Type: MSG_ASYNC,
})
return err
}
func (c *ClientCommon) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) {
return c.sendCtx(TransferMsg{
Key: key,
Value: value,
Type: MSG_SYNC_ASK,
}, ctx)
}
func (c *ClientCommon) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) {
return c.sendWait(TransferMsg{
Key: key,
Value: value,
Type: MSG_SYNC_ASK,
}, timeout)
}
func (c *ClientCommon) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) {
data, err := c.sequenceEn(value)
if err != nil {
return Message{}, err
}
return c.SendWait(key, data, timeout)
}
func (c *ClientCommon) Reply(m Message, value MsgVal) error {
return m.Reply(value)
}
func (c *ClientCommon) ExchangeKey(newKey []byte) error {
pubKey, err := starcrypto.DecodePublicKey(c.handshakeRsaPubKey)
if err != nil {
return err
}
newSendKey, err := starcrypto.RSAEncrypt(pubKey, newKey)
if err != nil {
return err
}
data, err := c.sendWait(TransferMsg{
ID: 19961127,
Key: "sirius",
Value: newSendKey,
Type: MSG_KEY_CHANGE,
}, time.Second*10)
if err != nil {
return err
}
if string(data.Value) != "success" {
return errors.New("cannot exchange new aes-key")
}
c.SecretKey = newKey
time.Sleep(time.Millisecond * 100)
return nil
}
func aesRsaHello(c Client) error {
newAesKey := []byte(fmt.Sprintf("%d%d%d%s", time.Now().UnixNano(), rand.Int63(), rand.Int63(), "b612.me"))
newAesKey = []byte(starcrypto.Md5Str(newAesKey))
return c.ExchangeKey(newAesKey)
}
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
return c.msgEn
}
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
c.msgEn = fn
}
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
return c.msgDe
}
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
c.msgDe = fn
}
func (c *ClientCommon) HeartbeatPeroid() time.Duration {
return c.heartbeatPeriod
}
func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) {
c.heartbeatPeriod = duration
}
func (c *ClientCommon) GetSecretKey() []byte {
return c.SecretKey
}
func (c *ClientCommon) SetSecretKey(key []byte) {
c.SecretKey = key
}
func (c *ClientCommon) RsaPubKey() []byte {
return c.handshakeRsaPubKey
}
func (c *ClientCommon) SetRsaPubKey(key []byte) {
c.handshakeRsaPubKey = key
}
func (c *ClientCommon) Stop() error {
if !c.alive.Load().(bool) {
return nil
}
c.alive.Store(false)
c.mu.Lock()
c.status = Status{
Alive: false,
Reason: "recv stop signal from user",
Err: nil,
}
c.mu.Unlock()
c.stopFn()
return nil
}
func (c *ClientCommon) StopMonitorChan() <-chan struct{} {
return c.stopCtx.Done()
}
func (c *ClientCommon) Status() Status {
return c.status
}
func (c *ClientCommon) GetSequenceEn() func(interface{}) ([]byte, error) {
return c.sequenceEn
}
func (c *ClientCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) {
c.sequenceEn = fn
}
func (c *ClientCommon) GetSequenceDe() func([]byte) (interface{}, error) {
return c.sequenceDe
}
func (c *ClientCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) {
c.sequenceDe = fn
}