diff --git a/que.go b/que.go index 5fcf929..95d2193 100644 --- a/que.go +++ b/que.go @@ -5,6 +5,7 @@ import ( "context" "encoding/binary" "errors" + "os" "sync" "time" ) @@ -21,10 +22,11 @@ type MsgQueue struct { // StarQueue 为流数据中的消息队列分发 type StarQueue struct { + count int64 Encode bool Reserve uint16 Msgid uint16 - MsgPool []MsgQueue + MsgPool chan MsgQueue UnFinMsg sync.Map LastID int //= -1 ctx context.Context @@ -32,18 +34,30 @@ type StarQueue struct { duration time.Duration EncodeFunc func([]byte) []byte DecodeFunc func([]byte) []byte - //parseMu sync.Mutex - restoreMu sync.Mutex + //restoreMu sync.Mutex } -// NewQueue 建立一个新消息队列 -func NewQueue() *StarQueue { +func NewQueueCtx(ctx context.Context, count int64) *StarQueue { var que StarQueue que.Encode = false - que.ctx, que.cancel = context.WithCancel(context.Background()) + que.count = count + que.MsgPool = make(chan MsgQueue, count) + if ctx == nil { + que.ctx, que.cancel = context.WithCancel(context.Background()) + } else { + que.ctx, que.cancel = context.WithCancel(ctx) + } que.duration = 0 return &que } +func NewQueueWithCount(count int64) *StarQueue { + return NewQueueCtx(nil, count) +} + +// NewQueue 建立一个新消息队列 +func NewQueue() *StarQueue { + return NewQueueWithCount(32) +} // Uint32ToByte 4位uint32转byte func Uint32ToByte(src uint32) []byte { @@ -112,8 +126,17 @@ type unFinMsg struct { RecvMsg []byte } +func (que *StarQueue) push2list(msg MsgQueue) { + que.MsgPool <- msg +} + // ParseMessage 用于解析收到的msg信息 func (que *StarQueue) ParseMessage(msg []byte, conn interface{}) error { + return que.parseMessage(msg, conn) +} + +// parseMessage 用于解析收到的msg信息 +func (que *StarQueue) parseMessage(msg []byte, conn interface{}) error { tmp, ok := que.UnFinMsg.Load(conn) if ok { //存在未完成的信息 lastMsg := tmp.(*unFinMsg) @@ -136,7 +159,7 @@ func (que *StarQueue) ParseMessage(msg []byte, conn interface{}) error { if len(msg) == 0 { return nil } - return que.ParseMessage(msg, conn) + return que.parseMessage(msg, conn) } //获得本数据包长度 lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12]) @@ -156,38 +179,40 @@ func (que *StarQueue) ParseMessage(msg []byte, conn interface{}) error { lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg) } msg = msg[lastMsg.LengthRecv:] - stroeMsg := MsgQueue{ + storeMsg := MsgQueue{ ID: lastMsg.ID, Msg: lastMsg.RecvMsg, Conn: conn, } - que.MsgPool = append(que.MsgPool, stroeMsg) + //que.restoreMu.Lock() + que.push2list(storeMsg) + //que.restoreMu.Unlock() que.UnFinMsg.Delete(conn) - return que.ParseMessage(msg, conn) + return que.parseMessage(msg, conn) } } else { lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg) if lastID < 0 { que.UnFinMsg.Delete(conn) - return que.ParseMessage(msg, conn) + return que.parseMessage(msg, conn) } if len(msg) >= lastID { lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID]) if que.Encode { lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg) } - stroeMsg := MsgQueue{ + storeMsg := MsgQueue{ ID: lastMsg.ID, Msg: lastMsg.RecvMsg, Conn: conn, } - que.MsgPool = append(que.MsgPool, stroeMsg) + que.push2list(storeMsg) que.UnFinMsg.Delete(conn) if len(msg) == lastID { return nil } msg = msg[lastID:] - return que.ParseMessage(msg, conn) + return que.parseMessage(msg, conn) } lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg) que.UnFinMsg.Store(conn, lastMsg) @@ -204,7 +229,7 @@ func (que *StarQueue) ParseMessage(msg []byte, conn interface{}) error { msg = msg[start:] lastMsg := unFinMsg{} que.UnFinMsg.Store(conn, &lastMsg) - return que.ParseMessage(msg, conn) + return que.parseMessage(msg, conn) } func checkHeader(msg []byte) bool { @@ -250,38 +275,31 @@ func bytesMerge(src ...[]byte) []byte { } // Restore 获取收到的信息 -func (que *StarQueue) Restore(n int) ([]MsgQueue, error) { - que.restoreMu.Lock() - defer que.restoreMu.Unlock() - var res []MsgQueue - dura := time.Duration(0) - for len(que.MsgPool) < n { +func (que *StarQueue) Restore() (MsgQueue, error) { + if que.duration.Seconds() == 0 { + que.duration = 86400 * time.Second + } + for { select { case <-que.ctx.Done(): - return res, errors.New("Stoped By External Function Call") - default: - time.Sleep(time.Millisecond * 20) - dura = time.Millisecond*20 + dura - if que.duration != 0 && dura > que.duration { - return res, errors.New("Time Exceed") + return MsgQueue{}, errors.New("Stoped By External Function Call") + case <-time.After(que.duration): + if que.duration != 0 { + return MsgQueue{}, os.ErrDeadlineExceeded + } + case data, ok := <-que.MsgPool: + if !ok { + return MsgQueue{}, os.ErrClosed } + return data, nil } } - if len(que.MsgPool) < n { - return res, errors.New("Result Not Enough") - } - res = que.MsgPool[0:n] - que.MsgPool = que.MsgPool[n:] - return res, nil } // RestoreOne 获取收到的一个信息 +//兼容性修改 func (que *StarQueue) RestoreOne() (MsgQueue, error) { - data, err := que.Restore(1) - if len(data) == 1 { - return data[0], err - } - return MsgQueue{}, err + return que.Restore() } // Stop 立即停止Restore @@ -293,3 +311,7 @@ func (que *StarQueue) Stop() { func (que *StarQueue) RestoreDuration(tm time.Duration) { que.duration = tm } + +func (que *StarQueue) RestoreChan() <-chan MsgQueue { + return que.MsgPool +} diff --git a/que_test.go b/que_test.go new file mode 100644 index 0000000..a1a9bf1 --- /dev/null +++ b/que_test.go @@ -0,0 +1,42 @@ +package starnet + +import ( + "fmt" + "testing" + "time" +) + +func Test_QueSpeed(t *testing.T) { + que := NewQueueWithCount(0) + stop := make(chan struct{}, 1) + que.RestoreDuration(time.Second * 10) + var count int64 + go func() { + for { + select { + case <-stop: + //fmt.Println(count) + return + default: + } + _, err := que.RestoreOne() + if err == nil { + count++ + } + } + }() + cp := 0 + stoped := time.After(time.Second * 10) + data := que.BuildMessage([]byte("hello")) + for { + select { + case <-stoped: + fmt.Println(count, cp) + stop <- struct{}{} + return + default: + que.ParseMessage(data, "lala") + cp++ + } + } +}