package notify import ( "b612.me/stario" "b612.me/starnet" "context" "errors" "fmt" "math/rand" "net" "os" "strings" "sync" "sync/atomic" "time" ) type ServerCommon struct { msgID uint64 alive atomic.Value status Status listener net.Listener udpListener *net.UDPConn queue *starnet.StarQueue stopFn context.CancelFunc stopCtx context.Context maxReadTimeout time.Duration maxWriteTimeout time.Duration parallelNum int wg stario.WaitGroup clientPool map[string]*ClientConn mu sync.RWMutex handshakeRsaKey []byte SecretKey []byte defaultMsgEn func([]byte, []byte) []byte defaultMsgDe func([]byte, []byte) []byte linkFns map[string]func(message *Message) defaultFns func(message *Message) noFinSyncMsgPool sync.Map noFinSyncMsgMaxKeepSeconds int64 maxHeartbeatLostSeconds int64 sequenceDe func([]byte) (interface{}, error) sequenceEn func(interface{}) ([]byte, error) showError bool } func NewServer() Server { var server ServerCommon server.wg = stario.NewWaitGroup(0) server.parallelNum = 0 server.noFinSyncMsgMaxKeepSeconds = 0 server.maxHeartbeatLostSeconds = 300 server.stopCtx, server.stopFn = context.WithCancel(context.Background()) server.SecretKey = defaultAesKey server.handshakeRsaKey = defaultRsaKey server.clientPool = make(map[string]*ClientConn) server.defaultMsgEn = defaultMsgEn server.defaultMsgDe = defaultMsgDe server.sequenceEn = encode server.sequenceDe = Decode server.alive.Store(false) server.linkFns = make(map[string]func(*Message)) server.defaultFns = func(message *Message) { return } return &server } func (s *ServerCommon) ShowError(std bool) { s.mu.Lock() s.showError = std s.mu.Unlock() } func (s *ServerCommon) Stop() error { if !s.alive.Load().(bool) { return nil } s.alive.Store(false) s.mu.Lock() s.status = Status{ Alive: false, Reason: "recv stop signal from user", Err: nil, } s.mu.Unlock() s.stopFn() return nil } func (s *ServerCommon) Listen(network string, addr string) error { if s.alive.Load().(bool) { return errors.New("server already run") } s.stopCtx, s.stopFn = context.WithCancel(context.Background()) s.queue = starnet.NewQueueCtx(s.stopCtx, 128) if strings.Contains(strings.ToLower(network), "udp") { return s.ListenUDP(network, addr) } return s.ListenTU(network, addr) } func (s *ServerCommon) ListenTU(network string, addr string) error { listener, err := net.Listen(network, addr) if err != nil { return err } s.alive.Store(true) s.status.Alive = true s.listener = listener go s.accept() go s.monitorPool() go s.loadMessage() return nil } func (s *ServerCommon) monitorPool() { for { select { case <-s.stopCtx.Done(): s.noFinSyncMsgPool.Range(func(k, v interface{}) bool { data := v.(WaitMsg) close(data.Reply) s.noFinSyncMsgPool.Delete(k) return true }) return case <-time.After(time.Second * 30): } now := time.Now() if s.noFinSyncMsgMaxKeepSeconds > 0 { s.noFinSyncMsgPool.Range(func(k, v interface{}) bool { data := v.(WaitMsg) if data.Time.Add(time.Duration(s.noFinSyncMsgMaxKeepSeconds) * time.Second).Before(now) { close(data.Reply) s.noFinSyncMsgPool.Delete(k) } return true }) } if s.maxHeartbeatLostSeconds != 0 { for _, v := range s.clientPool { if now.Unix()-v.lastHeartBeat > s.maxHeartbeatLostSeconds { v.stopFn() s.removeClient(v) } } } } } func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) { s.defaultMsgEn = fn } func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) { s.defaultMsgDe = fn } func (s *ServerCommon) SetDefaultLink(fn func(message *Message)) { s.defaultFns = fn } func (s *ServerCommon) SetLink(key string, fn func(*Message)) { s.mu.Lock() defer s.mu.Unlock() s.linkFns[key] = fn } func (s *ServerCommon) pushMessage(data []byte, source string) { s.queue.ParseMessage(data, source) } func (s *ServerCommon) removeClient(client *ClientConn) { s.mu.Lock() defer s.mu.Unlock() delete(s.clientPool, client.ClientID) } func (s *ServerCommon) accept() { if s.udpListener != nil { s.acceptUDP() } s.acceptTU() } func (s *ServerCommon) acceptTU() { for { select { case <-s.stopCtx.Done(): return default: } conn, err := s.listener.Accept() if err != nil { if s.showError { fmt.Println("error accept:", err) } continue } var id string for { id = fmt.Sprintf("%s%d%d", conn.RemoteAddr().String(), time.Now().UnixNano(), rand.Int63()) s.mu.RLock() if _, ok := s.clientPool[id]; ok { s.mu.RUnlock() continue } s.mu.RUnlock() break } client := ClientConn{ ClientID: id, ClientAddr: conn.RemoteAddr(), tuConn: conn, server: s, maxReadTimeout: s.maxReadTimeout, maxWriteTimeout: s.maxWriteTimeout, SecretKey: s.SecretKey, handshakeRsaKey: s.handshakeRsaKey, msgEn: s.defaultMsgEn, msgDe: s.defaultMsgDe, lastHeartBeat: time.Now().Unix(), } client.alive.Store(true) client.status = Status{ Alive: true, Reason: "", Err: nil, } client.stopCtx, client.stopFn = context.WithCancel(context.Background()) s.mu.Lock() s.clientPool[id] = &client s.mu.Unlock() go client.readTUMessage() } } func (s *ServerCommon) loadMessage() { for { select { case <-s.stopCtx.Done(): var wg sync.WaitGroup s.mu.RLock() for _, v := range s.clientPool { wg.Add(1) go func() { defer wg.Done() v.sayGoodByeForTU() v.alive.Store(false) v.status = Status{ Alive: false, Reason: "recv stop signal from server", Err: nil, } v.stopFn() s.removeClient(v) }() } s.mu.RUnlock() select { case <-time.After(time.Second * 8): case <-stario.WaitUntilFinished(func() error { wg.Wait() return nil }): } if s.listener != nil { s.listener.Close() } s.wg.Wait() return case data, ok := <-s.queue.RestoreChan(): if !ok { continue } s.wg.Add(1) go func(data starnet.MsgQueue) { s.mu.RLock() cc, ok := s.clientPool[data.Conn.(string)] s.mu.RUnlock() if !ok { return } //fmt.Println("received:", float64(time.Now().UnixNano()-nowd)/1000000) msg, err := s.sequenceDe(cc.msgDe(cc.SecretKey, data.Msg)) if err != nil { if s.showError { fmt.Println("server decode data error", err) } return } //fmt.Println("decoded:", float64(time.Now().UnixNano()-nowd)/1000000) message := Message{ NetType: NET_SERVER, ClientConn: cc, TransferMsg: msg.(TransferMsg), } message.Time = time.Now() //fmt.Println("dispatch:", float64(time.Now().UnixNano()-nowd)/1000000) s.dispatchMsg(message) }(data) } } } func (s *ServerCommon) sysMsg(message Message) { switch message.Key { case "bye": //fmt.Println("recv stop signal from client", message.ClientConn.ClientID) if message.TransferMsg.Type == MSG_SYS_WAIT { message.Reply(nil) } message.ClientConn.alive.Store(false) message.ClientConn.status = Status{ Alive: false, Reason: "recv stop signal from client", Err: nil, } message.ClientConn.stopFn() case "heartbeat": message.ClientConn.lastHeartBeat = time.Now().Unix() message.Reply(nil) } } func (s *ServerCommon) dispatchMsg(message Message) { defer s.wg.Done() switch message.TransferMsg.Type { case MSG_SYS_WAIT: fallthrough case MSG_SYS: s.sysMsg(message) return case MSG_KEY_CHANGE: message.ClientConn.rsaDecode(message) return case MSG_SYS_REPLY: fallthrough case MSG_SYNC_REPLY: data, ok := s.noFinSyncMsgPool.Load(message.TransferMsg.ID) if ok { wait := data.(WaitMsg) wait.Reply <- message s.noFinSyncMsgPool.Delete(message.TransferMsg.ID) return } //just throw return //fallthrough default: } callFn := func(fn func(*Message)) { fn(&message) } fn, ok := s.linkFns[message.TransferMsg.Key] if ok { callFn(fn) } if s.defaultFns != nil { callFn(s.defaultFns) } } func (s *ServerCommon) send(c *ClientConn, msg TransferMsg) (WaitMsg, error) { if s.udpListener != nil { return s.sendUDP(c, msg) } return s.sendTU(c, msg) } func (s *ServerCommon) sendTU(c *ClientConn, msg TransferMsg) (WaitMsg, error) { var wait WaitMsg if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { msg.ID = atomic.AddUint64(&s.msgID, 1) } data, err := s.sequenceEn(msg) if err != nil { return WaitMsg{}, err } data = c.msgEn(c.SecretKey, data) data = s.queue.BuildMessage(data) if c.maxWriteTimeout.Seconds() != 0 { c.tuConn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout)) } _, err = c.tuConn.Write(data) //fmt.Println("resend:", float64(time.Now().UnixNano()-nowd)/1000000) if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_SYS_WAIT) { wait.Time = time.Now() wait.TransferMsg = msg wait.Reply = make(chan Message, 1) s.noFinSyncMsgPool.Store(msg.ID, wait) } return wait, err } func (s *ServerCommon) Send(c *ClientConn, key string, value MsgVal) error { _, err := s.send(c, TransferMsg{ Key: key, Value: value, Type: MSG_ASYNC, }) return err } func (s *ServerCommon) sendWait(c *ClientConn, msg TransferMsg, timeout time.Duration) (Message, error) { data, err := s.send(c, msg) if err != nil { return Message{}, err } if timeout.Seconds() == 0 { msg, ok := <-data.Reply if !ok { return msg, os.ErrInvalid } return msg, nil } select { case <-time.After(timeout): close(data.Reply) s.noFinSyncMsgPool.Delete(data.TransferMsg.ID) return Message{}, os.ErrDeadlineExceeded case <-s.stopCtx.Done(): return Message{}, errors.New("Service shutdown") case msg, ok := <-data.Reply: if !ok { return msg, os.ErrInvalid } return msg, nil } } func (s *ServerCommon) SendCtx(ctx context.Context, c *ClientConn, key string, value MsgVal) (Message, error) { return s.sendCtx(c, TransferMsg{ Key: key, Value: value, Type: MSG_SYNC_ASK, }, ctx) } func (s *ServerCommon) sendCtx(c *ClientConn, msg TransferMsg, ctx context.Context) (Message, error) { data, err := s.send(c, msg) if err != nil { return Message{}, err } if ctx == nil { ctx = context.Background() } select { case <-ctx.Done(): close(data.Reply) s.noFinSyncMsgPool.Delete(data.TransferMsg.ID) return Message{}, os.ErrClosed case <-s.stopCtx.Done(): return Message{}, errors.New("Service shutdown") case msg, ok := <-data.Reply: if !ok { return msg, os.ErrInvalid } return msg, nil } } func (s *ServerCommon) SendWait(c *ClientConn, key string, value MsgVal, timeout time.Duration) (Message, error) { return s.sendWait(c, TransferMsg{ Key: key, Value: value, Type: MSG_SYNC_ASK, }, timeout) } func (s *ServerCommon) SendWaitObj(c *ClientConn, key string, value interface{}, timeout time.Duration) (Message, error) { data, err := s.sequenceEn(value) if err != nil { return Message{}, err } return s.SendWait(c, key, data, timeout) } func (s *ServerCommon) SendObjCtx(ctx context.Context, c *ClientConn, key string, val interface{}) (Message, error) { data, err := s.sequenceEn(val) if err != nil { return Message{}, err } return s.sendCtx(c, TransferMsg{ Key: key, Value: data, Type: MSG_SYNC_ASK, }, ctx) } func (s *ServerCommon) SendObj(c *ClientConn, key string, val interface{}) error { data, err := encode(val) if err != nil { return err } _, err = s.send(c, TransferMsg{ Key: key, Value: data, Type: MSG_ASYNC, }) return err } func (s *ServerCommon) Reply(m Message, value MsgVal) error { return m.Reply(value) } //for udp below func (s *ServerCommon) ListenUDP(network string, addr string) error { udpAddr, err := net.ResolveUDPAddr(network, addr) if err != nil { return err } listener, err := net.ListenUDP(network, udpAddr) if err != nil { return err } s.alive.Store(true) s.status.Alive = true s.udpListener = listener go s.accept() go s.monitorPool() go s.loadMessage() return nil } func (s *ServerCommon) acceptUDP() { for { select { case <-s.stopCtx.Done(): return default: } if s.maxReadTimeout.Seconds() > 0 { s.udpListener.SetReadDeadline(time.Now().Add(s.maxReadTimeout)) } data := make([]byte, 4096) num, addr, err := s.udpListener.ReadFromUDP(data) id := addr.String() //fmt.Println("s recv udp:", float64(time.Now().UnixNano()-nowd)/1000000) s.mu.RLock() if _, ok := s.clientPool[id]; !ok { s.mu.RUnlock() client := ClientConn{ ClientID: id, ClientAddr: addr, server: s, maxReadTimeout: s.maxReadTimeout, maxWriteTimeout: s.maxWriteTimeout, SecretKey: s.SecretKey, handshakeRsaKey: s.handshakeRsaKey, msgEn: s.defaultMsgEn, msgDe: s.defaultMsgDe, lastHeartBeat: time.Now().Unix(), } client.stopCtx, client.stopFn = context.WithCancel(context.Background()) s.mu.Lock() s.clientPool[id] = &client s.mu.Unlock() } else { s.mu.RUnlock() } if err == os.ErrDeadlineExceeded { if num != 0 { s.pushMessage(data[:num], id) } continue } if err != nil { continue } s.pushMessage(data[:num], id) } } func (s *ServerCommon) sendUDP(c *ClientConn, msg TransferMsg) (WaitMsg, error) { var wait WaitMsg if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { msg.ID = uint64(time.Now().UnixNano()) + rand.Uint64() + rand.Uint64() } data, err := s.sequenceEn(msg) if err != nil { return WaitMsg{}, err } data = c.msgEn(c.SecretKey, data) data = s.queue.BuildMessage(data) if c.maxWriteTimeout.Seconds() != 0 { s.udpListener.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout)) } _, err = s.udpListener.WriteTo(data, c.ClientAddr) if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_SYS_WAIT) { wait.Time = time.Now() wait.TransferMsg = msg wait.Reply = make(chan Message, 1) s.noFinSyncMsgPool.Store(msg.ID, wait) } return wait, err } func (s *ServerCommon) StopMonitorChan() <-chan struct{} { return s.stopCtx.Done() } func (s *ServerCommon) Status() Status { return s.status } func (s *ServerCommon) GetSecretKey() []byte { return s.SecretKey } func (s *ServerCommon) SetSecretKey(key []byte) { s.SecretKey = key } func (s *ServerCommon) RsaPrivKey() []byte { return s.handshakeRsaKey } func (s *ServerCommon) SetRsaPrivKey(key []byte) { s.handshakeRsaKey = key } func (s *ServerCommon) GetClient(id string) *ClientConn { s.mu.RLock() defer s.mu.RUnlock() c, ok := s.clientPool[id] if !ok { return nil } return c } func (s *ServerCommon) GetClientLists() []*ClientConn { s.mu.RLock() defer s.mu.RUnlock() var list []*ClientConn = make([]*ClientConn, 0, len(s.clientPool)) for _, v := range s.clientPool { list = append(list, v) } return list } func (s *ServerCommon) GetClientAddrs() []net.Addr { s.mu.RLock() defer s.mu.RUnlock() var list = make([]net.Addr, 0, len(s.clientPool)) for _, v := range s.clientPool { list = append(list, v.ClientAddr) } return list } func (s *ServerCommon) GetSequenceEn() func(interface{}) ([]byte, error) { return s.sequenceEn } func (s *ServerCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) { s.sequenceEn = fn } func (s *ServerCommon) GetSequenceDe() func([]byte) (interface{}, error) { return s.sequenceDe } func (s *ServerCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) { s.sequenceDe = fn } func (s *ServerCommon) HeartbeatTimeoutSec() int64 { return s.maxHeartbeatLostSeconds } func (s *ServerCommon) SetHeartbeatTimeoutSec(sec int64) { s.maxHeartbeatLostSeconds = sec }