From 07e374b83fb684b8c0e4cf700844e25319d6db6c Mon Sep 17 00:00:00 2001 From: 兔子 Date: Wed, 23 Dec 2020 20:50:57 +0800 Subject: [PATCH] bug fix:goroutine security improved --- client.go | 16 ++++++++++-- client_test.go | 4 +-- serialization.go | 8 ++++++ server.go | 65 ++++++++++++++++++++++++++++++++++++------------ 4 files changed, 73 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 506b881..a08ed40 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "math/rand" "net" "strings" + "sync" "time" "b612.me/starnet" @@ -17,6 +18,7 @@ type StarNotifyC struct { Connc net.Conn dialTimeout time.Duration clientSign map[string]chan string + mu sync.Mutex // FuncLists 当不使用channel时,使用此记录调用函数 FuncLists map[string]func(CMsg) stopSign context.Context @@ -62,7 +64,9 @@ func (star *StarNotifyC) starinitc() { func (star *StarNotifyC) Notify(key string) chan string { if _, ok := star.clientSign[key]; !ok { ch := make(chan string, 20) + star.mu.Lock() star.clientSign[key] = ch + star.mu.Unlock() } return star.clientSign[key] } @@ -71,7 +75,9 @@ func (star *StarNotifyC) store(key, value string) { if _, ok := star.clientSign[key]; !ok { ch := make(chan string, 20) ch <- value + star.mu.Lock() star.clientSign[key] = ch + star.mu.Unlock() return } star.clientSign[key] <- value @@ -212,7 +218,7 @@ func (star *StarNotifyC) SendValueWait(name, value string, tmout time.Duration) return CMsg{}, errors.New("Do Not Use UseChannel Mode!") } rand.Seed(time.Now().UnixNano()) - mode := "cr" + fmt.Sprintf("%05d", rand.Intn(99999)) + mode := "cr" + fmt.Sprintf("%d%06d", time.Now().UnixNano(), rand.Intn(999999)) var key []byte for _, v := range []byte(name) { if v == byte(124) || v == byte(92) { @@ -229,11 +235,15 @@ func (star *StarNotifyC) SendValueWait(name, value string, tmout time.Duration) } var source CMsg source.wait = make(chan int, 2) + star.mu.Lock() star.lockPool[mode] = source + star.mu.Unlock() select { case <-source.wait: res := star.lockPool[mode] + star.mu.Lock() delete(star.lockPool, mode) + star.mu.Unlock() return res, nil case <-tmceed: return CMsg{}, errors.New("Time Exceed") @@ -263,7 +273,7 @@ func (star *StarNotifyC) cnotify() { } data, err := star.Queue.RestoreOne() if err != nil { - time.Sleep(time.Millisecond * 20) + time.Sleep(time.Microsecond * 2) continue } if string(data.Msg) == "b612ryzstop" { @@ -301,7 +311,9 @@ func (star *StarNotifyC) cnotify() { sa.Key = key sa.Value = value sa.mode = mode + star.mu.Lock() star.lockPool[mode] = sa + star.mu.Unlock() sa.wait <- 1 } else { if msg, ok := star.FuncLists[key]; ok { diff --git a/client_test.go b/client_test.go index b3561c9..b21e0b7 100644 --- a/client_test.go +++ b/client_test.go @@ -40,7 +40,7 @@ func Test_usechannel(t *testing.T) { } func Test_nochannel(t *testing.T) { - server, err := NewNotifyS("udp", "127.0.0.1:1926") + server, err := NewNotifyS("tcp", "127.0.0.1:1926") if err != nil { fmt.Println(err) return @@ -53,7 +53,7 @@ func Test_nochannel(t *testing.T) { } return "" }) - client, err := NewNotifyC("udp", "127.0.0.1:1926") + client, err := NewNotifyC("tcp", "127.0.0.1:1926") if err != nil { fmt.Println(err) return diff --git a/serialization.go b/serialization.go index 9871e2d..c10d889 100644 --- a/serialization.go +++ b/serialization.go @@ -27,3 +27,11 @@ func Decode(src []byte) (interface{}, error) { err := dec.Decode(&dst) return dst, err } + +func (nmsg *SMsg) Decode() (interface{}, error) { + return Decode([]byte(nmsg.Value)) +} + +func (nmsg *CMsg) Decode() (interface{}, error) { + return Decode([]byte(nmsg.Value)) +} diff --git a/server.go b/server.go index 9da2169..1d6e12c 100644 --- a/server.go +++ b/server.go @@ -8,6 +8,7 @@ import ( "math/rand" "net" "strings" + "sync" "time" "b612.me/starcrypto" @@ -39,13 +40,16 @@ type StarNotifyS struct { Queue *starnet.StarQueue // FuncLists 记录了被通知项所记录的函数 FuncLists map[string]func(SMsg) string + funcMu sync.Mutex defaultFunc func(SMsg) string Connected func(SMsg) nickName map[string]string stopSign context.Context cancel context.CancelFunc - connPool map[string]net.Conn + connPool sync.Map + connMu sync.Mutex lockPool map[string]SMsg + lockMu sync.Mutex udpPool map[string]*net.UDPAddr listener net.Listener isUDP bool @@ -90,19 +94,22 @@ func (star *StarNotifyS) getName(conn string) string { // GetConnPool 获取所有Client端信息 func (star *StarNotifyS) GetConnPool() []SMsg { var result []SMsg - for _, v := range star.connPool { + star.connPool.Range(func(k, val interface{}) bool { + v := val.(net.Conn) result = append(result, SMsg{Conn: v, mode: "pa", nickName: star.setNickName, getName: star.getName}) - } + return true + }) for _, v := range star.udpPool { result = append(result, SMsg{UDP: v, uconn: star.UDPConn, mode: "pa0", nickName: star.setNickName, getName: star.getName}) } return result } -// GetConnPool 获取所有Client端信息 +// GetClient 获取所有Client端信息 func (star *StarNotifyS) GetClient(name string) (SMsg, error) { if str, ok := star.nickName[name]; ok { - if conn, ok := star.connPool[str]; ok { + if tmp, ok := star.connPool.Load(str); ok { + conn := tmp.(net.Conn) return SMsg{Conn: conn, mode: "pa", nickName: star.setNickName, getName: star.getName}, nil } if conn, ok := star.udpPool[str]; ok { @@ -188,7 +195,7 @@ func (star *StarNotifyS) SendWait(source SMsg, key, value string, tmout time.Dur var err error var tmceed <-chan time.Time rand.Seed(time.Now().UnixNano()) - mode := "sr" + fmt.Sprintf("%05d", rand.Intn(99999)) + mode := "sr" + fmt.Sprintf("%d%06d", time.Now().UnixNano(), rand.Intn(999999)) if source.uconn == nil { _, err = source.Conn.Write(builder.BuildMessage([]byte(mode + "||" + source.addSlash(key) + "||" + value))) } else { @@ -201,11 +208,15 @@ func (star *StarNotifyS) SendWait(source SMsg, key, value string, tmout time.Dur tmceed = time.After(tmout) } source.wait = make(chan int, 2) + star.lockMu.Lock() star.lockPool[mode] = source + star.lockMu.Unlock() select { case <-source.wait: + star.lockMu.Lock() res := star.lockPool[mode] delete(star.lockPool, mode) + star.lockMu.Unlock() return res, nil case <-tmceed: return SMsg{}, errors.New("Time Exceed") @@ -220,7 +231,6 @@ func (star *StarNotifyS) starinits() { star.Queue.Encode = true star.udpPool = make(map[string]*net.UDPAddr) star.FuncLists = make(map[string]func(SMsg) string) - star.connPool = make(map[string]net.Conn) star.nickName = make(map[string]string) star.lockPool = make(map[string]SMsg) star.Stop = make(chan int, 5) @@ -253,7 +263,9 @@ func doudps(netype, value string) (*StarNotifyS, error) { <-star.stopSign.Done() for k, v := range star.udpPool { star.UDPConn.WriteToUDP(star.Queue.BuildMessage([]byte("b612ryzstop")), v) + star.connMu.Lock() delete(star.udpPool, k) + star.connMu.Unlock() for k2, v2 := range star.nickName { if v2 == k { delete(star.nickName, k2) @@ -275,7 +287,9 @@ func doudps(netype, value string) (*StarNotifyS, error) { go star.Connected(SMsg{UDP: addr, uconn: star.UDPConn, nickName: star.setNickName, getName: star.getName}) } } + star.connMu.Lock() star.udpPool[addr.String()] = addr + star.connMu.Unlock() } if err != nil { continue @@ -298,15 +312,20 @@ func notudps(netype, value string) (*StarNotifyS, error) { go star.notify() go func() { <-star.stopSign.Done() - for k, v := range star.connPool { + star.connPool.Range(func(a, b interface{}) bool { + k := a.(string) + v := b.(net.Conn) v.Close() - delete(star.connPool, k) + star.connPool.Delete(a) for k2, v2 := range star.nickName { if v2 == k { + star.funcMu.Lock() delete(star.nickName, k2) + star.funcMu.Unlock() } } - } + return true + }) star.listener.Close() star.Online = false return @@ -341,7 +360,7 @@ func notudps(netype, value string) (*StarNotifyS, error) { } if err != nil { conn.Close() - delete(star.connPool, fmt.Sprint(conn)) + star.connPool.Delete(fmt.Sprint(conn)) for k, v := range star.nickName { if v == fmt.Sprint(conn) { delete(star.nickName, k) @@ -351,7 +370,7 @@ func notudps(netype, value string) (*StarNotifyS, error) { } } }(conn) - star.connPool[fmt.Sprint(conn)] = conn + star.connPool.Store(fmt.Sprint(conn), conn) if star.Connected != nil { go star.Connected(SMsg{Conn: conn, nickName: star.setNickName, getName: star.getName}) } @@ -367,7 +386,7 @@ func (star *StarNotifyS) GetListenerInfo() net.Listener { // SetNotify 用于设置通知关键词的调用函数 func (star *StarNotifyS) setNickName(name string, conn string) error { - if _, ok := star.connPool[conn]; !ok { + if _, ok := star.connPool.Load(conn); !ok { if _, ok := star.udpPool[conn]; !ok { return errors.New("Conn Not Found") } @@ -377,12 +396,22 @@ func (star *StarNotifyS) setNickName(name string, conn string) error { delete(star.nickName, k) } } + star.funcMu.Lock() star.nickName[name] = conn + star.funcMu.Unlock() return nil } // SetNotify 用于设置通知关键词的调用函数 func (star *StarNotifyS) SetNotify(name string, data func(SMsg) string) { + star.funcMu.Lock() + defer star.funcMu.Unlock() + if data == nil { + if _, ok := star.FuncLists[name]; ok { + delete(star.FuncLists, name) + } + return + } star.FuncLists[name] = data } @@ -414,7 +443,7 @@ func (star *StarNotifyS) notify() { } data, err := star.Queue.RestoreOne() if err != nil { - time.Sleep(time.Millisecond * 20) + time.Sleep(time.Microsecond * 2) continue } mode, key, value := star.analyseData(string(data.Msg)) @@ -424,7 +453,9 @@ func (star *StarNotifyS) notify() { } else { rmsg = SMsg{nil, key, value, data.Conn.(*net.UDPAddr), star.UDPConn, mode, nil, star.setNickName, star.getName} if key == "b612ryzstop" { + star.connMu.Lock() delete(star.udpPool, rmsg.UDP.String()) + star.connMu.Unlock() for k, v := range star.nickName { if v == rmsg.UDP.String() { delete(star.nickName, k) @@ -451,7 +482,7 @@ func (star *StarNotifyS) notify() { } } if mode[0:2] != "sr" { - if star.Sync { + if !star.Sync { go replyFunc(key, rmsg) } else { replyFunc(key, rmsg) @@ -459,10 +490,12 @@ func (star *StarNotifyS) notify() { } else { if sa, ok := star.lockPool[mode]; ok { rmsg.wait = sa.wait + star.lockMu.Lock() star.lockPool[mode] = rmsg star.lockPool[mode].wait <- 1 + star.lockMu.Unlock() } else { - if star.Sync { + if !star.Sync { go replyFunc(key, rmsg) } else { replyFunc(key, rmsg)