From 555bc3653eca0f1fb34c39ee0538a13d64077322 Mon Sep 17 00:00:00 2001 From: starainrt Date: Fri, 12 Nov 2021 16:04:39 +0800 Subject: [PATCH] v2 version release --- client.go | 817 ++++++++++++++--------- clienttype.go | 47 ++ default.go | 86 +++ msg.go | 472 +++++++++++++ serialization.go | 12 +- server.go | 995 +++++++++++++++------------- servertype.go | 46 ++ starnotify/define.go | 82 +-- v1/client.go | 394 +++++++++++ client_test.go => v1/client_test.go | 0 v1/serialization.go | 37 ++ v1/server.go | 534 +++++++++++++++ v1/starnotify/define.go | 103 +++ v1/v2cs_test.go | 51 ++ v2cs_test.go | 146 ++++ 15 files changed, 3021 insertions(+), 801 deletions(-) create mode 100644 clienttype.go create mode 100644 default.go create mode 100644 msg.go create mode 100644 servertype.go create mode 100644 v1/client.go rename client_test.go => v1/client_test.go (100%) create mode 100644 v1/serialization.go create mode 100644 v1/server.go create mode 100644 v1/starnotify/define.go create mode 100644 v1/v2cs_test.go create mode 100644 v2cs_test.go diff --git a/client.go b/client.go index 7d3de6e..16efbdc 100644 --- a/client.go +++ b/client.go @@ -1,377 +1,570 @@ package notify import ( + "b612.me/starcrypto" + "b612.me/stario" + "b612.me/starnet" "context" "errors" "fmt" "math/rand" "net" - "strings" + "os" "sync" + "sync/atomic" "time" - - "b612.me/starcrypto" - "b612.me/starnet" ) -// StarNotifyC 为Client端 -type StarNotifyC struct { - Connc net.Conn - dialTimeout time.Duration - clientSign map[string]chan string - mu sync.Mutex - // FuncLists 当不使用channel时,使用此记录调用函数 - FuncLists map[string]func(CMsg) - stopSign context.Context - cancel context.CancelFunc - defaultFunc func(CMsg) - // UseChannel 是否使用channel作为信息传递 - UseChannel bool - isUDP bool - Sync bool - // Queue 是用来处理收发信息的简单消息队列 - Queue *starnet.StarQueue - // Online 当前链接是否处于活跃状态 - Online bool - lockPool map[string]CMsg - aesKey []byte -} - -// CMsg 指明当前客户端被通知的关键字 -type CMsg struct { - Key string - Value string - mode string - wait chan int -} - -func (star *StarNotifyC) starinitc() { - builder := starnet.NewQueue() - builder.EncodeFunc = encodeFunc - builder.DecodeFunc = decodeFunc - builder.Encode = true - star.stopSign, star.cancel = context.WithCancel(context.Background()) - star.Queue = builder - star.FuncLists = make(map[string]func(CMsg)) - star.UseChannel = false - star.clientSign = make(map[string]chan string) - star.Online = false - star.lockPool = make(map[string]CMsg) - star.Queue.RestoreDuration(time.Millisecond * 50) -} - -func (star *StarNotifyC) SetAesKey(key []byte) { - star.aesKey = key - star.Queue.EncodeFunc = func(data []byte) []byte { - return starcrypto.AesEncryptCFB(data, key) - } - star.Queue.DecodeFunc = func(data []byte) []byte { - return starcrypto.AesDecryptCFB(data, key) - } -} - -func (star *StarNotifyC) GetAesKey() []byte { - if len(star.aesKey) == 0 { - return aesKey - } - return star.aesKey -} - -// Notify 用于获取一个通知 -func (star *StarNotifyC) Notify(key string) chan string { - if _, ok := star.clientSign[key]; !ok { - ch := make(chan string, 20) - star.mu.Lock() - star.clientSign[key] = ch - star.mu.Unlock() - } - return star.clientSign[key] -} - -func (star *StarNotifyC) store(key, value string) { - if _, ok := star.clientSign[key]; !ok { - ch := make(chan string, 20) - ch <- value - star.mu.Lock() - star.clientSign[key] = ch - star.mu.Unlock() - return +//var nowd int64 +type ClientCommon struct { + alive atomic.Value + status Status + byeFromServer bool + conn net.Conn + mu sync.Mutex + msgID uint64 + queue *starnet.StarQueue + stopFn context.CancelFunc + stopCtx context.Context + parallelNum int + maxReadTimeout time.Duration + maxWriteTimeout time.Duration + keyExchangeFn func(c Client) error + linkFns map[string]func(message *Message) + defaultFns func(message *Message) + msgEn func([]byte, []byte) []byte + msgDe func([]byte, []byte) []byte + noFinSyncMsgPool sync.Map + handshakeRsaPubKey []byte + SecretKey []byte + noFinSyncMsgMaxKeepSeconds int + lastHeartbeat int64 + heartbeatPeriod time.Duration + wg stario.WaitGroup + netType NetType + skipKeyExchange bool + useHeartBeat bool + sequenceDe func([]byte) (interface{}, error) + sequenceEn func(interface{}) ([]byte, error) +} + +func (c *ClientCommon) Connect(network string, addr string) error { + if c.alive.Load().(bool) { + return errors.New("client already run") + } + c.stopCtx, c.stopFn = context.WithCancel(context.Background()) + c.queue = starnet.NewQueueCtx(c.stopCtx, 4) + conn, err := net.Dial(network, addr) + if err != nil { + return err + } + c.alive.Store(true) + c.status.Alive = true + c.conn = conn + if c.useHeartBeat { + go c.Heartbeat() } - star.clientSign[key] <- value + return c.clientPostInit() } -func NewNotifyCWithTimeOut(netype, value string, timeout time.Duration) (*StarNotifyC, error) { - var err error - var star StarNotifyC - star.starinitc() - star.isUDP = false - if strings.Index(netype, "udp") >= 0 { - star.isUDP = true + +func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error { + if c.alive.Load().(bool) { + return errors.New("client already run") } - star.Connc, err = net.DialTimeout(netype, value, timeout) + c.stopCtx, c.stopFn = context.WithCancel(context.Background()) + c.queue = starnet.NewQueueCtx(c.stopCtx, 4) + conn, err := net.DialTimeout(network, addr, timeout) if err != nil { - return nil, err - } - star.dialTimeout = timeout - go star.cnotify() - go func() { - <-star.stopSign.Done() - star.Connc.Close() - star.Online = false - return - }() - go func() { - for { - buf := make([]byte, 8192) - n, err := star.Connc.Read(buf) - if n != 0 { - star.Queue.ParseMessage(buf[0:n], star.Connc) - } - if err != nil { - star.Connc.Close() - star.ClientStop() - //star, _ = NewNotifyC(netype, value) - star.Online = false - return - } + return err + } + c.alive.Store(true) + c.status.Alive = true + c.conn = conn + if c.useHeartBeat { + go c.Heartbeat() + } + return c.clientPostInit() +} + +func (c *ClientCommon) monitorPool() { + for { + select { + case <-c.stopCtx.Done(): + c.noFinSyncMsgPool.Range(func(k, v interface{}) bool { + data := v.(WaitMsg) + close(data.Reply) + c.noFinSyncMsgPool.Delete(k) + return true + }) + return + case <-time.After(time.Second * 30): + } + now := time.Now() + if c.noFinSyncMsgMaxKeepSeconds > 0 { + c.noFinSyncMsgPool.Range(func(k, v interface{}) bool { + data := v.(WaitMsg) + if data.Time.Add(time.Duration(c.noFinSyncMsgMaxKeepSeconds) * time.Second).Before(now) { + close(data.Reply) + c.noFinSyncMsgPool.Delete(k) + } + return true + }) } - }() - star.Online = true - return &star, nil + } } -// NewNotifyC 用于新建一个Client端进程 -func NewNotifyC(netype, value string) (*StarNotifyC, error) { - var err error - var star StarNotifyC - star.starinitc() - star.isUDP = false - if strings.Index(netype, "udp") >= 0 { - star.isUDP = true +func (c *ClientCommon) SkipExchangeKey() bool { + return c.skipKeyExchange +} + +func (c *ClientCommon) SetSkipExchangeKey(val bool) { + c.skipKeyExchange = val +} + +func (c *ClientCommon) clientPostInit() error { + go c.readMessage() + go c.loadMessage() + if !c.skipKeyExchange { + err := c.keyExchangeFn(c) + if err != nil { + c.alive.Store(false) + c.mu.Lock() + c.status = Status{ + Alive: false, + Reason: "key exchange failed", + Err: err, + } + c.mu.Unlock() + c.stopFn() + return err + } } - star.Connc, err = net.Dial(netype, value) - if err != nil { - return nil, err + return nil +} +func NewClient() Client { + var client = ClientCommon{ + maxReadTimeout: 0, + maxWriteTimeout: 0, + sequenceEn: encode, + sequenceDe: Decode, + keyExchangeFn: aesRsaHello, + SecretKey: defaultAesKey, + handshakeRsaPubKey: defaultRsaPubKey, + msgEn: defaultMsgEn, + msgDe: defaultMsgDe, } - go star.cnotify() - go func() { - <-star.stopSign.Done() - star.Connc.Close() - star.Online = false + client.alive.Store(false) + //heartbeat should not controlable for user + client.useHeartBeat = true + client.heartbeatPeriod = time.Second * 20 + client.linkFns = make(map[string]func(*Message)) + client.defaultFns = func(message *Message) { return - }() - go func() { - for { - buf := make([]byte, 8192) - n, err := star.Connc.Read(buf) - if n != 0 { - star.Queue.ParseMessage(buf[0:n], star.Connc) - } - if err != nil { - star.Connc.Close() - star.ClientStop() - //star, _ = NewNotifyC(netype, value) - star.Online = false - return + } + client.wg = stario.NewWaitGroup(0) + client.stopCtx, client.stopFn = context.WithCancel(context.Background()) + return &client +} + +func (c *ClientCommon) Heartbeat() { + failedCount := 0 + for { + select { + case <-c.stopCtx.Done(): + return + case <-time.After(c.heartbeatPeriod): + } + _, err := c.sendWait(TransferMsg{ + ID: 10000, + Key: "heartbeat", + Value: nil, + Type: MSG_SYS_WAIT, + }, time.Second*5) + if err == nil { + c.lastHeartbeat = time.Now().Unix() + failedCount = 0 + } + failedCount++ + if failedCount >= 3 { + //fmt.Println("heatbeat failed,stop client") + c.alive.Store(false) + c.mu.Lock() + c.status = Status{ + Alive: false, + Reason: "heartbeat failed more than 3 times", + Err: errors.New("heartbeat failed more than 3 times"), } + c.mu.Unlock() + c.stopFn() + return } - }() - star.Online = true - return &star, nil + } } -// Send 用于向Server端发送数据 -func (star *StarNotifyC) Send(name string) error { - return star.SendValue(name, "") +func (c *ClientCommon) readMessage() { + for { + select { + case <-c.stopCtx.Done(): + c.conn.Close() + return + default: + } + data := make([]byte, 8192) + if c.maxReadTimeout.Seconds() != 0 { + c.conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout)) + } + readNum, err := c.conn.Read(data) + if err == os.ErrDeadlineExceeded { + if readNum != 0 { + c.queue.ParseMessage(data[:readNum], "b612") + } + continue + } + if err != nil { + fmt.Println("client read error", err) + c.alive.Store(false) + c.mu.Lock() + c.status = Status{ + Alive: false, + Reason: "client read error", + Err: err, + } + c.mu.Unlock() + c.stopFn() + continue + } + c.queue.ParseMessage(data[:readNum], "b612") + } } -func (star *StarNotifyC) Stoped() <-chan struct{} { - return star.stopSign.Done() +func (c *ClientCommon) sayGoodBye() error { + _, err := c.sendWait(TransferMsg{ + ID: 10010, + Key: "bye", + Value: nil, + Type: MSG_SYS_WAIT, + }, time.Second*3) + return err } -func (star *StarNotifyC) SendValueRaw(key string, msg interface{}) error { - encodeData, err := encode(msg) - if err != nil { - return err +func (c *ClientCommon) loadMessage() { + for { + select { + case <-c.stopCtx.Done(): + //say goodbye + if !c.byeFromServer { + c.sayGoodBye() + } + c.conn.Close() + return + case data, ok := <-c.queue.RestoreChan(): + if !ok { + continue + } + c.wg.Add(1) + go func(data starnet.MsgQueue) { + defer c.wg.Done() + //fmt.Println("c received:", float64(time.Now().UnixNano()-nowd)/1000000) + now := time.Now() + //transfer to Msg + msg, err := c.sequenceDe(c.msgDe(c.SecretKey, data.Msg)) + if err != nil { + fmt.Println("client decode data error", err) + return + } + message := Message{ + ServerConn: c, + TransferMsg: msg.(TransferMsg), + NetType: NET_CLIENT, + } + message.Time = now + c.dispatchMsg(message) + }(data) + } } - return star.SendValue(key, string(encodeData)) } -// SendValue 用于向Server端发送key-value类型数据 -func (star *StarNotifyC) SendValue(name, value string) error { - var err error - var key []byte - for _, v := range []byte(name) { - if v == byte(124) || v == byte(92) { - key = append(key, byte(92)) +func (c *ClientCommon) dispatchMsg(message Message) { + switch message.TransferMsg.Type { + case MSG_SYS_WAIT: + fallthrough + case MSG_SYS: + c.sysMsg(message) + return + case MSG_KEY_CHANGE: + fallthrough + case MSG_SYS_REPLY: + fallthrough + case MSG_SYNC_REPLY: + data, ok := c.noFinSyncMsgPool.Load(message.ID) + if ok { + wait := data.(WaitMsg) + wait.Reply <- message + c.noFinSyncMsgPool.Delete(message.ID) + return } - key = append(key, v) + return + //fallthrough + default: + } + callFn := func(fn func(*Message)) { + fn(&message) + } + fn, ok := c.linkFns[message.Key] + if ok { + callFn(fn) + } + if c.defaultFns != nil { + callFn(c.defaultFns) } - _, err = star.Connc.Write(star.Queue.BuildMessage([]byte("pa" + "||" + string(key) + "||" + value))) - return err } -func (star *StarNotifyC) trim(name string) string { - var slash bool = false - var key []byte - for _, v := range []byte(name) { - if v == byte(92) && !slash { - slash = true - continue +func (c *ClientCommon) sysMsg(message Message) { + switch message.Key { + case "bye": + if message.TransferMsg.Type == MSG_SYS_WAIT { + //fmt.Println("recv stop signal from server") + c.byeFromServer = true + message.Reply(nil) } - slash = false - key = append(key, v) + c.alive.Store(false) + c.mu.Lock() + c.status = Status{ + Alive: false, + Reason: "recv stop signal from server", + Err: nil, + } + c.mu.Unlock() + c.stopFn() } - return string(key) } -func (star *StarNotifyC) SendValueWaitRaw(key string, msg interface{}, tmout time.Duration) (CMsg, error) { - encodeData, err := encode(msg) + +func (c *ClientCommon) SetDefaultLink(fn func(message *Message)) { + c.defaultFns = fn +} + +func (c *ClientCommon) SetLink(key string, fn func(*Message)) { + c.mu.Lock() + defer c.mu.Unlock() + c.linkFns[key] = fn +} + +func (c *ClientCommon) send(msg TransferMsg) (WaitMsg, error) { + var wait WaitMsg + if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { + msg.ID = atomic.AddUint64(&c.msgID, 1) + } + data, err := c.sequenceEn(msg) if err != nil { - return CMsg{}, err + return WaitMsg{}, err } - return star.SendValueWait(key, string(encodeData), tmout) + data = c.msgEn(c.SecretKey, data) + data = c.queue.BuildMessage(data) + if c.maxWriteTimeout.Seconds() != 0 { + c.conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout)) + } + _, err = c.conn.Write(data) + if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_KEY_CHANGE || msg.Type == MSG_SYS_WAIT) { + wait.Time = time.Now() + wait.TransferMsg = msg + wait.Reply = make(chan Message, 1) + c.noFinSyncMsgPool.Store(msg.ID, wait) + } + return wait, err +} + +func (c *ClientCommon) Send(key string, value MsgVal) error { + _, err := c.send(TransferMsg{ + Key: key, + Value: value, + Type: MSG_ASYNC, + }) + return err } -// SendValueWait 用于向Server端发送key-value类型数据并等待结果返回,此结果不会通过标准返回流程处理 -func (star *StarNotifyC) SendValueWait(name, value string, tmout time.Duration) (CMsg, error) { - var err error - var tmceed <-chan time.Time - if star.UseChannel { - return CMsg{}, errors.New("Do Not Use UseChannel Mode!") +func (c *ClientCommon) sendWait(msg TransferMsg, timeout time.Duration) (Message, error) { + data, err := c.send(msg) + if err != nil { + return Message{}, err } - rand.Seed(time.Now().UnixNano()) - mode := "cr" + fmt.Sprintf("%d%06d", time.Now().UnixNano(), rand.Intn(999999)) - var key []byte - for _, v := range []byte(name) { - if v == byte(124) || v == byte(92) { - key = append(key, byte(92)) + if timeout.Seconds() == 0 { + msg, ok := <-data.Reply + if !ok { + return msg, os.ErrInvalid } - key = append(key, v) + return msg, nil } - _, err = star.Connc.Write(star.Queue.BuildMessage([]byte(mode + "||" + string(key) + "||" + value))) + select { + case <-time.After(timeout): + close(data.Reply) + c.noFinSyncMsgPool.Delete(data.TransferMsg.ID) + return Message{}, os.ErrDeadlineExceeded + case msg, ok := <-data.Reply: + if !ok { + return msg, os.ErrInvalid + } + return msg, nil + } +} + +func (c *ClientCommon) sendCtx(msg TransferMsg, ctx context.Context) (Message, error) { + data, err := c.send(msg) if err != nil { - return CMsg{}, err + return Message{}, err } - if int64(tmout) > 0 { - tmceed = time.After(tmout) + if ctx == nil { + ctx = context.Background() } - var source CMsg - source.wait = make(chan int, 2) - star.mu.Lock() - star.lockPool[mode] = source - star.mu.Unlock() select { - case <-source.wait: - res := star.lockPool[mode] - star.mu.Lock() - delete(star.lockPool, mode) - star.mu.Unlock() - return res, nil - case <-tmceed: - return CMsg{}, errors.New("Time Exceed") - } -} - -// ReplyMsg 用于向Server端Reply信息 -func (star *StarNotifyC) ReplyMsg(data CMsg, name, value string) error { - var err error - var key []byte - for _, v := range []byte(name) { - if v == byte(124) || v == byte(92) { - key = append(key, byte(92)) + case <-ctx.Done(): + close(data.Reply) + c.noFinSyncMsgPool.Delete(data.TransferMsg.ID) + return Message{}, os.ErrDeadlineExceeded + case msg, ok := <-data.Reply: + if !ok { + return msg, os.ErrInvalid } - key = append(key, v) + return msg, nil + } +} + +func (c *ClientCommon) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) { + data, err := c.sequenceEn(val) + if err != nil { + return Message{}, err + } + return c.sendCtx(TransferMsg{ + Key: key, + Value: data, + Type: MSG_SYNC_ASK, + }, ctx) +} + +func (c *ClientCommon) SendObj(key string, val interface{}) error { + data, err := encode(val) + if err != nil { + return err } - _, err = star.Connc.Write(star.Queue.BuildMessage([]byte(data.mode + "||" + string(key) + "||" + value))) + _, err = c.send(TransferMsg{ + Key: key, + Value: data, + Type: MSG_ASYNC, + }) return err } -func (star *StarNotifyC) cnotify() { - for { - select { - case <-star.stopSign.Done(): - return - default: - } - data, err := star.Queue.RestoreOne() - if err != nil { - time.Sleep(time.Millisecond * 500) - continue - } - if string(data.Msg) == "b612ryzstop" { - star.ClientStop() - star.Online = false - return - } - strs := strings.SplitN(string(data.Msg), "||", 3) - if len(strs) < 3 { - continue - } - strs[1] = star.trim(strs[1]) - if star.UseChannel { - go star.store(strs[1], strs[2]) - } else { - mode, key, value := strs[0], strs[1], strs[2] - if mode[0:2] != "cr" { - if msg, ok := star.FuncLists[key]; ok { - if star.Sync { - msg(CMsg{key, value, mode, nil}) - } else { - go msg(CMsg{key, value, mode, nil}) - } - } else { - if star.defaultFunc != nil { - if star.Sync { - star.defaultFunc(CMsg{key, value, mode, nil}) - } else { - go star.defaultFunc(CMsg{key, value, mode, nil}) - } - } - } - } else { - if sa, ok := star.lockPool[mode]; ok { - sa.Key = key - sa.Value = value - sa.mode = mode - star.mu.Lock() - star.lockPool[mode] = sa - star.mu.Unlock() - sa.wait <- 1 - } else { - if msg, ok := star.FuncLists[key]; ok { - if star.Sync { - msg(CMsg{key, value, mode, nil}) - } else { - go msg(CMsg{key, value, mode, nil}) - } - } else { - if star.defaultFunc != nil { - if star.Sync { - star.defaultFunc(CMsg{key, value, mode, nil}) - } else { - go star.defaultFunc(CMsg{key, value, mode, nil}) - } - } - } - } - } - } +func (c *ClientCommon) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) { + return c.sendCtx(TransferMsg{ + Key: key, + Value: value, + Type: MSG_SYNC_ASK, + }, ctx) +} + +func (c *ClientCommon) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) { + return c.sendWait(TransferMsg{ + Key: key, + Value: value, + Type: MSG_SYNC_ASK, + }, timeout) +} + +func (c *ClientCommon) Reply(m Message, value MsgVal) error { + return m.Reply(value) +} + +func (c *ClientCommon) ExchangeKey(newKey []byte) error { + newSendKey, err := starcrypto.RSAEncrypt(newKey, c.handshakeRsaPubKey) + if err != nil { + return err + } + data, err := c.sendWait(TransferMsg{ + ID: 19961127, + Key: "sirius", + Value: newSendKey, + Type: MSG_KEY_CHANGE, + }, time.Second*10) + if err != nil { + return err + } + if string(data.Value) != "success" { + return errors.New("cannot exchange new aes-key") } + c.SecretKey = newKey + time.Sleep(time.Millisecond * 100) + return nil } -// ClientStop 终止client端运行 -func (star *StarNotifyC) ClientStop() { - if star.isUDP { - star.Send("b612ryzstop") +func aesRsaHello(c Client) error { + newAesKey := []byte(fmt.Sprintf("%d%d%d%s", time.Now().UnixNano(), rand.Int63(), rand.Int63(), "b612.me")) + newAesKey = []byte(starcrypto.Md5Str(newAesKey)) + return c.ExchangeKey(newAesKey) +} + +func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte { + return c.msgEn +} +func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) { + c.msgEn = fn +} +func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte { + return c.msgDe +} +func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) { + c.msgDe = fn +} + +func (c *ClientCommon) HeartbeatPeroid() time.Duration { + return c.heartbeatPeriod +} +func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) { + c.heartbeatPeriod = duration +} + +func (c *ClientCommon) GetSecretKey() []byte { + return c.SecretKey +} +func (c *ClientCommon) SetSecretKey(key []byte) { + c.SecretKey = key +} +func (c *ClientCommon) RsaPubKey() []byte { + return c.handshakeRsaPubKey +} +func (c *ClientCommon) SetRsaPubKey(key []byte) { + c.handshakeRsaPubKey = key +} +func (c *ClientCommon) Stop() error { + if !c.alive.Load().(bool) { + return nil + } + c.alive.Store(false) + c.mu.Lock() + c.status = Status{ + Alive: false, + Reason: "recv stop signal from user", + Err: nil, } - star.cancel() + c.mu.Unlock() + c.stopFn() + return nil +} +func (c *ClientCommon) StopMonitorChan() <-chan struct{} { + return c.stopCtx.Done() } -// SetNotify 用于设置关键词的调用函数 -func (star *StarNotifyC) SetNotify(name string, data func(CMsg)) { - star.FuncLists[name] = data +func (c *ClientCommon) Status() Status { + return c.status } -// SetDefaultNotify 用于设置默认关键词的调用函数 -func (star *StarNotifyC) SetDefaultNotify(data func(CMsg)) { - star.defaultFunc = data +func (c *ClientCommon) GetSequenceEn() func(interface{}) ([]byte, error) { + return c.sequenceEn +} +func (c *ClientCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) { + c.sequenceEn = fn +} +func (c *ClientCommon) GetSequenceDe() func([]byte) (interface{}, error) { + return c.sequenceDe +} +func (c *ClientCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) { + c.sequenceDe = fn } diff --git a/clienttype.go b/clienttype.go new file mode 100644 index 0000000..30b0551 --- /dev/null +++ b/clienttype.go @@ -0,0 +1,47 @@ +package notify + +import ( + "context" + "time" +) + +type Client interface { + SetDefaultLink(func(message *Message)) + SetLink(string, func(*Message)) + send(msg TransferMsg) (WaitMsg, error) + sendWait(msg TransferMsg, timeout time.Duration) (Message, error) + Send(key string, value MsgVal) error + SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) + SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) + Reply(m Message, value MsgVal) error + ExchangeKey(newKey []byte) error + Connect(network string, addr string) error + ConnectTimeout(network string, addr string, timeout time.Duration) error + SkipExchangeKey() bool + SetSkipExchangeKey(bool) + + GetMsgEn() func([]byte, []byte) []byte + SetMsgEn(func([]byte, []byte) []byte) + GetMsgDe() func([]byte, []byte) []byte + SetMsgDe(func([]byte, []byte) []byte) + + Heartbeat() + HeartbeatPeroid() time.Duration + SetHeartbeatPeroid(duration time.Duration) + + GetSecretKey() []byte + SetSecretKey(key []byte) + RsaPubKey() []byte + SetRsaPubKey([]byte) + + Stop() error + StopMonitorChan() <-chan struct{} + Status() Status + + GetSequenceEn() func(interface{}) ([]byte, error) + SetSequenceEn(func(interface{}) ([]byte, error)) + GetSequenceDe() func([]byte) (interface{}, error) + SetSequenceDe(func([]byte) (interface{}, error)) + SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) + SendObj(key string, val interface{}) error +} diff --git a/default.go b/default.go new file mode 100644 index 0000000..acb2042 --- /dev/null +++ b/default.go @@ -0,0 +1,86 @@ +package notify + +import ( + "b612.me/starcrypto" +) + +var defaultRsaKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIJKAIBAAKCAgEAxmeMqr9yfJFKZn26oe/HvC7bZXNLC9Nk55AuTkb4XuIoqXDb +AJD2Y/p167oJLKIqL3edcj7h+oTfn6s79vxT0ZCEf37ILU0G+scRzVwYHiLMwOUC +bS2o4Xor3zqUi9f1piJBvoBNh8RKKtsmJW6VQZdiUGJHbgX4MdOdtf/6TvxZMwSX +U+PRSCAjy04A31Zi7DEWUWJPyqmHeu++PxXU5lvoMdCGDqpcF2j2uO7oJJUww01M +3F5FtTElMrK4/P9gD4kP7NiPhOfVPEfBsYT/DSSjvqNZJZuWnxu+cDxE7J/sBvdp +eNRLhqzdmMYagZFuUmVrz8QmsD6jKHgydW+r7irllvb8WJPK/RIMif+4Rg7rDKFb +j8+ZQ3HZ/gKELoRSyb3zL6RC2qlGLjC1tdeN7TNTinCv092y39T8jIARJ7tpfePh +NBxsBdxfXbCAzHYZIHufI9Zlsc+felQwanlDhq+q8YLcnKHvNKYVyCf/upExpAiA +rr88y/KbeKes0KorKkwMBnGUMTothWM25wHozcurixNvP4UMWX7LWD7vOZZuNDQN +utZYeTwdsniI3mTO9vlPWEK8JTfxBU7x9SePUMJNDyjfDUJM8C2DOlyhGNPkgazO +GdliH87tHkEy/7jJnGclgKmciiVPgwHfFx9GGoBHEfvmAoGGrk4qNbjm7JECAwEA +AQKCAgBYzHe05ELFZfG6tYMWf08R9pbTbSqlfFOpIGrZNgJr1SUF0TDzq+3bCXpF +qtn4VAw1en/JZkOV8Gp1+Bm6jWymWtwyg/fr7pG1I+vf0dwpgMHLg7P2UX1IjXmd +S4a4oEuds69hJ+OLZFsdm0ATeM7ssGicOaBmqd1Pz7rCfnL1bxQtNVzVex1r/paG +o77YNr3HoKCwhCPaPM4aQ7sOWSMUhwYBZabaYX0eLShf1O2pkexlPO+tobPpSLmx +WzRYZ6QC0AGEq9hwT6KsfCFA5pmQtFllNY7suhpL1AsECLWAgoMNCyb1oW68NBpq +CiBK5WBPGH2MW+pE74Pu1P0gen6kLGnApKQjprE1aGuR+xkZe3uEnXwSryU9TXki +wINTEMsX8dkmofFqaJhUwSubrb+t7gvv9E9ZZe0X6UgKzAVVqvh4z1pP8VT+xHpu +pW7SR8n9cFddaEPUijSb1rSpJrNzfJJ+G7yrB7Cw2kBgQ07vzD3z/3kA9cwFevLS +mv3l3OQuB6y9c+AG3cX5WGAt/BVOLjimj9qJt+YglG0SwG31U0PUnnx6QVz/UtJm +CbJQ2TpJd+mk0HyuMU+eycp7BWF3PMN+SE4QgKCKWnhsLeAd3gcvifsbLOYE1OPg +wv1tqyJy0VsJiSn6Ub6Qq0kPLwCLlQTnLWk5mIhnRpHYufTSwQKCAQEA4gS4FKPU +tAcQ82dEYW4OjGfhNWrjFpF+A8K5zufleQWcgzQ3fQho13zH0vZobukfkEVlVxla +OIVk7ZgNA4mCSFrATjIx3RMqzrAUvTte0O4wkjYgCwVvTdS1W8nvRLKgugLygyoo +r+MLW5IT3eNMK/2fZbftNlAkbc7NCo3c2tS6MXFgjx5JUuzChOY73Kp4p5KS38L5 +wRRiI8KTIKjBjMZ5q/l8VLKX89bKOCaWibmItoXY6QMbIjargb7YLp3X6uGEyGIu +VhPbQ80/+OC2ZqIvDecp4PYnJNZFeqfjyfhJCNqDjBKYwIscBLMU/Wf9OY258OR4 +snQaerN1M0h9lQKCAQEA4LkZIRLLw+8bIVM+7VXxFwOAGy+MH35tvuNIToItAoUh +zjL5LG34PjID8J0DPyP8VRVanak1EcxF0aTEkvnt2f2RAVsW89ytcn8Lybb12Ae8 +ia2ZWuIM+J40nuKOGPs3lJ9HqdPWmZYWsWKxFJmYBBnwD6CADYqhqambQn0HeaYl +/WUD7blLYg+4Kk1mt9/hIw93jTWP/86O2H0ia+AhYPTqyvVXfIXKhat6NlOYksGf +Hdv+aCC8Ukg6FyEgiNc/rFn0MWPnEX+cM1AwubviHIBhV8QWILLBTjupwsEBZVah +60ftH+HRUCmEeOpI7jyzIlfEUNLoBHfswKMhMPtcDQKCAQEA0JFkQX+xn/PJW6PX +AUWrXTvbIg0hw8i9DcFa76klJBnehWDhN5tUDE5Uo8PJOVgdTWgMjWSS0geezHX8 +xF/XfudoAIDnbMfsP9FTQhCQfaLf5XzW8vSv8pWwSiS9jJp+IUjo+8siwrR03aqe +dKr0tr+ToS0qVG1+QGqO4gdpX/LgYxHp9ggPx9s94aAIa6hQMOrcaGqnSNqDedZr +KL8x5LOewek3J32rJVP3Rfut/SfeFfjL4rKADoF+oPs4yUPVZSV4/+VCNyKZuyaj +uwm6qFlPrLe9+J+OHbsxYG+fj9hzpRzoOZFLrppwX5HWc8XLcpnrlXVwP9VOPh5u +r8VcRQKCAQAJFHGHfJLvH8Ig3pQ0UryjCWkrsAghXaJhjB1nzqqy514uTrDysp7N +JIg0OKPg8TtI1MwMgsG6Ll7D0bx/k8mgfTZWr6+FuuznK2r2g4X7bJSZm4IOwgN0 +KDBIGy9SoxPj1Wu32O9a1U2lbS9qfao+wC2K9Bk4ctmFWW0Eiri6mZP/YQ1/lXUO +SURPsUDtPQaDvCRAeGGRHG95H9U8NpoiqMKz4KXgSiecrwkJGOeZRml/c1wcKPZy +/KgcNyJxZQEVnazYMgksE9Pj3uGZH5ZLQISuXyXlvFNDLfX2AIZl6dIxB371QtKK +QqMvn4fC2IEEajdsbJkjVRUj03OL3xwhAoIBAAfMhDSvBbDkGTaXnNMjPPSbswqK +qcSRhSG27mjs1dDNBKuFbz6TkIOp4nxjuS9Zp19fErXlAE9mF5yXSmuiAkZmWfhs +HKpWIdjFJK1EqSfcINe2YuoyUIulz9oG7ObRHD4D8jSPjA8Ete+XsBHGyOtUl09u +X4u9uClhqjK+r1Tno2vw5yF6ZxfQtdWuL4W0UL1S8E+VO7vjTjNOYvgjAIpAM/gW +sqjA2Qw52UZqhhLXoTfRvtJilxlXXhIRJSsnUoGiYVCQ/upjqJCClEvJfIWdGY/U +I2CbFrwJcNvOG1lUsSM55JUmbrSWVPfo7yq2k9GCuFxOy2n/SVlvlQUcNkA= +-----END RSA PRIVATE KEY-----`) + +var defaultRsaPubKey = []byte(`-----BEGIN PUBLIC KEY----- +MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAxmeMqr9yfJFKZn26oe/H +vC7bZXNLC9Nk55AuTkb4XuIoqXDbAJD2Y/p167oJLKIqL3edcj7h+oTfn6s79vxT +0ZCEf37ILU0G+scRzVwYHiLMwOUCbS2o4Xor3zqUi9f1piJBvoBNh8RKKtsmJW6V +QZdiUGJHbgX4MdOdtf/6TvxZMwSXU+PRSCAjy04A31Zi7DEWUWJPyqmHeu++PxXU +5lvoMdCGDqpcF2j2uO7oJJUww01M3F5FtTElMrK4/P9gD4kP7NiPhOfVPEfBsYT/ +DSSjvqNZJZuWnxu+cDxE7J/sBvdpeNRLhqzdmMYagZFuUmVrz8QmsD6jKHgydW+r +7irllvb8WJPK/RIMif+4Rg7rDKFbj8+ZQ3HZ/gKELoRSyb3zL6RC2qlGLjC1tdeN +7TNTinCv092y39T8jIARJ7tpfePhNBxsBdxfXbCAzHYZIHufI9Zlsc+felQwanlD +hq+q8YLcnKHvNKYVyCf/upExpAiArr88y/KbeKes0KorKkwMBnGUMTothWM25wHo +zcurixNvP4UMWX7LWD7vOZZuNDQNutZYeTwdsniI3mTO9vlPWEK8JTfxBU7x9SeP +UMJNDyjfDUJM8C2DOlyhGNPkgazOGdliH87tHkEy/7jJnGclgKmciiVPgwHfFx9G +GoBHEfvmAoGGrk4qNbjm7JECAwEAAQ== +-----END PUBLIC KEY-----`) + +var defaultAesKey = []byte{0x19, 0x96, 0x11, 0x27, 228, 187, 187, 231, 142, 137, 230, 179, 189, 229, 184, 133} + +func defaultMsgEn(key []byte, d []byte) []byte { + return starcrypto.AesEncryptCFB(d, key) +} + +func defaultMsgDe(key []byte, d []byte) []byte { + return starcrypto.AesDecryptCFB(d, key) +} + +func init() { + Register(TransferMsg{}) +} diff --git a/msg.go b/msg.go new file mode 100644 index 0000000..745a2eb --- /dev/null +++ b/msg.go @@ -0,0 +1,472 @@ +package notify + +import ( + "b612.me/starcrypto" + "context" + "errors" + "fmt" + "net" + "os" + "sync" + "sync/atomic" + "time" +) + +const ( + MSG_SYS MessageType = iota + MSG_SYS_WAIT + MSG_SYS_REPLY + MSG_KEY_CHANGE + MSG_ASYNC + MSG_SYNC_ASK + MSG_SYNC_REPLY +) + +type MessageType uint8 + +type NetType uint8 + +const ( + NET_SERVER NetType = iota + NET_CLIENT +) + +type MsgVal []byte +type TransferMsg struct { + ID uint64 + Key string + Value MsgVal + Type MessageType +} + +type Message struct { + NetType + ClientConn *ClientConn + ServerConn Client + TransferMsg + Time time.Time + sync.Mutex +} + +type WaitMsg struct { + TransferMsg + Time time.Time + Reply chan Message + //Ctx context.Context +} + +func (m *Message) Reply(value MsgVal) (err error) { + reply := TransferMsg{ + ID: m.ID, + Key: m.Key, + Value: value, + Type: m.Type, + } + if reply.Type == MSG_SYNC_ASK { + reply.Type = MSG_SYNC_REPLY + } + if reply.Type == MSG_SYS_WAIT { + reply.Type = MSG_SYS_REPLY + } + if m.NetType == NET_SERVER { + _, err = m.ClientConn.server.send(m.ClientConn, reply) + } + if m.NetType == NET_CLIENT { + _, err = m.ServerConn.send(reply) + } + return +} + +func (m *Message) ReplyObj(value interface{}) (err error) { + data, err := encode(value) + if err != nil { + return err + } + return m.Reply(data) +} + +type ClientConn struct { + alive atomic.Value + status Status + ClientID string + ClientAddr net.Addr + tuConn net.Conn + server Server + stopFn context.CancelFunc + stopCtx context.Context + maxReadTimeout time.Duration + maxWriteTimeout time.Duration + msgEn func([]byte, []byte) []byte + msgDe func([]byte, []byte) []byte + handshakeRsaKey []byte + SecretKey []byte + lastHeartBeat int64 +} + +type Status struct { + Alive bool + Reason string + Err error +} + +func (c *ClientConn) readTUMessage() { + for { + select { + case <-c.stopCtx.Done(): + c.tuConn.Close() + c.server.removeClient(c) + return + default: + } + if c.maxReadTimeout.Seconds() > 0 { + c.tuConn.SetReadDeadline(time.Now().Add(c.maxReadTimeout)) + } + data := make([]byte, 8192) + num, err := c.tuConn.Read(data) + if err == os.ErrDeadlineExceeded { + if num != 0 { + c.server.pushMessage(data[:num], c.ClientID) + } + continue + } + if err != nil { + //conn is broke + c.alive.Store(false) + c.status = Status{ + Alive: false, + Reason: "read error", + Err: err, + } + c.stopFn() + continue + } + c.server.pushMessage(data[:num], c.ClientID) + //fmt.Println("finished:", float64(time.Now().UnixNano()-nowd)/1000000) + } +} + +func (c *ClientConn) rsaDecode(message Message) { + unknownKey := message.Value + data, err := starcrypto.RSADecrypt(unknownKey, c.handshakeRsaKey, "") + if err != nil { + fmt.Println(err) + message.Reply([]byte("failed")) + return + } + //fmt.Println("aes-key changed to", string(data)) + message.Reply([]byte("success")) + c.SecretKey = data +} + +func (c *ClientConn) sayGoodByeForTU() error { + _, err := c.server.sendWait(c, TransferMsg{ + ID: 10010, + Key: "bye", + Value: nil, + Type: MSG_SYS_WAIT, + }, time.Second*3) + return err +} + +func (c *ClientConn) GetSecretKey() []byte { + return c.SecretKey +} +func (c *ClientConn) SetSecretKey(key []byte) { + c.SecretKey = key +} + +func (c *ClientConn) GetMsgEn() func([]byte, []byte) []byte { + return c.msgEn +} +func (c *ClientConn) SetMsgEn(fn func([]byte, []byte) []byte) { + c.msgEn = fn +} +func (c *ClientConn) GetMsgDe() func([]byte, []byte) []byte { + return c.msgDe +} +func (c *ClientConn) SetMsgDe(fn func([]byte, []byte) []byte) { + c.msgDe = fn +} + +func (c *ClientConn) StopMonitorChan() <-chan struct{} { + return c.stopCtx.Done() +} + +func (c *ClientConn) Status() Status { + return c.status +} + +func (c *ClientConn) Server() Server { + return c.server +} + +func (c *ClientConn) GetRemoteAddr() net.Addr { + return c.ClientAddr +} + +func (m MsgVal) ToClearString() string { + return string(m) +} + +func (m MsgVal) ToInterface() (interface{}, error) { + return Decode(m) +} + +func (m MsgVal) MustToInterface() interface{} { + inf, err := m.ToInterface() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToString() (string, error) { + inf, err := m.ToInterface() + if err != nil { + return "", err + } + if data, ok := inf.(string); !ok { + return "", errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToString() string { + inf, err := m.ToString() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToInt32() (int32, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(int32); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToInt32() int32 { + inf, err := m.ToInt32() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToInt() (int, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(int); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToInt() int { + inf, err := m.ToInt() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToUint64() (uint64, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(uint64); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToUint64() uint64 { + inf, err := m.ToUint64() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToUint32() (uint32, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(uint32); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToUint32() uint32 { + inf, err := m.ToUint32() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToUint() (uint, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(uint); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToUint() uint { + inf, err := m.ToUint() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToBool() (bool, error) { + inf, err := m.ToInterface() + if err != nil { + return false, err + } + if data, ok := inf.(bool); !ok { + return false, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToBool() bool { + inf, err := m.ToBool() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToFloat64() (float64, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(float64); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToFloat64() float64 { + inf, err := m.ToFloat64() + if err != nil { + panic(err) + } + return inf +} +func (m MsgVal) ToFloat32() (float32, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(float32); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToFloat32() float32 { + inf, err := m.ToFloat32() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToSliceString() ([]string, error) { + inf, err := m.ToInterface() + if err != nil { + return []string{}, err + } + if data, ok := inf.([]string); !ok { + return []string{}, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToSliceString() []string { + inf, err := m.ToSliceString() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToSliceInt64() ([]int64, error) { + inf, err := m.ToInterface() + if err != nil { + return []int64{}, err + } + if data, ok := inf.([]int64); !ok { + return []int64{}, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToSliceInt64() []int64 { + inf, err := m.ToSliceInt64() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToSliceFloat64() ([]float64, error) { + inf, err := m.ToInterface() + if err != nil { + return []float64{}, err + } + if data, ok := inf.([]float64); !ok { + return []float64{}, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToSliceFloat64() []float64 { + inf, err := m.ToSliceFloat64() + if err != nil { + panic(err) + } + return inf +} + +func ToMsgVal(val interface{}) (MsgVal, error) { + return Encode(val) +} + +func MustToMsgVal(val interface{}) MsgVal { + d, err := ToMsgVal(val) + if err != nil { + panic(err) + } + return d +} diff --git a/serialization.go b/serialization.go index c10d889..99e611d 100644 --- a/serialization.go +++ b/serialization.go @@ -21,17 +21,13 @@ func encode(src interface{}) ([]byte, error) { return buf.Bytes(), err } +func Encode(src interface{}) ([]byte, error) { + return encode(src) +} + func Decode(src []byte) (interface{}, error) { dec := gob.NewDecoder(bytes.NewReader(src)) var dst interface{} err := dec.Decode(&dst) return dst, err } - -func (nmsg *SMsg) Decode() (interface{}, error) { - return Decode([]byte(nmsg.Value)) -} - -func (nmsg *CMsg) Decode() (interface{}, error) { - return Decode([]byte(nmsg.Value)) -} diff --git a/server.go b/server.go index 55bd6c9..87b68be 100644 --- a/server.go +++ b/server.go @@ -1,534 +1,643 @@ -// Package notify is a package which provide common tcp/udp/unix socket service package notify import ( + "b612.me/stario" + "b612.me/starnet" "context" "errors" "fmt" "math/rand" "net" + "os" "strings" "sync" + "sync/atomic" "time" - - "b612.me/starcrypto" - - "b612.me/starnet" ) -var aesKey = []byte{0x19, 0x96, 0x11, 0x27, 228, 187, 187, 231, 142, 137, 230, 179, 189, 229, 184, 133} - -func encodeFunc(data []byte) []byte { - return starcrypto.AesEncryptCFB(data, aesKey) -} - -func decodeFunc(data []byte) []byte { - return starcrypto.AesDecryptCFB(data, aesKey) -} - -// StarNotifyS 为Server端 -type StarNotifyS struct { - // Queue 是用来处理收发信息的简单消息队列 - Queue *starnet.StarQueue - // FuncLists 记录了被通知项所记录的函数 - aesKey []byte - FuncLists map[string]func(SMsg) string - funcMu sync.Mutex - defaultFunc func(SMsg) string - Connected func(SMsg) - nickName map[string]string - stopSign context.Context - cancel context.CancelFunc - connPool sync.Map - connMu sync.Mutex - lockPool map[string]SMsg - lockMu sync.Mutex - udpPool map[string]*net.UDPAddr - listener net.Listener - isUDP bool - Sync bool - // UDPConn UDP监听 - UDPConn *net.UDPConn - // Online 当前链接是否处于活跃状态 - Online bool - // ReadDeadline tcp/unix中读超时设置,udp请直接调用UDPConn - ReadDeadline time.Time - // WriteDeadline tcp/unix中写超时设置,udp请直接调用UDPConn - WriteDeadline time.Time - - // Deadline tcp/unix中超时设置,udp请直接调用UDPConn - Deadline time.Time -} - -// SMsg 指明当前服务端被通知的关键字 -type SMsg struct { - Conn net.Conn - Key string - Value string - UDP *net.UDPAddr - uconn *net.UDPConn - mode string - wait chan int - nickName func(string, string) error - getName func(string) string - queue *starnet.StarQueue -} - -func (star *StarNotifyS) SetAesKey(key []byte) { - star.aesKey = key - star.Queue.EncodeFunc = func(data []byte) []byte { - return starcrypto.AesEncryptCFB(data, key) - } - star.Queue.DecodeFunc = func(data []byte) []byte { - return starcrypto.AesDecryptCFB(data, key) - } -} - -func (star *StarNotifyS) GetAesKey() []byte { - if len(star.aesKey) == 0 { - return aesKey - } - return star.aesKey -} - -func (star *StarNotifyS) getName(conn string) string { - for k, v := range star.nickName { - if v == conn { - return k - } - } - return "" -} -func (star *StarNotifyS) Stoped() <-chan struct{} { - return star.stopSign.Done() +type ServerCommon struct { + msgID uint64 + alive atomic.Value + status Status + listener net.Listener + udpListener *net.UDPConn + queue *starnet.StarQueue + stopFn context.CancelFunc + stopCtx context.Context + maxReadTimeout time.Duration + maxWriteTimeout time.Duration + parallelNum int + wg stario.WaitGroup + clientPool map[string]*ClientConn + mu sync.RWMutex + handshakeRsaKey []byte + SecretKey []byte + defaultMsgEn func([]byte, []byte) []byte + defaultMsgDe func([]byte, []byte) []byte + linkFns map[string]func(message *Message) + defaultFns func(message *Message) + noFinSyncMsgPool sync.Map + noFinSyncMsgMaxKeepSeconds int64 + maxHeartbeatLostSeconds int64 + sequenceDe func([]byte) (interface{}, error) + sequenceEn func(interface{}) ([]byte, error) } -// GetConnPool 获取所有Client端信息 -func (star *StarNotifyS) GetConnPool() []SMsg { - var result []SMsg - star.connPool.Range(func(k, val interface{}) bool { - v := val.(net.Conn) - result = append(result, SMsg{Conn: v, mode: "pa", nickName: star.setNickName, getName: star.getName, queue: star.Queue}) - return true - }) - for _, v := range star.udpPool { - result = append(result, SMsg{UDP: v, uconn: star.UDPConn, mode: "pa0", nickName: star.setNickName, getName: star.getName, queue: star.Queue}) +func NewServer() Server { + var server ServerCommon + server.wg = stario.NewWaitGroup(0) + server.parallelNum = 0 + server.noFinSyncMsgMaxKeepSeconds = 0 + server.maxHeartbeatLostSeconds = 300 + server.stopCtx, server.stopFn = context.WithCancel(context.Background()) + server.SecretKey = defaultAesKey + server.handshakeRsaKey = defaultRsaKey + server.clientPool = make(map[string]*ClientConn) + server.defaultMsgEn = defaultMsgEn + server.defaultMsgDe = defaultMsgDe + server.sequenceEn = encode + server.sequenceDe = Decode + server.alive.Store(false) + server.linkFns = make(map[string]func(*Message)) + server.defaultFns = func(message *Message) { + return } - return result + return &server } -// GetClient 获取所有Client端信息 -func (star *StarNotifyS) GetClient(name string) (SMsg, error) { - if str, ok := star.nickName[name]; ok { - if tmp, ok := star.connPool.Load(str); ok { - conn := tmp.(net.Conn) - return SMsg{Conn: conn, mode: "pa", nickName: star.setNickName, getName: star.getName, queue: star.Queue}, nil - } - if conn, ok := star.udpPool[str]; ok { - return SMsg{UDP: conn, uconn: star.UDPConn, mode: "pa0", nickName: star.setNickName, getName: star.getName, queue: star.Queue}, nil - } +func (s *ServerCommon) Stop() error { + if !s.alive.Load().(bool) { + return nil + } + s.alive.Store(false) + s.mu.Lock() + s.status = Status{ + Alive: false, + Reason: "recv stop signal from user", + Err: nil, } - return SMsg{}, errors.New("Not Found") + s.mu.Unlock() + s.stopFn() + return nil } - -func (nmsg *SMsg) GetName() string { - if nmsg.uconn != nil { - return nmsg.getName(nmsg.UDP.String()) +func (s *ServerCommon) Listen(network string, addr string) error { + if s.alive.Load().(bool) { + return errors.New("server already run") + } + s.stopCtx, s.stopFn = context.WithCancel(context.Background()) + s.queue = starnet.NewQueueCtx(s.stopCtx, 128) + if strings.Contains(strings.ToLower(network), "udp") { + return s.ListenUDP(network, addr) } - return nmsg.getName(fmt.Sprint(nmsg.Conn)) + return s.ListenTU(network, addr) } -func (nmsg *SMsg) SetName(name string) error { - if nmsg.uconn != nil { - return nmsg.nickName(name, nmsg.UDP.String()) +func (s *ServerCommon) ListenTU(network string, addr string) error { + listener, err := net.Listen(network, addr) + if err != nil { + return err } - return nmsg.nickName(name, fmt.Sprint(nmsg.Conn)) + s.alive.Store(true) + s.status.Alive = true + s.listener = listener + go s.accept() + go s.monitorPool() + go s.loadMessage() + return nil } -func (nmsg *SMsg) addSlash(name string) string { - var key []byte - for _, v := range []byte(name) { - if v == byte(124) || v == byte(92) { - key = append(key, byte(92)) +func (s *ServerCommon) monitorPool() { + for { + select { + case <-s.stopCtx.Done(): + s.noFinSyncMsgPool.Range(func(k, v interface{}) bool { + data := v.(WaitMsg) + close(data.Reply) + s.noFinSyncMsgPool.Delete(k) + return true + }) + return + case <-time.After(time.Second * 30): + } + now := time.Now() + if s.noFinSyncMsgMaxKeepSeconds > 0 { + s.noFinSyncMsgPool.Range(func(k, v interface{}) bool { + data := v.(WaitMsg) + if data.Time.Add(time.Duration(s.noFinSyncMsgMaxKeepSeconds) * time.Second).Before(now) { + close(data.Reply) + s.noFinSyncMsgPool.Delete(k) + } + return true + }) + } + if s.maxHeartbeatLostSeconds != 0 { + for _, v := range s.clientPool { + if now.Unix()-v.lastHeartBeat > s.maxHeartbeatLostSeconds { + v.stopFn() + s.removeClient(v) + } + } } - key = append(key, v) } - return string(key) } -func (nmsg *SMsg) ReplyRaw(msg interface{}) error { - encodeData, err := encode(msg) - if err != nil { - return err - } - return nmsg.Reply(string(encodeData)) +func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) { + s.defaultMsgEn = fn } -// Reply 用于向client端回复数据 -func (nmsg *SMsg) Reply(msg string) error { - var err error - if nmsg.uconn == nil { - _, err = nmsg.Conn.Write(nmsg.queue.BuildMessage([]byte(nmsg.mode + "||" + nmsg.addSlash(nmsg.Key) + "||" + msg))) - } else { - _, err = nmsg.uconn.WriteToUDP(nmsg.queue.BuildMessage([]byte(nmsg.mode+"||"+nmsg.addSlash(nmsg.Key)+"||"+msg)), nmsg.UDP) - } - return err +func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) { + s.defaultMsgDe = fn } -// Send 用于向client端发送key-value数据 -func (nmsg *SMsg) Send(key, value string) error { - var err error - if nmsg.uconn == nil { - _, err = nmsg.Conn.Write(nmsg.queue.BuildMessage([]byte("pa||" + nmsg.addSlash(key) + "||" + value))) - } else { - _, err = nmsg.uconn.WriteToUDP(nmsg.queue.BuildMessage([]byte("pa||"+nmsg.addSlash(key)+"||"+value)), nmsg.UDP) - } - return err +func (s *ServerCommon) SetDefaultLink(fn func(message *Message)) { + s.defaultFns = fn } -func (nmsg *SMsg) SendRaw(key string, msg interface{}) error { - encodeData, err := encode(msg) - if err != nil { - return err - } - return nmsg.Send(key, string(encodeData)) +func (s *ServerCommon) SetLink(key string, fn func(*Message)) { + s.mu.Lock() + defer s.mu.Unlock() + s.linkFns[key] = fn } -func (star *StarNotifyS) SendWaitRaw(source SMsg, key string, msg interface{}, tmout time.Duration) (SMsg, error) { - encodeData, err := encode(msg) - if err != nil { - return SMsg{}, err - } - return star.SendWait(source, key, string(encodeData), tmout) +func (s *ServerCommon) pushMessage(data []byte, source string) { + s.queue.ParseMessage(data, source) } -// SendWait 用于向client端发送key-value数据,并等待 -func (star *StarNotifyS) SendWait(source SMsg, key, value string, tmout time.Duration) (SMsg, error) { - var err error - var tmceed <-chan time.Time - rand.Seed(time.Now().UnixNano()) - mode := "sr" + fmt.Sprintf("%d%06d", time.Now().UnixNano(), rand.Intn(999999)) - if source.uconn == nil { - _, err = source.Conn.Write(star.Queue.BuildMessage([]byte(mode + "||" + source.addSlash(key) + "||" + value))) - } else { - _, err = source.uconn.WriteToUDP(star.Queue.BuildMessage([]byte(mode+"||"+source.addSlash(key)+"||"+value)), source.UDP) - } - if err != nil { - return SMsg{}, err - } - if int64(tmout) > 0 { - tmceed = time.After(tmout) - } - source.wait = make(chan int, 2) - star.lockMu.Lock() - star.lockPool[mode] = source - star.lockMu.Unlock() - select { - case <-source.wait: - star.lockMu.Lock() - res := star.lockPool[mode] - delete(star.lockPool, mode) - star.lockMu.Unlock() - return res, nil - case <-tmceed: - return SMsg{}, errors.New("Time Exceed") - } -} - -func (star *StarNotifyS) starinits() { - builder := starnet.NewQueue() - builder.EncodeFunc = encodeFunc - builder.DecodeFunc = decodeFunc - builder.Encode = true - star.stopSign, star.cancel = context.WithCancel(context.Background()) - star.Queue = builder - star.udpPool = make(map[string]*net.UDPAddr) - star.FuncLists = make(map[string]func(SMsg) string) - star.nickName = make(map[string]string) - star.lockPool = make(map[string]SMsg) - star.Online = false - star.Queue.RestoreDuration(time.Millisecond * 50) -} - -// NewNotifyS 开启一个新的Server端通知 -func NewNotifyS(netype, value string) (*StarNotifyS, error) { - if netype[0:3] != "udp" { - return notudps(netype, value) - } - return doudps(netype, value) -} - -func doudps(netype, value string) (*StarNotifyS, error) { - var star StarNotifyS - star.starinits() - star.isUDP = true - udpaddr, err := net.ResolveUDPAddr(netype, value) - if err != nil { - return nil, err +func (s *ServerCommon) removeClient(client *ClientConn) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.clientPool, client.ClientID) +} + +func (s *ServerCommon) accept() { + if s.udpListener != nil { + s.acceptUDP() } - star.UDPConn, err = net.ListenUDP(netype, udpaddr) - if err != nil { - return nil, err - } - go star.notify() - go func() { - <-star.stopSign.Done() - for k, v := range star.udpPool { - star.UDPConn.WriteToUDP(star.Queue.BuildMessage([]byte("b612ryzstop")), v) - star.connMu.Lock() - delete(star.udpPool, k) - star.connMu.Unlock() - for k2, v2 := range star.nickName { - if v2 == k { - delete(star.nickName, k2) - } - } + s.acceptTU() +} +func (s *ServerCommon) acceptTU() { + for { + select { + case <-s.stopCtx.Done(): + return + default: } - star.UDPConn.Close() - star.Online = false - return - }() - go func() { + conn, err := s.listener.Accept() + if err != nil { + fmt.Println("error accept:", err) + continue + } + var id string for { - buf := make([]byte, 8192) - n, addr, err := star.UDPConn.ReadFromUDP(buf) - if n != 0 { - star.Queue.ParseMessage(buf[0:n], addr) - if _, ok := star.udpPool[addr.String()]; !ok { - if star.Connected != nil { - go star.Connected(SMsg{UDP: addr, uconn: star.UDPConn, nickName: star.setNickName, getName: star.getName, queue: star.Queue}) - } - } - star.connMu.Lock() - star.udpPool[addr.String()] = addr - star.connMu.Unlock() - } - if err != nil { + id = fmt.Sprintf("%s%d%d", conn.RemoteAddr().String(), time.Now().UnixNano(), rand.Int63()) + s.mu.RLock() + if _, ok := s.clientPool[id]; ok { + s.mu.RUnlock() continue } + s.mu.RUnlock() + break + } + client := ClientConn{ + ClientID: id, + ClientAddr: conn.RemoteAddr(), + tuConn: conn, + server: s, + maxReadTimeout: s.maxReadTimeout, + maxWriteTimeout: s.maxWriteTimeout, + SecretKey: s.SecretKey, + handshakeRsaKey: s.handshakeRsaKey, + msgEn: s.defaultMsgEn, + msgDe: s.defaultMsgDe, + lastHeartBeat: time.Now().Unix(), + } + client.alive.Store(true) + client.status = Status{ + Alive: true, + Reason: "", + Err: nil, } - }() - star.Online = true - return &star, nil + client.stopCtx, client.stopFn = context.WithCancel(context.Background()) + s.mu.Lock() + s.clientPool[id] = &client + s.mu.Unlock() + go client.readTUMessage() + } } -func notudps(netype, value string) (*StarNotifyS, error) { - var err error - var star StarNotifyS - star.starinits() - star.isUDP = false - star.listener, err = net.Listen(netype, value) - if err != nil { - return nil, err - } - go star.notify() - go func() { - <-star.stopSign.Done() - star.connPool.Range(func(a, b interface{}) bool { - k := a.(string) - v := b.(net.Conn) - v.Close() - star.connPool.Delete(a) - for k2, v2 := range star.nickName { - if v2 == k { - star.funcMu.Lock() - delete(star.nickName, k2) - star.funcMu.Unlock() - } - } - return true - }) - star.listener.Close() - star.Online = false - return - }() - go func() { - for { - conn, err := star.listener.Accept() - if err != nil { - select { - case <-star.stopSign.Done(): - star.listener.Close() - return - default: - continue - } +func (s *ServerCommon) loadMessage() { + for { + select { + case <-s.stopCtx.Done(): + var wg sync.WaitGroup + s.mu.RLock() + for _, v := range s.clientPool { + wg.Add(1) + go func() { + defer wg.Done() + v.sayGoodByeForTU() + v.alive.Store(false) + v.status = Status{ + Alive: false, + Reason: "recv stop signal from server", + Err: nil, + } + v.stopFn() + s.removeClient(v) + }() } - if !star.ReadDeadline.IsZero() { - conn.SetReadDeadline(star.ReadDeadline) + s.mu.RUnlock() + select { + case <-time.After(time.Second * 8): + case <-stario.WaitUntilFinished(func() error { + wg.Wait() + return nil + }): } - if !star.WriteDeadline.IsZero() { - conn.SetWriteDeadline(star.WriteDeadline) + if s.listener != nil { + s.listener.Close() } - if !star.Deadline.IsZero() { - conn.SetDeadline(star.Deadline) + s.wg.Wait() + return + case data, ok := <-s.queue.RestoreChan(): + if !ok { + continue } - go func(conn net.Conn) { - for { - buf := make([]byte, 8192) - n, err := conn.Read(buf) - if n != 0 { - star.Queue.ParseMessage(buf[0:n], conn) - } - if err != nil { - conn.Close() - star.connPool.Delete(fmt.Sprint(conn)) - for k, v := range star.nickName { - if v == fmt.Sprint(conn) { - delete(star.nickName, k) - } - } - break - } + s.wg.Add(1) + go func(data starnet.MsgQueue) { + s.mu.RLock() + cc, ok := s.clientPool[data.Conn.(string)] + s.mu.RUnlock() + if !ok { + return } - }(conn) - star.connPool.Store(fmt.Sprint(conn), conn) - if star.Connected != nil { - go star.Connected(SMsg{Conn: conn, nickName: star.setNickName, getName: star.getName, queue: star.Queue}) - } + //fmt.Println("received:", float64(time.Now().UnixNano()-nowd)/1000000) + msg, err := s.sequenceDe(cc.msgDe(cc.SecretKey, data.Msg)) + if err != nil { + fmt.Println("server decode data error", err) + return + } + //fmt.Println("decoded:", float64(time.Now().UnixNano()-nowd)/1000000) + message := Message{ + NetType: NET_SERVER, + ClientConn: cc, + TransferMsg: msg.(TransferMsg), + } + message.Time = time.Now() + + //fmt.Println("dispatch:", float64(time.Now().UnixNano()-nowd)/1000000) + s.dispatchMsg(message) + }(data) + } + } +} +func (s *ServerCommon) sysMsg(message Message) { + switch message.Key { + case "bye": + fmt.Println("recv stop signal from client", message.ClientConn.ClientID) + if message.TransferMsg.Type == MSG_SYS_WAIT { + message.Reply(nil) + } + message.ClientConn.alive.Store(false) + message.ClientConn.status = Status{ + Alive: false, + Reason: "recv stop signal from client", + Err: nil, + } + message.ClientConn.stopFn() + case "heartbeat": + message.ClientConn.lastHeartBeat = time.Now().Unix() + message.Reply(nil) + } +} + +func (s *ServerCommon) dispatchMsg(message Message) { + defer s.wg.Done() + switch message.TransferMsg.Type { + case MSG_SYS_WAIT: + fallthrough + case MSG_SYS: + s.sysMsg(message) + return + case MSG_KEY_CHANGE: + message.ClientConn.rsaDecode(message) + return + case MSG_SYS_REPLY: + fallthrough + case MSG_SYNC_REPLY: + data, ok := s.noFinSyncMsgPool.Load(message.TransferMsg.ID) + if ok { + wait := data.(WaitMsg) + wait.Reply <- message + s.noFinSyncMsgPool.Delete(message.TransferMsg.ID) + return } - }() - star.Online = true - return &star, nil + //just throw + return + //fallthrough + default: + } + callFn := func(fn func(*Message)) { + fn(&message) + } + fn, ok := s.linkFns[message.TransferMsg.Key] + if ok { + callFn(fn) + } + if s.defaultFns != nil { + callFn(s.defaultFns) + } +} +func (s *ServerCommon) send(c *ClientConn, msg TransferMsg) (WaitMsg, error) { + if s.udpListener != nil { + return s.sendUDP(c, msg) + } + return s.sendTU(c, msg) } -func (star *StarNotifyS) GetListenerInfo() net.Listener { - return star.listener +func (s *ServerCommon) sendTU(c *ClientConn, msg TransferMsg) (WaitMsg, error) { + var wait WaitMsg + if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { + msg.ID = atomic.AddUint64(&s.msgID, 1) + } + data, err := s.sequenceEn(msg) + if err != nil { + return WaitMsg{}, err + } + data = c.msgEn(c.SecretKey, data) + data = s.queue.BuildMessage(data) + if c.maxWriteTimeout.Seconds() != 0 { + c.tuConn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout)) + } + _, err = c.tuConn.Write(data) + //fmt.Println("resend:", float64(time.Now().UnixNano()-nowd)/1000000) + if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_SYS_WAIT) { + wait.Time = time.Now() + wait.TransferMsg = msg + wait.Reply = make(chan Message, 1) + s.noFinSyncMsgPool.Store(msg.ID, wait) + } + return wait, err } -// SetNotify 用于设置通知关键词的调用函数 -func (star *StarNotifyS) setNickName(name string, conn string) error { - if _, ok := star.connPool.Load(conn); !ok { - if _, ok := star.udpPool[conn]; !ok { - return errors.New("Conn Not Found") +func (s *ServerCommon) Send(c *ClientConn, key string, value MsgVal) error { + _, err := s.send(c, TransferMsg{ + Key: key, + Value: value, + Type: MSG_ASYNC, + }) + return err +} +func (s *ServerCommon) sendWait(c *ClientConn, msg TransferMsg, timeout time.Duration) (Message, error) { + data, err := s.send(c, msg) + if err != nil { + return Message{}, err + } + if timeout.Seconds() == 0 { + msg, ok := <-data.Reply + if !ok { + return msg, os.ErrInvalid } + return msg, nil } - for k, v := range star.nickName { - if v == conn { - delete(star.nickName, k) + select { + case <-time.After(timeout): + close(data.Reply) + s.noFinSyncMsgPool.Delete(data.TransferMsg.ID) + return Message{}, os.ErrDeadlineExceeded + case msg, ok := <-data.Reply: + if !ok { + return msg, os.ErrInvalid } + return msg, nil } - star.funcMu.Lock() - star.nickName[name] = conn - star.funcMu.Unlock() - return nil } -// SetNotify 用于设置通知关键词的调用函数 -func (star *StarNotifyS) SetNotify(name string, data func(SMsg) string) { - star.funcMu.Lock() - defer star.funcMu.Unlock() - if data == nil { - if _, ok := star.FuncLists[name]; ok { - delete(star.FuncLists, name) +func (s *ServerCommon) SendCtx(ctx context.Context, c *ClientConn, key string, value MsgVal) (Message, error) { + return s.sendCtx(c, TransferMsg{ + Key: key, + Value: value, + Type: MSG_SYNC_ASK, + }, ctx) +} + +func (s *ServerCommon) sendCtx(c *ClientConn, msg TransferMsg, ctx context.Context) (Message, error) { + data, err := s.send(c, msg) + if err != nil { + return Message{}, err + } + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + close(data.Reply) + s.noFinSyncMsgPool.Delete(data.TransferMsg.ID) + return Message{}, os.ErrClosed + case msg, ok := <-data.Reply: + if !ok { + return msg, os.ErrInvalid } - return + return msg, nil } - star.FuncLists[name] = data +} +func (s *ServerCommon) SendWait(c *ClientConn, key string, value MsgVal, timeout time.Duration) (Message, error) { + return s.sendWait(c, TransferMsg{ + Key: key, + Value: value, + Type: MSG_SYNC_ASK, + }, timeout) } -// SetDefaultNotify 用于设置默认关键词的调用函数 -func (star *StarNotifyS) SetDefaultNotify(data func(SMsg) string) { - star.defaultFunc = data +func (s *ServerCommon) SendObjCtx(ctx context.Context, c *ClientConn, key string, val interface{}) (Message, error) { + data, err := s.sequenceEn(val) + if err != nil { + return Message{}, err + } + return s.sendCtx(c, TransferMsg{ + Key: key, + Value: data, + Type: MSG_SYNC_ASK, + }, ctx) } -func (star *StarNotifyS) trim(name string) string { - var slash bool = false - var key []byte - for _, v := range []byte(name) { - if v == byte(92) && !slash { - slash = true - continue - } - slash = false - key = append(key, v) +func (s *ServerCommon) SendObj(c *ClientConn, key string, val interface{}) error { + data, err := encode(val) + if err != nil { + return err + } + _, err = s.send(c, TransferMsg{ + Key: key, + Value: data, + Type: MSG_ASYNC, + }) + return err +} + +func (s *ServerCommon) Reply(m Message, value MsgVal) error { + return m.Reply(value) +} + +//for udp below + +func (s *ServerCommon) ListenUDP(network string, addr string) error { + udpAddr, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return err + } + listener, err := net.ListenUDP(network, udpAddr) + if err != nil { + return err } - return string(key) + s.alive.Store(true) + s.status.Alive = true + s.udpListener = listener + go s.accept() + go s.monitorPool() + go s.loadMessage() + return nil } -func (star *StarNotifyS) notify() { +func (s *ServerCommon) acceptUDP() { for { select { - case <-star.stopSign.Done(): + case <-s.stopCtx.Done(): return default: } - data, err := star.Queue.RestoreOne() - if err != nil { - time.Sleep(time.Millisecond * 500) - continue - } - mode, key, value := star.analyseData(string(data.Msg)) - if mode == key && mode == value && mode == "" { - continue + if s.maxReadTimeout.Seconds() > 0 { + s.udpListener.SetReadDeadline(time.Now().Add(s.maxReadTimeout)) } - var rmsg SMsg - if !star.isUDP { - rmsg = SMsg{data.Conn.(net.Conn), key, value, nil, nil, mode, nil, star.setNickName, star.getName, star.Queue} - } else { - rmsg = SMsg{nil, key, value, data.Conn.(*net.UDPAddr), star.UDPConn, mode, nil, star.setNickName, star.getName, star.Queue} - if key == "b612ryzstop" { - star.connMu.Lock() - delete(star.udpPool, rmsg.UDP.String()) - star.connMu.Unlock() - for k, v := range star.nickName { - if v == rmsg.UDP.String() { - delete(star.nickName, k) - } - } - continue + data := make([]byte, 4096) + num, addr, err := s.udpListener.ReadFromUDP(data) + id := addr.String() + //fmt.Println("s recv udp:", float64(time.Now().UnixNano()-nowd)/1000000) + s.mu.RLock() + if _, ok := s.clientPool[id]; !ok { + s.mu.RUnlock() + client := ClientConn{ + ClientID: id, + ClientAddr: addr, + server: s, + maxReadTimeout: s.maxReadTimeout, + maxWriteTimeout: s.maxWriteTimeout, + SecretKey: s.SecretKey, + handshakeRsaKey: s.handshakeRsaKey, + msgEn: s.defaultMsgEn, + msgDe: s.defaultMsgDe, + lastHeartBeat: time.Now().Unix(), } + client.stopCtx, client.stopFn = context.WithCancel(context.Background()) + s.mu.Lock() + s.clientPool[id] = &client + s.mu.Unlock() + } else { + s.mu.RUnlock() } - replyFunc := func(key string, rmsg SMsg) { - if msg, ok := star.FuncLists[key]; ok { - sdata := msg(rmsg) - if sdata == "" { - return - } - rmsg.Reply(sdata) - } else { - if star.defaultFunc != nil { - sdata := star.defaultFunc(rmsg) - if sdata == "" { - return - } - rmsg.Reply(sdata) - } + if err == os.ErrDeadlineExceeded { + if num != 0 { + s.pushMessage(data[:num], id) } + continue } - if mode[0:2] != "sr" { - if !star.Sync { - go replyFunc(key, rmsg) - } else { - replyFunc(key, rmsg) - } - } else { - if sa, ok := star.lockPool[mode]; ok { - rmsg.wait = sa.wait - star.lockMu.Lock() - star.lockPool[mode] = rmsg - star.lockPool[mode].wait <- 1 - star.lockMu.Unlock() - } else { - if !star.Sync { - go replyFunc(key, rmsg) - } else { - replyFunc(key, rmsg) - } - } + if err != nil { + continue } + s.pushMessage(data[:num], id) } } -func (star *StarNotifyS) analyseData(msg string) (mode, key, value string) { - slice := strings.SplitN(msg, "||", 3) - if len(slice) < 3 { - return "", "", "" +func (s *ServerCommon) sendUDP(c *ClientConn, msg TransferMsg) (WaitMsg, error) { + var wait WaitMsg + if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { + msg.ID = uint64(time.Now().UnixNano()) + rand.Uint64() + rand.Uint64() + } + data, err := s.sequenceEn(msg) + if err != nil { + return WaitMsg{}, err } - return slice[0], star.trim(slice[1]), slice[2] + data = c.msgEn(c.SecretKey, data) + data = s.queue.BuildMessage(data) + if c.maxWriteTimeout.Seconds() != 0 { + s.udpListener.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout)) + } + _, err = s.udpListener.WriteTo(data, c.ClientAddr) + if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_SYS_WAIT) { + wait.Time = time.Now() + wait.TransferMsg = msg + wait.Reply = make(chan Message, 1) + s.noFinSyncMsgPool.Store(msg.ID, wait) + } + return wait, err +} + +func (s *ServerCommon) StopMonitorChan() <-chan struct{} { + return s.stopCtx.Done() +} + +func (s *ServerCommon) Status() Status { + return s.status +} + +func (s *ServerCommon) GetSecretKey() []byte { + return s.SecretKey +} +func (s *ServerCommon) SetSecretKey(key []byte) { + s.SecretKey = key +} +func (s *ServerCommon) RsaPrivKey() []byte { + return s.handshakeRsaKey +} +func (s *ServerCommon) SetRsaPrivKey(key []byte) { + s.handshakeRsaKey = key +} + +func (s *ServerCommon) GetClient(id string) *ClientConn { + s.mu.RLock() + defer s.mu.RUnlock() + c, ok := s.clientPool[id] + if !ok { + return nil + } + return c +} +func (s *ServerCommon) GetClientLists() []*ClientConn { + s.mu.RLock() + defer s.mu.RUnlock() + var list []*ClientConn = make([]*ClientConn, 0, len(s.clientPool)) + for _, v := range s.clientPool { + list = append(list, v) + } + return list +} + +func (s *ServerCommon) GetClientAddrs() []net.Addr { + s.mu.RLock() + defer s.mu.RUnlock() + var list = make([]net.Addr, 0, len(s.clientPool)) + for _, v := range s.clientPool { + list = append(list, v.ClientAddr) + } + return list +} + +func (s *ServerCommon) GetSequenceEn() func(interface{}) ([]byte, error) { + return s.sequenceEn +} +func (s *ServerCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) { + s.sequenceEn = fn +} +func (s *ServerCommon) GetSequenceDe() func([]byte) (interface{}, error) { + return s.sequenceDe +} +func (s *ServerCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) { + s.sequenceDe = fn +} + +func (s *ServerCommon) HeartbeatTimeoutSec() int64 { + return s.maxHeartbeatLostSeconds } -// ServerStop 用于终止Server端运行 -func (star *StarNotifyS) ServerStop() { - star.cancel() +func (s *ServerCommon) SetHeartbeatTimeoutSec(sec int64) { + s.maxHeartbeatLostSeconds = sec } diff --git a/servertype.go b/servertype.go new file mode 100644 index 0000000..1185447 --- /dev/null +++ b/servertype.go @@ -0,0 +1,46 @@ +package notify + +import ( + "context" + "net" + "time" +) + +type Server interface { + SetDefaultCommEncode(func([]byte, []byte) []byte) + SetDefaultCommDecode(func([]byte, []byte) []byte) + SetDefaultLink(func(message *Message)) + SetLink(string, func(*Message)) + send(c *ClientConn, msg TransferMsg) (WaitMsg, error) + sendWait(c *ClientConn, msg TransferMsg, timeout time.Duration) (Message, error) + SendObjCtx(ctx context.Context, c *ClientConn, key string, val interface{}) (Message, error) + SendObj(c *ClientConn, key string, val interface{}) error + Send(c *ClientConn, key string, value MsgVal) error + SendWait(c *ClientConn, key string, value MsgVal, timeout time.Duration) (Message, error) + SendCtx(ctx context.Context, c *ClientConn, key string, value MsgVal) (Message, error) + Reply(m Message, value MsgVal) error + pushMessage([]byte, string) + removeClient(client *ClientConn) + Listen(network string, addr string) error + Stop() error + StopMonitorChan() <-chan struct{} + Status() Status + + GetSecretKey() []byte + SetSecretKey(key []byte) + RsaPrivKey() []byte + SetRsaPrivKey([]byte) + + GetClient(id string) *ClientConn + GetClientLists() []*ClientConn + GetClientAddrs() []net.Addr + + GetSequenceEn() func(interface{}) ([]byte,error) + SetSequenceEn(func(interface{}) ([]byte,error)) + GetSequenceDe() func([]byte) (interface{}, error) + SetSequenceDe(func([]byte) (interface{}, error)) + + HeartbeatTimeoutSec()int64 + SetHeartbeatTimeoutSec(int64) + +} diff --git a/starnotify/define.go b/starnotify/define.go index c9d8161..4b14e59 100644 --- a/starnotify/define.go +++ b/starnotify/define.go @@ -1,101 +1,107 @@ package starnotify import ( - "errors" - "time" - "b612.me/notify" + "errors" + "sync" ) var ( - starClient map[string]*notify.StarNotifyC - starServer map[string]*notify.StarNotifyS + cmu sync.RWMutex + smu sync.RWMutex + starClient map[string]notify.Client + starServer map[string]notify.Server ) func init() { - starClient = make(map[string]*notify.StarNotifyC) - starServer = make(map[string]*notify.StarNotifyS) -} - -func NewClient(key, netype, value string) (*notify.StarNotifyC, error) { - client, err := notify.NewNotifyC(netype, value) - if err != nil { - return client, err - } - starClient[key] = client - return client, err + starClient = make(map[string]notify.Client) + starServer = make(map[string]notify.Server) } -func NewClientWithTimeout(key, netype, value string, timeout time.Duration) (*notify.StarNotifyC, error) { - client, err := notify.NewNotifyCWithTimeOut(netype, value, timeout) - if err != nil { - return client, err - } +func NewClient(key string) notify.Client { + client := notify.NewClient() + cmu.Lock() starClient[key] = client - return client, err + cmu.Unlock() + return client } -func DeleteClient(key string) error { +func DeleteClient(key string) (err error) { + cmu.RLock() client, ok := starClient[key] + cmu.RUnlock() if !ok { return errors.New("Not Exists Yet!") } - if client.Online { - client.ClientStop() + if client.Status().Alive { + err = client.Stop() } client = nil + cmu.Lock() delete(starClient, key) - return nil + cmu.Unlock() + return err } -func NewServer(key, netype, value string) (*notify.StarNotifyS, error) { - server, err := notify.NewNotifyS(netype, value) - if err != nil { - return server, err - } +func NewServer(key string) notify.Server { + server := notify.NewServer() + smu.Lock() starServer[key] = server - return server, err + smu.Unlock() + return server } func DeleteServer(key string) error { + smu.RLock() server, ok := starServer[key] + smu.RUnlock() if !ok { return errors.New("Not Exists Yet!") } - if server.Online { - server.ServerStop() + if server.Status().Alive { + server.Stop() } server = nil + smu.Lock() delete(starServer, key) + smu.Unlock() return nil } -func S(key string) *notify.StarNotifyS { +func S(key string) notify.Server { + smu.RLock() server, ok := starServer[key] + smu.RUnlock() if !ok { return nil } return server } -func C(key string) *notify.StarNotifyC { +func C(key string) notify.Client { + cmu.RLock() client, ok := starClient[key] + cmu.RUnlock() if !ok { return nil } return client } -func Server(key string) (*notify.StarNotifyS, error) { +func Server(key string) (notify.Server, error) { + smu.RLock() server, ok := starServer[key] + smu.RUnlock() if !ok { return nil, errors.New("Not Exists Yet") } return server, nil } -func Client(key string) (*notify.StarNotifyC, error) { +func Client(key string) (notify.Client, error) { + cmu.RLock() client, ok := starClient[key] + cmu.RUnlock() if !ok { return nil, errors.New("Not Exists Yet") } diff --git a/v1/client.go b/v1/client.go new file mode 100644 index 0000000..2402ae8 --- /dev/null +++ b/v1/client.go @@ -0,0 +1,394 @@ +package notify + +import ( + "context" + "errors" + "fmt" + "math/rand" + "net" + "strings" + "sync" + "time" + + "b612.me/starcrypto" + "b612.me/starnet" +) + +// StarNotifyC 为Client端 +type StarNotifyC struct { + Connc net.Conn + dialTimeout time.Duration + clientSign map[string]chan string + mu sync.Mutex + // FuncLists 当不使用channel时,使用此记录调用函数 + FuncLists map[string]func(CMsg) + stopSign context.Context + cancel context.CancelFunc + defaultFunc func(CMsg) + // UseChannel 是否使用channel作为信息传递 + UseChannel bool + isUDP bool + Sync bool + // Queue 是用来处理收发信息的简单消息队列 + Queue *starnet.StarQueue + // Online 当前链接是否处于活跃状态 + Online bool + lockPool map[string]CMsg + aesKey []byte +} + +// CMsg 指明当前客户端被通知的关键字 +type CMsg struct { + Key string + Value string + mode string + wait chan int +} + +func WriteToUDP(local *net.UDPConn, remote *net.UDPAddr, data []byte) error { + var MAX_RECV_LEN = 8192 + var haveErr error + end := len(data) + for i := 0; i < end; i += MAX_RECV_LEN { + step := i + MAX_RECV_LEN + if step > end { + step = end + } + _, err := local.WriteToUDP(data[i:step], remote) + if err != nil { + haveErr = err + } + } + return haveErr +} + +func (star *StarNotifyC) starinitc() { + builder := starnet.NewQueue() + builder.EncodeFunc = encodeFunc + builder.DecodeFunc = decodeFunc + builder.Encode = true + star.stopSign, star.cancel = context.WithCancel(context.Background()) + star.Queue = builder + star.FuncLists = make(map[string]func(CMsg)) + star.UseChannel = false + star.clientSign = make(map[string]chan string) + star.Online = false + star.lockPool = make(map[string]CMsg) + star.Queue.RestoreDuration(time.Millisecond * 50) +} + +func (star *StarNotifyC) SetAesKey(key []byte) { + star.aesKey = key + star.Queue.EncodeFunc = func(data []byte) []byte { + return starcrypto.AesEncryptCFB(data, key) + } + star.Queue.DecodeFunc = func(data []byte) []byte { + return starcrypto.AesDecryptCFB(data, key) + } +} + +func (star *StarNotifyC) GetAesKey() []byte { + if len(star.aesKey) == 0 { + return aesKey + } + return star.aesKey +} + +// Notify 用于获取一个通知 +func (star *StarNotifyC) Notify(key string) chan string { + if _, ok := star.clientSign[key]; !ok { + ch := make(chan string, 20) + star.mu.Lock() + star.clientSign[key] = ch + star.mu.Unlock() + } + return star.clientSign[key] +} + +func (star *StarNotifyC) store(key, value string) { + if _, ok := star.clientSign[key]; !ok { + ch := make(chan string, 20) + ch <- value + star.mu.Lock() + star.clientSign[key] = ch + star.mu.Unlock() + return + } + star.clientSign[key] <- value +} +func NewNotifyCWithTimeOut(netype, value string, timeout time.Duration) (*StarNotifyC, error) { + var err error + var star StarNotifyC + star.starinitc() + star.isUDP = false + if strings.Index(netype, "udp") >= 0 { + star.isUDP = true + } + star.Connc, err = net.DialTimeout(netype, value, timeout) + if err != nil { + return nil, err + } + star.dialTimeout = timeout + go star.cnotify() + go func() { + <-star.stopSign.Done() + star.Connc.Close() + star.Online = false + return + }() + go func() { + for { + buf := make([]byte, 8192) + n, err := star.Connc.Read(buf) + if n != 0 { + star.Queue.ParseMessage(buf[0:n], star.Connc) + } + if err != nil { + star.Connc.Close() + star.ClientStop() + //star, _ = NewNotifyC(netype, value) + star.Online = false + return + } + } + }() + star.Online = true + return &star, nil +} + +// NewNotifyC 用于新建一个Client端进程 +func NewNotifyC(netype, value string) (*StarNotifyC, error) { + var err error + var star StarNotifyC + star.starinitc() + star.isUDP = false + if strings.Index(netype, "udp") >= 0 { + star.isUDP = true + } + star.Connc, err = net.Dial(netype, value) + if err != nil { + return nil, err + } + go star.cnotify() + go func() { + <-star.stopSign.Done() + star.Connc.Close() + star.Online = false + return + }() + go func() { + for { + buf := make([]byte, 8192) + n, err := star.Connc.Read(buf) + if n != 0 { + star.Queue.ParseMessage(buf[0:n], star.Connc) + } + if err != nil { + star.Connc.Close() + star.ClientStop() + //star, _ = NewNotifyC(netype, value) + star.Online = false + return + } + } + }() + star.Online = true + return &star, nil +} + +// Send 用于向Server端发送数据 +func (star *StarNotifyC) Send(name string) error { + return star.SendValue(name, "") +} + +func (star *StarNotifyC) Stoped() <-chan struct{} { + return star.stopSign.Done() +} + +func (star *StarNotifyC) SendValueRaw(key string, msg interface{}) error { + encodeData, err := encode(msg) + if err != nil { + return err + } + return star.SendValue(key, string(encodeData)) +} + +// SendValue 用于向Server端发送key-value类型数据 +func (star *StarNotifyC) SendValue(name, value string) error { + var err error + var key []byte + for _, v := range []byte(name) { + if v == byte(124) || v == byte(92) { + key = append(key, byte(92)) + } + key = append(key, v) + } + _, err = star.Connc.Write(star.Queue.BuildMessage([]byte("pa" + "||" + string(key) + "||" + value))) + return err +} + +func (star *StarNotifyC) trim(name string) string { + var slash bool = false + var key []byte + for _, v := range []byte(name) { + if v == byte(92) && !slash { + slash = true + continue + } + slash = false + key = append(key, v) + } + return string(key) +} +func (star *StarNotifyC) SendValueWaitRaw(key string, msg interface{}, tmout time.Duration) (CMsg, error) { + encodeData, err := encode(msg) + if err != nil { + return CMsg{}, err + } + return star.SendValueWait(key, string(encodeData), tmout) +} + +// SendValueWait 用于向Server端发送key-value类型数据并等待结果返回,此结果不会通过标准返回流程处理 +func (star *StarNotifyC) SendValueWait(name, value string, tmout time.Duration) (CMsg, error) { + var err error + var tmceed <-chan time.Time + if star.UseChannel { + return CMsg{}, errors.New("Do Not Use UseChannel Mode!") + } + rand.Seed(time.Now().UnixNano()) + mode := "cr" + fmt.Sprintf("%d%06d", time.Now().UnixNano(), rand.Intn(999999)) + var key []byte + for _, v := range []byte(name) { + if v == byte(124) || v == byte(92) { + key = append(key, byte(92)) + } + key = append(key, v) + } + _, err = star.Connc.Write(star.Queue.BuildMessage([]byte(mode + "||" + string(key) + "||" + value))) + if err != nil { + return CMsg{}, err + } + if int64(tmout) > 0 { + tmceed = time.After(tmout) + } + var source CMsg + source.wait = make(chan int, 2) + star.mu.Lock() + star.lockPool[mode] = source + star.mu.Unlock() + select { + case <-source.wait: + res := star.lockPool[mode] + star.mu.Lock() + delete(star.lockPool, mode) + star.mu.Unlock() + return res, nil + case <-tmceed: + return CMsg{}, errors.New("Time Exceed") + } +} + +// ReplyMsg 用于向Server端Reply信息 +func (star *StarNotifyC) ReplyMsg(data CMsg, name, value string) error { + var err error + var key []byte + for _, v := range []byte(name) { + if v == byte(124) || v == byte(92) { + key = append(key, byte(92)) + } + key = append(key, v) + } + _, err = star.Connc.Write(star.Queue.BuildMessage([]byte(data.mode + "||" + string(key) + "||" + value))) + return err +} + +func (star *StarNotifyC) cnotify() { + for { + select { + case <-star.stopSign.Done(): + return + default: + } + data, err := star.Queue.RestoreOne() + if err != nil { + time.Sleep(time.Millisecond * 500) + continue + } + if string(data.Msg) == "b612ryzstop" { + star.ClientStop() + star.Online = false + return + } + strs := strings.SplitN(string(data.Msg), "||", 3) + if len(strs) < 3 { + continue + } + strs[1] = star.trim(strs[1]) + if star.UseChannel { + go star.store(strs[1], strs[2]) + } else { + mode, key, value := strs[0], strs[1], strs[2] + if mode[0:2] != "cr" { + if msg, ok := star.FuncLists[key]; ok { + if star.Sync { + msg(CMsg{key, value, mode, nil}) + } else { + go msg(CMsg{key, value, mode, nil}) + } + } else { + if star.defaultFunc != nil { + if star.Sync { + star.defaultFunc(CMsg{key, value, mode, nil}) + } else { + go star.defaultFunc(CMsg{key, value, mode, nil}) + } + } + } + } else { + if sa, ok := star.lockPool[mode]; ok { + sa.Key = key + sa.Value = value + sa.mode = mode + star.mu.Lock() + star.lockPool[mode] = sa + star.mu.Unlock() + sa.wait <- 1 + } else { + if msg, ok := star.FuncLists[key]; ok { + if star.Sync { + msg(CMsg{key, value, mode, nil}) + } else { + go msg(CMsg{key, value, mode, nil}) + } + } else { + if star.defaultFunc != nil { + if star.Sync { + star.defaultFunc(CMsg{key, value, mode, nil}) + } else { + go star.defaultFunc(CMsg{key, value, mode, nil}) + } + } + } + } + } + } + } +} + +// ClientStop 终止client端运行 +func (star *StarNotifyC) ClientStop() { + if star.isUDP { + star.Send("b612ryzstop") + } + star.cancel() +} + +// SetNotify 用于设置关键词的调用函数 +func (star *StarNotifyC) SetNotify(name string, data func(CMsg)) { + star.FuncLists[name] = data +} + +// SetDefaultNotify 用于设置默认关键词的调用函数 +func (star *StarNotifyC) SetDefaultNotify(data func(CMsg)) { + star.defaultFunc = data +} diff --git a/client_test.go b/v1/client_test.go similarity index 100% rename from client_test.go rename to v1/client_test.go diff --git a/v1/serialization.go b/v1/serialization.go new file mode 100644 index 0000000..c10d889 --- /dev/null +++ b/v1/serialization.go @@ -0,0 +1,37 @@ +package notify + +import ( + "bytes" + "encoding/gob" +) + +func Register(data interface{}) { + gob.Register(data) +} + +func RegisterAll(data []interface{}) { + for _, v := range data { + gob.Register(v) + } +} +func encode(src interface{}) ([]byte, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err := enc.Encode(&src) + return buf.Bytes(), err +} + +func Decode(src []byte) (interface{}, error) { + dec := gob.NewDecoder(bytes.NewReader(src)) + var dst interface{} + err := dec.Decode(&dst) + return dst, err +} + +func (nmsg *SMsg) Decode() (interface{}, error) { + return Decode([]byte(nmsg.Value)) +} + +func (nmsg *CMsg) Decode() (interface{}, error) { + return Decode([]byte(nmsg.Value)) +} diff --git a/v1/server.go b/v1/server.go new file mode 100644 index 0000000..3149e57 --- /dev/null +++ b/v1/server.go @@ -0,0 +1,534 @@ +// Package notify is a package which provide common tcp/udp/unix socket service +package notify + +import ( + "context" + "errors" + "fmt" + "math/rand" + "net" + "strings" + "sync" + "time" + + "b612.me/starcrypto" + + "b612.me/starnet" +) + +var aesKey = []byte{0x19, 0x96, 0x11, 0x27, 228, 187, 187, 231, 142, 137, 230, 179, 189, 229, 184, 133} + +func encodeFunc(data []byte) []byte { + return starcrypto.AesEncryptCFB(data, aesKey) +} + +func decodeFunc(data []byte) []byte { + return starcrypto.AesDecryptCFB(data, aesKey) +} + +// StarNotifyS 为Server端 +type StarNotifyS struct { + // Queue 是用来处理收发信息的简单消息队列 + Queue *starnet.StarQueue + // FuncLists 记录了被通知项所记录的函数 + aesKey []byte + FuncLists map[string]func(SMsg) string + funcMu sync.Mutex + defaultFunc func(SMsg) string + Connected func(SMsg) + nickName map[string]string + stopSign context.Context + cancel context.CancelFunc + connPool sync.Map + connMu sync.Mutex + lockPool map[string]SMsg + lockMu sync.Mutex + udpPool map[string]*net.UDPAddr + listener net.Listener + isUDP bool + Sync bool + // UDPConn UDP监听 + UDPConn *net.UDPConn + // Online 当前链接是否处于活跃状态 + Online bool + // ReadDeadline tcp/unix中读超时设置,udp请直接调用UDPConn + ReadDeadline time.Time + // WriteDeadline tcp/unix中写超时设置,udp请直接调用UDPConn + WriteDeadline time.Time + + // Deadline tcp/unix中超时设置,udp请直接调用UDPConn + Deadline time.Time +} + +// SMsg 指明当前服务端被通知的关键字 +type SMsg struct { + Conn net.Conn + Key string + Value string + UDP *net.UDPAddr + Uconn *net.UDPConn + mode string + wait chan int + nickName func(string, string) error + getName func(string) string + queue *starnet.StarQueue +} + +func (star *StarNotifyS) SetAesKey(key []byte) { + star.aesKey = key + star.Queue.EncodeFunc = func(data []byte) []byte { + return starcrypto.AesEncryptCFB(data, key) + } + star.Queue.DecodeFunc = func(data []byte) []byte { + return starcrypto.AesDecryptCFB(data, key) + } +} + +func (star *StarNotifyS) GetAesKey() []byte { + if len(star.aesKey) == 0 { + return aesKey + } + return star.aesKey +} + +func (star *StarNotifyS) getName(conn string) string { + for k, v := range star.nickName { + if v == conn { + return k + } + } + return "" +} +func (star *StarNotifyS) Stoped() <-chan struct{} { + return star.stopSign.Done() +} + +// GetConnPool 获取所有Client端信息 +func (star *StarNotifyS) GetConnPool() []SMsg { + var result []SMsg + star.connPool.Range(func(k, val interface{}) bool { + v := val.(net.Conn) + result = append(result, SMsg{Conn: v, mode: "pa", nickName: star.setNickName, getName: star.getName, queue: star.Queue}) + return true + }) + for _, v := range star.udpPool { + result = append(result, SMsg{UDP: v, Uconn: star.UDPConn, mode: "pa0", nickName: star.setNickName, getName: star.getName, queue: star.Queue}) + } + return result +} + +// GetClient 获取所有Client端信息 +func (star *StarNotifyS) GetClient(name string) (SMsg, error) { + if str, ok := star.nickName[name]; ok { + if tmp, ok := star.connPool.Load(str); ok { + conn := tmp.(net.Conn) + return SMsg{Conn: conn, mode: "pa", nickName: star.setNickName, getName: star.getName, queue: star.Queue}, nil + } + if conn, ok := star.udpPool[str]; ok { + return SMsg{UDP: conn, Uconn: star.UDPConn, mode: "pa0", nickName: star.setNickName, getName: star.getName, queue: star.Queue}, nil + } + } + return SMsg{}, errors.New("Not Found") +} + +func (nmsg *SMsg) GetName() string { + if nmsg.Uconn != nil { + return nmsg.getName(nmsg.UDP.String()) + } + return nmsg.getName(fmt.Sprint(nmsg.Conn)) +} + +func (nmsg *SMsg) SetName(name string) error { + if nmsg.Uconn != nil { + return nmsg.nickName(name, nmsg.UDP.String()) + } + return nmsg.nickName(name, fmt.Sprint(nmsg.Conn)) +} + +func (nmsg *SMsg) addSlash(name string) string { + var key []byte + for _, v := range []byte(name) { + if v == byte(124) || v == byte(92) { + key = append(key, byte(92)) + } + key = append(key, v) + } + return string(key) +} + +func (nmsg *SMsg) ReplyRaw(msg interface{}) error { + encodeData, err := encode(msg) + if err != nil { + return err + } + return nmsg.Reply(string(encodeData)) +} + +// Reply 用于向client端回复数据 +func (nmsg *SMsg) Reply(msg string) error { + var err error + if nmsg.Uconn == nil { + _, err = nmsg.Conn.Write(nmsg.queue.BuildMessage([]byte(nmsg.mode + "||" + nmsg.addSlash(nmsg.Key) + "||" + msg))) + } else { + err = WriteToUDP(nmsg.Uconn, nmsg.UDP, nmsg.queue.BuildMessage([]byte(nmsg.mode+"||"+nmsg.addSlash(nmsg.Key)+"||"+msg))) + } + return err +} + +// Send 用于向client端发送key-value数据 +func (nmsg *SMsg) Send(key, value string) error { + var err error + if nmsg.Uconn == nil { + _, err = nmsg.Conn.Write(nmsg.queue.BuildMessage([]byte("pa||" + nmsg.addSlash(key) + "||" + value))) + } else { + err = WriteToUDP(nmsg.Uconn, nmsg.UDP, nmsg.queue.BuildMessage([]byte("pa||"+nmsg.addSlash(key)+"||"+value))) + } + return err +} + +func (nmsg *SMsg) SendRaw(key string, msg interface{}) error { + encodeData, err := encode(msg) + if err != nil { + return err + } + return nmsg.Send(key, string(encodeData)) +} + +func (star *StarNotifyS) SendWaitRaw(source SMsg, key string, msg interface{}, tmout time.Duration) (SMsg, error) { + encodeData, err := encode(msg) + if err != nil { + return SMsg{}, err + } + return star.SendWait(source, key, string(encodeData), tmout) +} + +// SendWait 用于向client端发送key-value数据,并等待 +func (star *StarNotifyS) SendWait(source SMsg, key, value string, tmout time.Duration) (SMsg, error) { + var err error + var tmceed <-chan time.Time + rand.Seed(time.Now().UnixNano()) + mode := "sr" + fmt.Sprintf("%d%06d", time.Now().UnixNano(), rand.Intn(999999)) + if source.Uconn == nil { + _, err = source.Conn.Write(star.Queue.BuildMessage([]byte(mode + "||" + source.addSlash(key) + "||" + value))) + } else { + err = WriteToUDP(source.Uconn, source.UDP, star.Queue.BuildMessage([]byte(mode+"||"+source.addSlash(key)+"||"+value))) + } + if err != nil { + return SMsg{}, err + } + if int64(tmout) > 0 { + tmceed = time.After(tmout) + } + source.wait = make(chan int, 2) + star.lockMu.Lock() + star.lockPool[mode] = source + star.lockMu.Unlock() + select { + case <-source.wait: + star.lockMu.Lock() + res := star.lockPool[mode] + delete(star.lockPool, mode) + star.lockMu.Unlock() + return res, nil + case <-tmceed: + return SMsg{}, errors.New("Time Exceed") + } +} + +func (star *StarNotifyS) starinits() { + builder := starnet.NewQueue() + builder.EncodeFunc = encodeFunc + builder.DecodeFunc = decodeFunc + builder.Encode = true + star.stopSign, star.cancel = context.WithCancel(context.Background()) + star.Queue = builder + star.udpPool = make(map[string]*net.UDPAddr) + star.FuncLists = make(map[string]func(SMsg) string) + star.nickName = make(map[string]string) + star.lockPool = make(map[string]SMsg) + star.Online = false + star.Queue.RestoreDuration(time.Millisecond * 50) +} + +// NewNotifyS 开启一个新的Server端通知 +func NewNotifyS(netype, value string) (*StarNotifyS, error) { + if netype[0:3] != "udp" { + return notudps(netype, value) + } + return doudps(netype, value) +} + +func doudps(netype, value string) (*StarNotifyS, error) { + var star StarNotifyS + star.starinits() + star.isUDP = true + udpaddr, err := net.ResolveUDPAddr(netype, value) + if err != nil { + return nil, err + } + star.UDPConn, err = net.ListenUDP(netype, udpaddr) + if err != nil { + return nil, err + } + go star.notify() + go func() { + <-star.stopSign.Done() + for k, v := range star.udpPool { + WriteToUDP(star.UDPConn, v, star.Queue.BuildMessage([]byte("b612ryzstop"))) + star.connMu.Lock() + delete(star.udpPool, k) + star.connMu.Unlock() + for k2, v2 := range star.nickName { + if v2 == k { + delete(star.nickName, k2) + } + } + } + star.UDPConn.Close() + star.Online = false + return + }() + go func() { + for { + buf := make([]byte, 81920) + n, addr, err := star.UDPConn.ReadFromUDP(buf) + if n != 0 { + star.Queue.ParseMessage(buf[0:n], addr) + if _, ok := star.udpPool[addr.String()]; !ok { + if star.Connected != nil { + go star.Connected(SMsg{UDP: addr, Uconn: star.UDPConn, nickName: star.setNickName, getName: star.getName, queue: star.Queue}) + } + } + star.connMu.Lock() + star.udpPool[addr.String()] = addr + star.connMu.Unlock() + } + if err != nil { + continue + } + } + }() + star.Online = true + return &star, nil +} + +func notudps(netype, value string) (*StarNotifyS, error) { + var err error + var star StarNotifyS + star.starinits() + star.isUDP = false + star.listener, err = net.Listen(netype, value) + if err != nil { + return nil, err + } + go star.notify() + go func() { + <-star.stopSign.Done() + star.connPool.Range(func(a, b interface{}) bool { + k := a.(string) + v := b.(net.Conn) + v.Close() + star.connPool.Delete(a) + for k2, v2 := range star.nickName { + if v2 == k { + star.funcMu.Lock() + delete(star.nickName, k2) + star.funcMu.Unlock() + } + } + return true + }) + star.listener.Close() + star.Online = false + return + }() + go func() { + for { + conn, err := star.listener.Accept() + if err != nil { + select { + case <-star.stopSign.Done(): + star.listener.Close() + return + default: + continue + } + } + if !star.ReadDeadline.IsZero() { + conn.SetReadDeadline(star.ReadDeadline) + } + if !star.WriteDeadline.IsZero() { + conn.SetWriteDeadline(star.WriteDeadline) + } + if !star.Deadline.IsZero() { + conn.SetDeadline(star.Deadline) + } + go func(conn net.Conn) { + for { + buf := make([]byte, 8192) + n, err := conn.Read(buf) + if n != 0 { + star.Queue.ParseMessage(buf[0:n], conn) + } + if err != nil { + conn.Close() + star.connPool.Delete(fmt.Sprint(conn)) + for k, v := range star.nickName { + if v == fmt.Sprint(conn) { + delete(star.nickName, k) + } + } + break + } + } + }(conn) + star.connPool.Store(fmt.Sprint(conn), conn) + if star.Connected != nil { + go star.Connected(SMsg{Conn: conn, nickName: star.setNickName, getName: star.getName, queue: star.Queue}) + } + } + }() + star.Online = true + return &star, nil +} + +func (star *StarNotifyS) GetListenerInfo() net.Listener { + return star.listener +} + +// SetNotify 用于设置通知关键词的调用函数 +func (star *StarNotifyS) setNickName(name string, conn string) error { + if _, ok := star.connPool.Load(conn); !ok { + if _, ok := star.udpPool[conn]; !ok { + return errors.New("Conn Not Found") + } + } + for k, v := range star.nickName { + if v == conn { + delete(star.nickName, k) + } + } + star.funcMu.Lock() + star.nickName[name] = conn + star.funcMu.Unlock() + return nil +} + +// SetNotify 用于设置通知关键词的调用函数 +func (star *StarNotifyS) SetNotify(name string, data func(SMsg) string) { + star.funcMu.Lock() + defer star.funcMu.Unlock() + if data == nil { + if _, ok := star.FuncLists[name]; ok { + delete(star.FuncLists, name) + } + return + } + star.FuncLists[name] = data +} + +// SetDefaultNotify 用于设置默认关键词的调用函数 +func (star *StarNotifyS) SetDefaultNotify(data func(SMsg) string) { + star.defaultFunc = data +} + +func (star *StarNotifyS) trim(name string) string { + var slash bool = false + var key []byte + for _, v := range []byte(name) { + if v == byte(92) && !slash { + slash = true + continue + } + slash = false + key = append(key, v) + } + return string(key) +} + +func (star *StarNotifyS) notify() { + for { + select { + case <-star.stopSign.Done(): + return + default: + } + data, err := star.Queue.RestoreOne() + if err != nil { + time.Sleep(time.Millisecond * 500) + continue + } + mode, key, value := star.analyseData(string(data.Msg)) + if mode == key && mode == value && mode == "" { + continue + } + var rmsg SMsg + if !star.isUDP { + rmsg = SMsg{data.Conn.(net.Conn), key, value, nil, nil, mode, nil, star.setNickName, star.getName, star.Queue} + } else { + rmsg = SMsg{nil, key, value, data.Conn.(*net.UDPAddr), star.UDPConn, mode, nil, star.setNickName, star.getName, star.Queue} + if key == "b612ryzstop" { + star.connMu.Lock() + delete(star.udpPool, rmsg.UDP.String()) + star.connMu.Unlock() + for k, v := range star.nickName { + if v == rmsg.UDP.String() { + delete(star.nickName, k) + } + } + continue + } + } + replyFunc := func(key string, rmsg SMsg) { + if msg, ok := star.FuncLists[key]; ok { + sdata := msg(rmsg) + if sdata == "" { + return + } + rmsg.Reply(sdata) + } else { + if star.defaultFunc != nil { + sdata := star.defaultFunc(rmsg) + if sdata == "" { + return + } + rmsg.Reply(sdata) + } + } + } + if mode[0:2] != "sr" { + if !star.Sync { + go replyFunc(key, rmsg) + } else { + replyFunc(key, rmsg) + } + } else { + if sa, ok := star.lockPool[mode]; ok { + rmsg.wait = sa.wait + star.lockMu.Lock() + star.lockPool[mode] = rmsg + star.lockPool[mode].wait <- 1 + star.lockMu.Unlock() + } else { + if !star.Sync { + go replyFunc(key, rmsg) + } else { + replyFunc(key, rmsg) + } + } + } + } +} + +func (star *StarNotifyS) analyseData(msg string) (mode, key, value string) { + slice := strings.SplitN(msg, "||", 3) + if len(slice) < 3 { + return "", "", "" + } + return slice[0], star.trim(slice[1]), slice[2] +} + +// ServerStop 用于终止Server端运行 +func (star *StarNotifyS) ServerStop() { + star.cancel() +} diff --git a/v1/starnotify/define.go b/v1/starnotify/define.go new file mode 100644 index 0000000..f6f9b7b --- /dev/null +++ b/v1/starnotify/define.go @@ -0,0 +1,103 @@ +package starnotify + +import ( + "errors" + "time" + + "b612.me/notify/v1" +) + +var ( + starClient map[string]*notify.StarNotifyC + starServer map[string]*notify.StarNotifyS +) + +func init() { + starClient = make(map[string]*notify.StarNotifyC) + starServer = make(map[string]*notify.StarNotifyS) +} + +func NewClient(key, netype, value string) (*notify.StarNotifyC, error) { + client, err := notify.NewNotifyC(netype, value) + if err != nil { + return client, err + } + starClient[key] = client + return client, err +} + +func NewClientWithTimeout(key, netype, value string, timeout time.Duration) (*notify.StarNotifyC, error) { + client, err := notify.NewNotifyCWithTimeOut(netype, value, timeout) + if err != nil { + return client, err + } + starClient[key] = client + return client, err +} + +func DeleteClient(key string) error { + client, ok := starClient[key] + if !ok { + return errors.New("Not Exists Yet!") + } + if client.Online { + client.ClientStop() + } + client = nil + delete(starClient, key) + return nil +} + +func NewServer(key, netype, value string) (*notify.StarNotifyS, error) { + server, err := notify.NewNotifyS(netype, value) + if err != nil { + return server, err + } + starServer[key] = server + return server, err +} + +func DeleteServer(key string) error { + server, ok := starServer[key] + if !ok { + return errors.New("Not Exists Yet!") + } + if server.Online { + server.ServerStop() + } + server = nil + delete(starServer, key) + return nil +} + +func S(key string) *notify.StarNotifyS { + server, ok := starServer[key] + if !ok { + return nil + } + return server +} + +func C(key string) *notify.StarNotifyC { + client, ok := starClient[key] + if !ok { + return nil + } + return client +} + +func Server(key string) (*notify.StarNotifyS, error) { + server, ok := starServer[key] + if !ok { + return nil, errors.New("Not Exists Yet") + } + return server, nil +} + +func Client(key string) (*notify.StarNotifyC, error) { + client, ok := starClient[key] + if !ok { + return nil, errors.New("Not Exists Yet") + } + return client, nil +} diff --git a/v1/v2cs_test.go b/v1/v2cs_test.go new file mode 100644 index 0000000..812a242 --- /dev/null +++ b/v1/v2cs_test.go @@ -0,0 +1,51 @@ +package notify + +import ( + "fmt" + "sync/atomic" + "testing" + "time" +) + +func Test_ServerTuAndClientCommon(t *testing.T) { + server, err := NewNotifyS("tcp", "127.0.0.1:12345") + if err != nil { + panic(err) + } + server.SetNotify("notify", notify) + for i := 1; i <= 1; i++ { + go func() { + + client, err := NewNotifyC("tcp", "127.0.0.1:12345") + if err != nil { + time.Sleep(time.Second * 2) + panic(err) + } + for { + //nowd = time.Now().UnixNano() + client.SendValueWait("notify", "client hello", time.Second*50) + //time.Sleep(time.Millisecond) + //fmt.Println("finished:", float64(time.Now().UnixNano()-nowd)/1000000) + //client.Send("notify", []byte("client hello")) + } + }() + } + go func() { + time.Sleep(time.Second * 10) + server.ServerStop() + }() + <-server.Stoped() + //time.Sleep(time.Second * 5) + fmt.Println(count2) + +} + +var count2 int64 + +func notify(msg SMsg) string { + //fmt.Println(string(msg.Msg.Value)) + //fmt.Println("called:", float64(time.Now().UnixNano()-nowd)/1000000) + + go atomic.AddInt64(&count2, 1) + return "ok" +} diff --git a/v2cs_test.go b/v2cs_test.go new file mode 100644 index 0000000..1bb00c6 --- /dev/null +++ b/v2cs_test.go @@ -0,0 +1,146 @@ +package notify + +import ( + "fmt" + "net" + //_ "net/http/pprof" + "sync/atomic" + "testing" + "time" +) + +func Test_ServerTuAndClientCommon(t *testing.T) { + //go http.ListenAndServe("0.0.0.0:8888", nil) + noEn := func(key, bn []byte) []byte { + return bn + } + server := NewServer() + server.SetDefaultCommDecode(noEn) + server.SetDefaultCommEncode(noEn) + err := server.Listen("tcp", "127.0.0.1:12345") + if err != nil { + panic(err) + } + server.SetLink("notify", notify) + for i := 1; i <= 5000; i++ { + go func() { + client := NewClient() + client.SetMsgEn(noEn) + client.SetMsgDe(noEn) + client.SetSkipExchangeKey(true) + err = client.Connect("tcp", "127.0.0.1:12345") + if err != nil { + time.Sleep(time.Second * 2) + return + } + //client.SetLink("notify", notify) + for { + + //nowd = time.Now().UnixNano() + client.SendWait("notify", []byte("client hello"),time.Second*15) + //time.Sleep(time.Millisecond) + //fmt.Println("finished:", float64(time.Now().UnixNano()-nowd)/1000000) + //client.Send("notify", []byte("client")) + } + }() + } + time.Sleep(time.Second) + go func() { + time.Sleep(time.Second * 10) + server.Stop() + }() + <-server.StopMonitorChan() + fmt.Println(count2) +} + +var count2 int64 + +func notify(msg *Message) { + //fmt.Println(string(msg.Msg.Value)) + //fmt.Println("called:", float64(time.Now().UnixNano()-nowd)/1000000) + if msg.NetType == NET_SERVER { + atomic.AddInt64(&count2, 1) + msg.Reply([]byte("server reply")) + } +} + +func Test_normal(t *testing.T) { + server, _ := net.Listen("udp", "127.0.0.1:12345") + go func() { + for { + conn, err := server.Accept() + if err == nil { + go func() { + for { + buf := make([]byte, 1024) + _, err := conn.Read(buf) + //fmt.Println("S RECV", string(buf[:i])) + atomic.AddInt64(&count2, 1) + if err == nil { + conn.Write([]byte("hello world server")) + } + } + }() + } + } + }() + time.Sleep(time.Second * 5) + for i := 1; i <= 100; i++ { + go func() { + conn, err := net.Dial("udp", "127.0.0.1:12345") + if err != nil { + panic(err) + } + for { + //nowd = time.Now().UnixNano() + _, err := conn.Write([]byte("hello world client")) + if err != nil { + fmt.Println(err) + } + buf := make([]byte, 1024) + conn.Read(buf) + + continue + } + }() + } + time.Sleep(time.Second * 10) + fmt.Println(count2) +} + +func Test_normal_udp(t *testing.T) { + ludp, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12345") + conn, _ := net.ListenUDP("udp", ludp) + go func() { + for { + buf := make([]byte, 1024) + _, addr, err := conn.ReadFromUDP(buf) + fmt.Println(time.Now(), "S RECV", addr.String()) + atomic.AddInt64(&count2, 1) + if err == nil { + conn.WriteToUDP([]byte("hello world server"), addr) + } + } + }() + for i := 1; i <= 100; i++ { + go func() { + conn, err := net.Dial("udp", "127.0.0.1:12345") + if err != nil { + panic(err) + } + for { + //nowd = time.Now().UnixNano() + _, err := conn.Write([]byte("hello world client")) + if err != nil { + fmt.Println(err) + } + buf := make([]byte, 1024) + conn.Read(buf) + fmt.Println(time.Now(), "C RECV") + continue + } + }() + } + time.Sleep(time.Second * 10) + fmt.Println(count2) +}