You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

666 lines
16 KiB
Go

2 years ago
package notify
import (
"b612.me/stario"
"b612.me/starnet"
"context"
"errors"
"fmt"
"math/rand"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
type ServerCommon struct {
msgID uint64
alive atomic.Value
status Status
listener net.Listener
udpListener *net.UDPConn
queue *starnet.StarQueue
stopFn context.CancelFunc
stopCtx context.Context
maxReadTimeout time.Duration
maxWriteTimeout time.Duration
parallelNum int
wg stario.WaitGroup
clientPool map[string]*ClientConn
mu sync.RWMutex
handshakeRsaKey []byte
SecretKey []byte
defaultMsgEn func([]byte, []byte) []byte
defaultMsgDe func([]byte, []byte) []byte
linkFns map[string]func(message *Message)
defaultFns func(message *Message)
noFinSyncMsgPool sync.Map
noFinSyncMsgMaxKeepSeconds int64
maxHeartbeatLostSeconds int64
sequenceDe func([]byte) (interface{}, error)
sequenceEn func(interface{}) ([]byte, error)
showError bool
}
func NewServer() Server {
var server ServerCommon
server.wg = stario.NewWaitGroup(0)
server.parallelNum = 0
server.noFinSyncMsgMaxKeepSeconds = 0
server.maxHeartbeatLostSeconds = 300
server.stopCtx, server.stopFn = context.WithCancel(context.Background())
server.SecretKey = defaultAesKey
server.handshakeRsaKey = defaultRsaKey
server.clientPool = make(map[string]*ClientConn)
server.defaultMsgEn = defaultMsgEn
server.defaultMsgDe = defaultMsgDe
server.sequenceEn = encode
server.sequenceDe = Decode
server.alive.Store(false)
server.linkFns = make(map[string]func(*Message))
server.defaultFns = func(message *Message) {
return
}
return &server
}
func (s *ServerCommon) ShowError(std bool) {
s.mu.Lock()
s.showError = std
s.mu.Unlock()
}
func (s *ServerCommon) Stop() error {
if !s.alive.Load().(bool) {
return nil
}
s.alive.Store(false)
s.mu.Lock()
s.status = Status{
Alive: false,
Reason: "recv stop signal from user",
Err: nil,
}
s.mu.Unlock()
s.stopFn()
return nil
}
func (s *ServerCommon) Listen(network string, addr string) error {
if s.alive.Load().(bool) {
return errors.New("server already run")
}
s.stopCtx, s.stopFn = context.WithCancel(context.Background())
s.queue = starnet.NewQueueCtx(s.stopCtx, 128)
if strings.Contains(strings.ToLower(network), "udp") {
return s.ListenUDP(network, addr)
}
return s.ListenTU(network, addr)
}
func (s *ServerCommon) ListenTU(network string, addr string) error {
listener, err := net.Listen(network, addr)
if err != nil {
return err
}
s.alive.Store(true)
s.status.Alive = true
s.listener = listener
go s.accept()
go s.monitorPool()
go s.loadMessage()
return nil
}
func (s *ServerCommon) monitorPool() {
for {
select {
case <-s.stopCtx.Done():
s.noFinSyncMsgPool.Range(func(k, v interface{}) bool {
data := v.(WaitMsg)
close(data.Reply)
s.noFinSyncMsgPool.Delete(k)
return true
})
return
case <-time.After(time.Second * 30):
}
now := time.Now()
if s.noFinSyncMsgMaxKeepSeconds > 0 {
s.noFinSyncMsgPool.Range(func(k, v interface{}) bool {
data := v.(WaitMsg)
if data.Time.Add(time.Duration(s.noFinSyncMsgMaxKeepSeconds) * time.Second).Before(now) {
close(data.Reply)
s.noFinSyncMsgPool.Delete(k)
}
return true
})
}
if s.maxHeartbeatLostSeconds != 0 {
for _, v := range s.clientPool {
if now.Unix()-v.lastHeartBeat > s.maxHeartbeatLostSeconds {
v.stopFn()
s.removeClient(v)
}
}
}
}
}
func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) {
s.defaultMsgEn = fn
}
func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) {
s.defaultMsgDe = fn
}
func (s *ServerCommon) SetDefaultLink(fn func(message *Message)) {
s.defaultFns = fn
}
func (s *ServerCommon) SetLink(key string, fn func(*Message)) {
s.mu.Lock()
defer s.mu.Unlock()
s.linkFns[key] = fn
}
func (s *ServerCommon) pushMessage(data []byte, source string) {
s.queue.ParseMessage(data, source)
}
func (s *ServerCommon) removeClient(client *ClientConn) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.clientPool, client.ClientID)
}
func (s *ServerCommon) accept() {
if s.udpListener != nil {
s.acceptUDP()
}
s.acceptTU()
}
func (s *ServerCommon) acceptTU() {
for {
select {
case <-s.stopCtx.Done():
return
default:
}
conn, err := s.listener.Accept()
if err != nil {
if s.showError {
fmt.Println("error accept:", err)
}
continue
}
var id string
for {
id = fmt.Sprintf("%s%d%d", conn.RemoteAddr().String(), time.Now().UnixNano(), rand.Int63())
s.mu.RLock()
if _, ok := s.clientPool[id]; ok {
s.mu.RUnlock()
continue
}
s.mu.RUnlock()
break
}
client := ClientConn{
ClientID: id,
ClientAddr: conn.RemoteAddr(),
tuConn: conn,
server: s,
maxReadTimeout: s.maxReadTimeout,
maxWriteTimeout: s.maxWriteTimeout,
SecretKey: s.SecretKey,
handshakeRsaKey: s.handshakeRsaKey,
msgEn: s.defaultMsgEn,
msgDe: s.defaultMsgDe,
lastHeartBeat: time.Now().Unix(),
}
client.alive.Store(true)
client.status = Status{
Alive: true,
Reason: "",
Err: nil,
}
client.stopCtx, client.stopFn = context.WithCancel(context.Background())
s.mu.Lock()
s.clientPool[id] = &client
s.mu.Unlock()
go client.readTUMessage()
}
}
func (s *ServerCommon) loadMessage() {
for {
select {
case <-s.stopCtx.Done():
var wg sync.WaitGroup
s.mu.RLock()
for _, v := range s.clientPool {
wg.Add(1)
go func() {
defer wg.Done()
v.sayGoodByeForTU()
v.alive.Store(false)
v.status = Status{
Alive: false,
Reason: "recv stop signal from server",
Err: nil,
}
v.stopFn()
s.removeClient(v)
}()
}
s.mu.RUnlock()
select {
case <-time.After(time.Second * 8):
case <-stario.WaitUntilFinished(func() error {
wg.Wait()
return nil
}):
}
if s.listener != nil {
s.listener.Close()
}
s.wg.Wait()
return
case data, ok := <-s.queue.RestoreChan():
if !ok {
continue
}
s.wg.Add(1)
go func(data starnet.MsgQueue) {
s.mu.RLock()
cc, ok := s.clientPool[data.Conn.(string)]
s.mu.RUnlock()
if !ok {
return
}
//fmt.Println("received:", float64(time.Now().UnixNano()-nowd)/1000000)
msg, err := s.sequenceDe(cc.msgDe(cc.SecretKey, data.Msg))
if err != nil {
if s.showError {
fmt.Println("server decode data error", err)
}
return
}
//fmt.Println("decoded:", float64(time.Now().UnixNano()-nowd)/1000000)
message := Message{
NetType: NET_SERVER,
ClientConn: cc,
TransferMsg: msg.(TransferMsg),
}
message.Time = time.Now()
//fmt.Println("dispatch:", float64(time.Now().UnixNano()-nowd)/1000000)
s.dispatchMsg(message)
}(data)
}
}
}
func (s *ServerCommon) sysMsg(message Message) {
switch message.Key {
case "bye":
//fmt.Println("recv stop signal from client", message.ClientConn.ClientID)
if message.TransferMsg.Type == MSG_SYS_WAIT {
message.Reply(nil)
}
message.ClientConn.alive.Store(false)
message.ClientConn.status = Status{
Alive: false,
Reason: "recv stop signal from client",
Err: nil,
}
message.ClientConn.stopFn()
case "heartbeat":
message.ClientConn.lastHeartBeat = time.Now().Unix()
message.Reply(nil)
}
}
func (s *ServerCommon) dispatchMsg(message Message) {
defer s.wg.Done()
switch message.TransferMsg.Type {
case MSG_SYS_WAIT:
fallthrough
case MSG_SYS:
s.sysMsg(message)
return
case MSG_KEY_CHANGE:
message.ClientConn.rsaDecode(message)
return
case MSG_SYS_REPLY:
fallthrough
case MSG_SYNC_REPLY:
data, ok := s.noFinSyncMsgPool.Load(message.TransferMsg.ID)
if ok {
wait := data.(WaitMsg)
wait.Reply <- message
s.noFinSyncMsgPool.Delete(message.TransferMsg.ID)
return
}
//just throw
return
//fallthrough
default:
}
callFn := func(fn func(*Message)) {
fn(&message)
}
fn, ok := s.linkFns[message.TransferMsg.Key]
if ok {
callFn(fn)
}
if s.defaultFns != nil {
callFn(s.defaultFns)
}
}
func (s *ServerCommon) send(c *ClientConn, msg TransferMsg) (WaitMsg, error) {
if s.udpListener != nil {
return s.sendUDP(c, msg)
}
return s.sendTU(c, msg)
}
func (s *ServerCommon) sendTU(c *ClientConn, msg TransferMsg) (WaitMsg, error) {
var wait WaitMsg
if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 {
msg.ID = atomic.AddUint64(&s.msgID, 1)
}
data, err := s.sequenceEn(msg)
if err != nil {
return WaitMsg{}, err
}
data = c.msgEn(c.SecretKey, data)
data = s.queue.BuildMessage(data)
if c.maxWriteTimeout.Seconds() != 0 {
c.tuConn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
}
_, err = c.tuConn.Write(data)
//fmt.Println("resend:", float64(time.Now().UnixNano()-nowd)/1000000)
if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_SYS_WAIT) {
wait.Time = time.Now()
wait.TransferMsg = msg
wait.Reply = make(chan Message, 1)
s.noFinSyncMsgPool.Store(msg.ID, wait)
}
return wait, err
}
func (s *ServerCommon) Send(c *ClientConn, key string, value MsgVal) error {
_, err := s.send(c, TransferMsg{
Key: key,
Value: value,
Type: MSG_ASYNC,
})
return err
}
func (s *ServerCommon) sendWait(c *ClientConn, msg TransferMsg, timeout time.Duration) (Message, error) {
data, err := s.send(c, msg)
if err != nil {
return Message{}, err
}
if timeout.Seconds() == 0 {
msg, ok := <-data.Reply
if !ok {
return msg, os.ErrInvalid
}
return msg, nil
}
select {
case <-time.After(timeout):
close(data.Reply)
s.noFinSyncMsgPool.Delete(data.TransferMsg.ID)
return Message{}, os.ErrDeadlineExceeded
case <-s.stopCtx.Done():
return Message{}, errors.New("Service shutdown")
case msg, ok := <-data.Reply:
if !ok {
return msg, os.ErrInvalid
}
return msg, nil
}
}
func (s *ServerCommon) SendCtx(ctx context.Context, c *ClientConn, key string, value MsgVal) (Message, error) {
return s.sendCtx(c, TransferMsg{
Key: key,
Value: value,
Type: MSG_SYNC_ASK,
}, ctx)
}
func (s *ServerCommon) sendCtx(c *ClientConn, msg TransferMsg, ctx context.Context) (Message, error) {
data, err := s.send(c, msg)
if err != nil {
return Message{}, err
}
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
close(data.Reply)
s.noFinSyncMsgPool.Delete(data.TransferMsg.ID)
return Message{}, os.ErrClosed
case <-s.stopCtx.Done():
return Message{}, errors.New("Service shutdown")
case msg, ok := <-data.Reply:
if !ok {
return msg, os.ErrInvalid
}
return msg, nil
}
}
func (s *ServerCommon) SendWait(c *ClientConn, key string, value MsgVal, timeout time.Duration) (Message, error) {
return s.sendWait(c, TransferMsg{
Key: key,
Value: value,
Type: MSG_SYNC_ASK,
}, timeout)
}
func (s *ServerCommon) SendWaitObj(c *ClientConn, key string, value interface{}, timeout time.Duration) (Message, error) {
data, err := s.sequenceEn(value)
if err != nil {
return Message{}, err
}
return s.SendWait(c, key, data, timeout)
}
func (s *ServerCommon) SendObjCtx(ctx context.Context, c *ClientConn, key string, val interface{}) (Message, error) {
data, err := s.sequenceEn(val)
if err != nil {
return Message{}, err
}
return s.sendCtx(c, TransferMsg{
Key: key,
Value: data,
Type: MSG_SYNC_ASK,
}, ctx)
}
func (s *ServerCommon) SendObj(c *ClientConn, key string, val interface{}) error {
data, err := encode(val)
if err != nil {
return err
}
_, err = s.send(c, TransferMsg{
Key: key,
Value: data,
Type: MSG_ASYNC,
})
return err
}
func (s *ServerCommon) Reply(m Message, value MsgVal) error {
return m.Reply(value)
}
//for udp below
func (s *ServerCommon) ListenUDP(network string, addr string) error {
udpAddr, err := net.ResolveUDPAddr(network, addr)
if err != nil {
return err
}
listener, err := net.ListenUDP(network, udpAddr)
if err != nil {
return err
}
s.alive.Store(true)
s.status.Alive = true
s.udpListener = listener
go s.accept()
go s.monitorPool()
go s.loadMessage()
return nil
}
func (s *ServerCommon) acceptUDP() {
for {
select {
case <-s.stopCtx.Done():
return
default:
}
if s.maxReadTimeout.Seconds() > 0 {
s.udpListener.SetReadDeadline(time.Now().Add(s.maxReadTimeout))
}
data := make([]byte, 4096)
num, addr, err := s.udpListener.ReadFromUDP(data)
id := addr.String()
//fmt.Println("s recv udp:", float64(time.Now().UnixNano()-nowd)/1000000)
s.mu.RLock()
if _, ok := s.clientPool[id]; !ok {
s.mu.RUnlock()
client := ClientConn{
ClientID: id,
ClientAddr: addr,
server: s,
maxReadTimeout: s.maxReadTimeout,
maxWriteTimeout: s.maxWriteTimeout,
SecretKey: s.SecretKey,
handshakeRsaKey: s.handshakeRsaKey,
msgEn: s.defaultMsgEn,
msgDe: s.defaultMsgDe,
lastHeartBeat: time.Now().Unix(),
}
client.stopCtx, client.stopFn = context.WithCancel(context.Background())
s.mu.Lock()
s.clientPool[id] = &client
s.mu.Unlock()
} else {
s.mu.RUnlock()
}
if err == os.ErrDeadlineExceeded {
if num != 0 {
s.pushMessage(data[:num], id)
}
continue
}
if err != nil {
continue
}
s.pushMessage(data[:num], id)
}
}
func (s *ServerCommon) sendUDP(c *ClientConn, msg TransferMsg) (WaitMsg, error) {
var wait WaitMsg
if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 {
msg.ID = uint64(time.Now().UnixNano()) + rand.Uint64() + rand.Uint64()
}
data, err := s.sequenceEn(msg)
if err != nil {
return WaitMsg{}, err
}
data = c.msgEn(c.SecretKey, data)
data = s.queue.BuildMessage(data)
if c.maxWriteTimeout.Seconds() != 0 {
s.udpListener.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
}
_, err = s.udpListener.WriteTo(data, c.ClientAddr)
if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_SYS_WAIT) {
wait.Time = time.Now()
wait.TransferMsg = msg
wait.Reply = make(chan Message, 1)
s.noFinSyncMsgPool.Store(msg.ID, wait)
}
return wait, err
}
func (s *ServerCommon) StopMonitorChan() <-chan struct{} {
return s.stopCtx.Done()
}
func (s *ServerCommon) Status() Status {
return s.status
}
func (s *ServerCommon) GetSecretKey() []byte {
return s.SecretKey
}
func (s *ServerCommon) SetSecretKey(key []byte) {
s.SecretKey = key
}
func (s *ServerCommon) RsaPrivKey() []byte {
return s.handshakeRsaKey
}
func (s *ServerCommon) SetRsaPrivKey(key []byte) {
s.handshakeRsaKey = key
}
func (s *ServerCommon) GetClient(id string) *ClientConn {
s.mu.RLock()
defer s.mu.RUnlock()
c, ok := s.clientPool[id]
if !ok {
return nil
}
return c
}
func (s *ServerCommon) GetClientLists() []*ClientConn {
s.mu.RLock()
defer s.mu.RUnlock()
var list []*ClientConn = make([]*ClientConn, 0, len(s.clientPool))
for _, v := range s.clientPool {
list = append(list, v)
}
return list
}
func (s *ServerCommon) GetClientAddrs() []net.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
var list = make([]net.Addr, 0, len(s.clientPool))
for _, v := range s.clientPool {
list = append(list, v.ClientAddr)
}
return list
}
func (s *ServerCommon) GetSequenceEn() func(interface{}) ([]byte, error) {
return s.sequenceEn
}
func (s *ServerCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) {
s.sequenceEn = fn
}
func (s *ServerCommon) GetSequenceDe() func([]byte) (interface{}, error) {
return s.sequenceDe
}
func (s *ServerCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) {
s.sequenceDe = fn
}
func (s *ServerCommon) HeartbeatTimeoutSec() int64 {
return s.maxHeartbeatLostSeconds
}
func (s *ServerCommon) SetHeartbeatTimeoutSec(sec int64) {
s.maxHeartbeatLostSeconds = sec
}