You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
star/net/natserver.go

212 lines
5.6 KiB
Go

package net
import (
"b612.me/starlog"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
)
// MSG_CMD_HELLO 控制链路主动链接参头 16byte
var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA11965122519670220")
var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014")
// MSG_NEW_CONN_HELLO 交链路主动连接头 16byte
var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612")
var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008")
type NatServer struct {
sync.RWMutex
cmdTCPConn net.Conn
listenTcp net.Listener
listenUDP *net.UDPConn
ListenAddr string
lastTCPHeart int64
lastUDPHeart int64
Passwd string
NetTimeout int64
UDPTimeout int64
running int32
tcpConnPool chan net.Conn
stopCtx context.Context
stopFn context.CancelFunc
enableTCP bool
enableUDP bool
}
func (n *NatServer) Run() error {
if n.running != 0 {
return fmt.Errorf("Server Already Run")
}
n.stopCtx, n.stopFn = context.WithCancel(context.Background())
if n.NetTimeout == 0 {
n.NetTimeout = 10000
}
if n.Passwd != "" {
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16]
}
if n.enableTCP {
go n.runTcpListen()
}
return nil
}
func (n *NatServer) runTcpListen() error {
var err error
n.tcpConnPool = make(chan net.Conn, 128)
atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1)
starlog.Infoln("nat server tcp listener start run")
n.listenTcp, err = net.Listen("tcp", n.ListenAddr)
if err != nil {
starlog.Errorln("nat server tcp listener start failed:", err)
return err
}
msgChan := make(chan []byte, 16)
for {
conn, err := n.listenTcp.Accept()
if err != nil {
continue
}
var ok bool
if n.cmdTCPConn == nil {
if conn, ok = n.checkIsTcpControlConn(conn); ok {
n.cmdTCPConn = conn
conn.Write(MSG_CMD_HELLO_REPLY)
go n.handleTcpControlConn(conn, msgChan)
continue
}
}
if conn, ok = n.checkIsTcpNewConn(conn); ok {
starlog.Noticef("new tcp cmd conn is client conn %v\n", conn.RemoteAddr().String())
n.tcpConnPool <- conn
continue
}
starlog.Noticef("new tcp cmd conn is not client conn %v\n", conn.RemoteAddr().String())
go func() {
msgChan <- MSG_NEW_CONN_HELLO
}()
go n.pairNewClientConn(conn)
}
}
func (n *NatServer) pairNewClientConn(conn net.Conn) {
log := starlog.Std.NewFlag()
log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String())
select {
case <-time.After(time.Millisecond * time.Duration(n.NetTimeout)):
log.Errorln("pair new conn fail,wait timeout,conn is:", conn)
conn.Close()
return
case nconn := <-n.tcpConnPool:
log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String())
go io.Copy(nconn, conn)
go io.Copy(conn, nconn)
return
}
}
func (n *NatServer) handleTcpControlConn(conn net.Conn, msg chan []byte) {
go func() {
for {
select {
case data := <-msg:
_, err := conn.Write(data)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
case <-time.After(time.Minute):
_, err := conn.Write(MSG_HEARTBEAT)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
}
}
}()
for {
header := make([]byte, 16)
_, err := io.ReadFull(conn, header)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
if bytes.Equal(header, MSG_HEARTBEAT) {
n.lastTCPHeart = time.Now().Unix()
}
continue
}
}
func (n *NatServer) checkIsTcpControlConn(conn net.Conn) (net.Conn, bool) {
log := starlog.Std.NewFlag()
log.Noticef("start check tcp cmd conn %v\n", conn.RemoteAddr().String())
header := make([]byte, 16)
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
count, err := io.ReadFull(conn, header)
conn.SetReadDeadline(time.Time{})
if err == nil {
if bytes.Equal(header, MSG_CMD_HELLO) {
log.Infof("check tcp cmd conn success:%v\n", conn.RemoteAddr().String())
return conn, true
}
}
log.Infof("check tcp cmd conn fail:%v %v\n", conn.RemoteAddr().String(), err)
return NewCensorConn(header[:count], conn), false
}
func (n *NatServer) checkIsTcpNewConn(conn net.Conn) (net.Conn, bool) {
if n.cmdTCPConn == nil {
return conn, false
}
remoteIp := strings.Split(n.cmdTCPConn.RemoteAddr().String(), ":")[0]
newConnIp := strings.Split(conn.RemoteAddr().String(), ":")[0]
if remoteIp != newConnIp {
return conn, false
}
header := make([]byte, 16)
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
read, err := io.ReadFull(conn, header)
conn.SetReadDeadline(time.Time{})
if err == nil {
if bytes.Equal(header, MSG_NEW_CONN_HELLO) {
return conn, true
}
}
return NewCensorConn(header[:read], conn), false
}
type censorConn struct {
reader io.Reader
conn net.Conn
}
func NewCensorConn(header []byte, conn net.Conn) censorConn {
return censorConn{
reader: io.MultiReader(bytes.NewReader(header), conn),
conn: conn,
}
}
func (c censorConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c censorConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
func (c censorConn) Close() error { return c.conn.Close() }
func (c censorConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c censorConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c censorConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c censorConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c censorConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }