bug fix:goroutine security improved

master
兔子 3 years ago
parent 79dcaaf249
commit 07e374b83f

@ -7,6 +7,7 @@ import (
"math/rand"
"net"
"strings"
"sync"
"time"
"b612.me/starnet"
@ -17,6 +18,7 @@ type StarNotifyC struct {
Connc net.Conn
dialTimeout time.Duration
clientSign map[string]chan string
mu sync.Mutex
// FuncLists 当不使用channel时使用此记录调用函数
FuncLists map[string]func(CMsg)
stopSign context.Context
@ -62,7 +64,9 @@ func (star *StarNotifyC) starinitc() {
func (star *StarNotifyC) Notify(key string) chan string {
if _, ok := star.clientSign[key]; !ok {
ch := make(chan string, 20)
star.mu.Lock()
star.clientSign[key] = ch
star.mu.Unlock()
}
return star.clientSign[key]
}
@ -71,7 +75,9 @@ func (star *StarNotifyC) store(key, value string) {
if _, ok := star.clientSign[key]; !ok {
ch := make(chan string, 20)
ch <- value
star.mu.Lock()
star.clientSign[key] = ch
star.mu.Unlock()
return
}
star.clientSign[key] <- value
@ -212,7 +218,7 @@ func (star *StarNotifyC) SendValueWait(name, value string, tmout time.Duration)
return CMsg{}, errors.New("Do Not Use UseChannel Mode!")
}
rand.Seed(time.Now().UnixNano())
mode := "cr" + fmt.Sprintf("%05d", rand.Intn(99999))
mode := "cr" + fmt.Sprintf("%d%06d", time.Now().UnixNano(), rand.Intn(999999))
var key []byte
for _, v := range []byte(name) {
if v == byte(124) || v == byte(92) {
@ -229,11 +235,15 @@ func (star *StarNotifyC) SendValueWait(name, value string, tmout time.Duration)
}
var source CMsg
source.wait = make(chan int, 2)
star.mu.Lock()
star.lockPool[mode] = source
star.mu.Unlock()
select {
case <-source.wait:
res := star.lockPool[mode]
star.mu.Lock()
delete(star.lockPool, mode)
star.mu.Unlock()
return res, nil
case <-tmceed:
return CMsg{}, errors.New("Time Exceed")
@ -263,7 +273,7 @@ func (star *StarNotifyC) cnotify() {
}
data, err := star.Queue.RestoreOne()
if err != nil {
time.Sleep(time.Millisecond * 20)
time.Sleep(time.Microsecond * 2)
continue
}
if string(data.Msg) == "b612ryzstop" {
@ -301,7 +311,9 @@ func (star *StarNotifyC) cnotify() {
sa.Key = key
sa.Value = value
sa.mode = mode
star.mu.Lock()
star.lockPool[mode] = sa
star.mu.Unlock()
sa.wait <- 1
} else {
if msg, ok := star.FuncLists[key]; ok {

@ -40,7 +40,7 @@ func Test_usechannel(t *testing.T) {
}
func Test_nochannel(t *testing.T) {
server, err := NewNotifyS("udp", "127.0.0.1:1926")
server, err := NewNotifyS("tcp", "127.0.0.1:1926")
if err != nil {
fmt.Println(err)
return
@ -53,7 +53,7 @@ func Test_nochannel(t *testing.T) {
}
return ""
})
client, err := NewNotifyC("udp", "127.0.0.1:1926")
client, err := NewNotifyC("tcp", "127.0.0.1:1926")
if err != nil {
fmt.Println(err)
return

@ -27,3 +27,11 @@ func Decode(src []byte) (interface{}, error) {
err := dec.Decode(&dst)
return dst, err
}
func (nmsg *SMsg) Decode() (interface{}, error) {
return Decode([]byte(nmsg.Value))
}
func (nmsg *CMsg) Decode() (interface{}, error) {
return Decode([]byte(nmsg.Value))
}

@ -8,6 +8,7 @@ import (
"math/rand"
"net"
"strings"
"sync"
"time"
"b612.me/starcrypto"
@ -39,13 +40,16 @@ type StarNotifyS struct {
Queue *starnet.StarQueue
// FuncLists 记录了被通知项所记录的函数
FuncLists map[string]func(SMsg) string
funcMu sync.Mutex
defaultFunc func(SMsg) string
Connected func(SMsg)
nickName map[string]string
stopSign context.Context
cancel context.CancelFunc
connPool map[string]net.Conn
connPool sync.Map
connMu sync.Mutex
lockPool map[string]SMsg
lockMu sync.Mutex
udpPool map[string]*net.UDPAddr
listener net.Listener
isUDP bool
@ -90,19 +94,22 @@ func (star *StarNotifyS) getName(conn string) string {
// GetConnPool 获取所有Client端信息
func (star *StarNotifyS) GetConnPool() []SMsg {
var result []SMsg
for _, v := range star.connPool {
star.connPool.Range(func(k, val interface{}) bool {
v := val.(net.Conn)
result = append(result, SMsg{Conn: v, mode: "pa", nickName: star.setNickName, getName: star.getName})
}
return true
})
for _, v := range star.udpPool {
result = append(result, SMsg{UDP: v, uconn: star.UDPConn, mode: "pa0", nickName: star.setNickName, getName: star.getName})
}
return result
}
// GetConnPool 获取所有Client端信息
// GetClient 获取所有Client端信息
func (star *StarNotifyS) GetClient(name string) (SMsg, error) {
if str, ok := star.nickName[name]; ok {
if conn, ok := star.connPool[str]; ok {
if tmp, ok := star.connPool.Load(str); ok {
conn := tmp.(net.Conn)
return SMsg{Conn: conn, mode: "pa", nickName: star.setNickName, getName: star.getName}, nil
}
if conn, ok := star.udpPool[str]; ok {
@ -188,7 +195,7 @@ func (star *StarNotifyS) SendWait(source SMsg, key, value string, tmout time.Dur
var err error
var tmceed <-chan time.Time
rand.Seed(time.Now().UnixNano())
mode := "sr" + fmt.Sprintf("%05d", rand.Intn(99999))
mode := "sr" + fmt.Sprintf("%d%06d", time.Now().UnixNano(), rand.Intn(999999))
if source.uconn == nil {
_, err = source.Conn.Write(builder.BuildMessage([]byte(mode + "||" + source.addSlash(key) + "||" + value)))
} else {
@ -201,11 +208,15 @@ func (star *StarNotifyS) SendWait(source SMsg, key, value string, tmout time.Dur
tmceed = time.After(tmout)
}
source.wait = make(chan int, 2)
star.lockMu.Lock()
star.lockPool[mode] = source
star.lockMu.Unlock()
select {
case <-source.wait:
star.lockMu.Lock()
res := star.lockPool[mode]
delete(star.lockPool, mode)
star.lockMu.Unlock()
return res, nil
case <-tmceed:
return SMsg{}, errors.New("Time Exceed")
@ -220,7 +231,6 @@ func (star *StarNotifyS) starinits() {
star.Queue.Encode = true
star.udpPool = make(map[string]*net.UDPAddr)
star.FuncLists = make(map[string]func(SMsg) string)
star.connPool = make(map[string]net.Conn)
star.nickName = make(map[string]string)
star.lockPool = make(map[string]SMsg)
star.Stop = make(chan int, 5)
@ -253,7 +263,9 @@ func doudps(netype, value string) (*StarNotifyS, error) {
<-star.stopSign.Done()
for k, v := range star.udpPool {
star.UDPConn.WriteToUDP(star.Queue.BuildMessage([]byte("b612ryzstop")), v)
star.connMu.Lock()
delete(star.udpPool, k)
star.connMu.Unlock()
for k2, v2 := range star.nickName {
if v2 == k {
delete(star.nickName, k2)
@ -275,7 +287,9 @@ func doudps(netype, value string) (*StarNotifyS, error) {
go star.Connected(SMsg{UDP: addr, uconn: star.UDPConn, nickName: star.setNickName, getName: star.getName})
}
}
star.connMu.Lock()
star.udpPool[addr.String()] = addr
star.connMu.Unlock()
}
if err != nil {
continue
@ -298,15 +312,20 @@ func notudps(netype, value string) (*StarNotifyS, error) {
go star.notify()
go func() {
<-star.stopSign.Done()
for k, v := range star.connPool {
star.connPool.Range(func(a, b interface{}) bool {
k := a.(string)
v := b.(net.Conn)
v.Close()
delete(star.connPool, k)
star.connPool.Delete(a)
for k2, v2 := range star.nickName {
if v2 == k {
star.funcMu.Lock()
delete(star.nickName, k2)
star.funcMu.Unlock()
}
}
}
return true
})
star.listener.Close()
star.Online = false
return
@ -341,7 +360,7 @@ func notudps(netype, value string) (*StarNotifyS, error) {
}
if err != nil {
conn.Close()
delete(star.connPool, fmt.Sprint(conn))
star.connPool.Delete(fmt.Sprint(conn))
for k, v := range star.nickName {
if v == fmt.Sprint(conn) {
delete(star.nickName, k)
@ -351,7 +370,7 @@ func notudps(netype, value string) (*StarNotifyS, error) {
}
}
}(conn)
star.connPool[fmt.Sprint(conn)] = conn
star.connPool.Store(fmt.Sprint(conn), conn)
if star.Connected != nil {
go star.Connected(SMsg{Conn: conn, nickName: star.setNickName, getName: star.getName})
}
@ -367,7 +386,7 @@ func (star *StarNotifyS) GetListenerInfo() net.Listener {
// SetNotify 用于设置通知关键词的调用函数
func (star *StarNotifyS) setNickName(name string, conn string) error {
if _, ok := star.connPool[conn]; !ok {
if _, ok := star.connPool.Load(conn); !ok {
if _, ok := star.udpPool[conn]; !ok {
return errors.New("Conn Not Found")
}
@ -377,12 +396,22 @@ func (star *StarNotifyS) setNickName(name string, conn string) error {
delete(star.nickName, k)
}
}
star.funcMu.Lock()
star.nickName[name] = conn
star.funcMu.Unlock()
return nil
}
// SetNotify 用于设置通知关键词的调用函数
func (star *StarNotifyS) SetNotify(name string, data func(SMsg) string) {
star.funcMu.Lock()
defer star.funcMu.Unlock()
if data == nil {
if _, ok := star.FuncLists[name]; ok {
delete(star.FuncLists, name)
}
return
}
star.FuncLists[name] = data
}
@ -414,7 +443,7 @@ func (star *StarNotifyS) notify() {
}
data, err := star.Queue.RestoreOne()
if err != nil {
time.Sleep(time.Millisecond * 20)
time.Sleep(time.Microsecond * 2)
continue
}
mode, key, value := star.analyseData(string(data.Msg))
@ -424,7 +453,9 @@ func (star *StarNotifyS) notify() {
} else {
rmsg = SMsg{nil, key, value, data.Conn.(*net.UDPAddr), star.UDPConn, mode, nil, star.setNickName, star.getName}
if key == "b612ryzstop" {
star.connMu.Lock()
delete(star.udpPool, rmsg.UDP.String())
star.connMu.Unlock()
for k, v := range star.nickName {
if v == rmsg.UDP.String() {
delete(star.nickName, k)
@ -451,7 +482,7 @@ func (star *StarNotifyS) notify() {
}
}
if mode[0:2] != "sr" {
if star.Sync {
if !star.Sync {
go replyFunc(key, rmsg)
} else {
replyFunc(key, rmsg)
@ -459,10 +490,12 @@ func (star *StarNotifyS) notify() {
} else {
if sa, ok := star.lockPool[mode]; ok {
rmsg.wait = sa.wait
star.lockMu.Lock()
star.lockPool[mode] = rmsg
star.lockPool[mode].wait <- 1
star.lockMu.Unlock()
} else {
if star.Sync {
if !star.Sync {
go replyFunc(key, rmsg)
} else {
replyFunc(key, rmsg)

Loading…
Cancel
Save