package net import ( "b612.me/starlog" "bytes" "context" "crypto/sha256" "encoding/hex" "fmt" "io" "net" "strings" "sync" "sync/atomic" "time" ) // MSG_CMD_HELLO 控制链路主动链接参头 16byte var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA11965122519670220") var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014") // MSG_NEW_CONN_HELLO 交链路主动连接头 16byte var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612") var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008") type NatServer struct { sync.RWMutex cmdTCPConn net.Conn listenTcp net.Listener listenUDP *net.UDPConn ListenAddr string lastTCPHeart int64 lastUDPHeart int64 Passwd string NetTimeout int64 UDPTimeout int64 running int32 tcpConnPool chan net.Conn stopCtx context.Context stopFn context.CancelFunc enableTCP bool enableUDP bool } func (n *NatServer) Run() error { if n.running != 0 { return fmt.Errorf("Server Already Run") } n.stopCtx, n.stopFn = context.WithCancel(context.Background()) if n.NetTimeout == 0 { n.NetTimeout = 10000 } if n.Passwd != "" { MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16] } if n.enableTCP { go n.runTcpListen() } return nil } func (n *NatServer) runTcpListen() error { var err error n.tcpConnPool = make(chan net.Conn, 128) atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) starlog.Infoln("nat server tcp listener start run") n.listenTcp, err = net.Listen("tcp", n.ListenAddr) if err != nil { starlog.Errorln("nat server tcp listener start failed:", err) return err } msgChan := make(chan []byte, 16) for { conn, err := n.listenTcp.Accept() if err != nil { continue } var ok bool if n.cmdTCPConn == nil { if conn, ok = n.checkIsTcpControlConn(conn); ok { n.cmdTCPConn = conn conn.Write(MSG_CMD_HELLO_REPLY) go n.handleTcpControlConn(conn, msgChan) continue } } if conn, ok = n.checkIsTcpNewConn(conn); ok { starlog.Noticef("new tcp cmd conn is client conn %v\n", conn.RemoteAddr().String()) n.tcpConnPool <- conn continue } starlog.Noticef("new tcp cmd conn is not client conn %v\n", conn.RemoteAddr().String()) go func() { msgChan <- MSG_NEW_CONN_HELLO }() go n.pairNewClientConn(conn) } } func (n *NatServer) pairNewClientConn(conn net.Conn) { log := starlog.Std.NewFlag() log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String()) select { case <-time.After(time.Millisecond * time.Duration(n.NetTimeout)): log.Errorln("pair new conn fail,wait timeout,conn is:", conn) conn.Close() return case nconn := <-n.tcpConnPool: log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String()) go io.Copy(nconn, conn) go io.Copy(conn, nconn) return } } func (n *NatServer) handleTcpControlConn(conn net.Conn, msg chan []byte) { go func() { for { select { case data := <-msg: _, err := conn.Write(data) if err != nil { conn.Close() n.cmdTCPConn = nil return } case <-time.After(time.Minute): _, err := conn.Write(MSG_HEARTBEAT) if err != nil { conn.Close() n.cmdTCPConn = nil return } } } }() for { header := make([]byte, 16) _, err := io.ReadFull(conn, header) if err != nil { conn.Close() n.cmdTCPConn = nil return } if bytes.Equal(header, MSG_HEARTBEAT) { n.lastTCPHeart = time.Now().Unix() } continue } } func (n *NatServer) checkIsTcpControlConn(conn net.Conn) (net.Conn, bool) { log := starlog.Std.NewFlag() log.Noticef("start check tcp cmd conn %v\n", conn.RemoteAddr().String()) header := make([]byte, 16) conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200)) count, err := io.ReadFull(conn, header) conn.SetReadDeadline(time.Time{}) if err == nil { if bytes.Equal(header, MSG_CMD_HELLO) { log.Infof("check tcp cmd conn success:%v\n", conn.RemoteAddr().String()) return conn, true } } log.Infof("check tcp cmd conn fail:%v %v\n", conn.RemoteAddr().String(), err) return NewCensorConn(header[:count], conn), false } func (n *NatServer) checkIsTcpNewConn(conn net.Conn) (net.Conn, bool) { if n.cmdTCPConn == nil { return conn, false } remoteIp := strings.Split(n.cmdTCPConn.RemoteAddr().String(), ":")[0] newConnIp := strings.Split(conn.RemoteAddr().String(), ":")[0] if remoteIp != newConnIp { return conn, false } header := make([]byte, 16) conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200)) read, err := io.ReadFull(conn, header) conn.SetReadDeadline(time.Time{}) if err == nil { if bytes.Equal(header, MSG_NEW_CONN_HELLO) { return conn, true } } return NewCensorConn(header[:read], conn), false } type censorConn struct { reader io.Reader conn net.Conn } func NewCensorConn(header []byte, conn net.Conn) censorConn { return censorConn{ reader: io.MultiReader(bytes.NewReader(header), conn), conn: conn, } } func (c censorConn) Read(p []byte) (int, error) { return c.reader.Read(p) } func (c censorConn) Write(p []byte) (int, error) { return c.conn.Write(p) } func (c censorConn) Close() error { return c.conn.Close() } func (c censorConn) LocalAddr() net.Addr { return c.conn.LocalAddr() } func (c censorConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } func (c censorConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } func (c censorConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } func (c censorConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }