package net import ( "b612.me/starlog" "bytes" "context" "crypto/sha256" "encoding/hex" "fmt" "io" "net" "strings" "sync" "sync/atomic" "time" ) // MSG_CMD_HELLO 控制链路主动链接参头 16byte var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA11965122519670220") var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014") // MSG_NEW_CONN_HELLO 交链路主动连接头 16byte var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612") // MSG_HEARTBEAT 心跳报文 16byte var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008") type NatServer struct { sync.RWMutex cmdTCPConn net.Conn listenTcp net.Listener listenUDP *net.UDPConn udpConnMap sync.Map udpPairMap sync.Map udpCmdAddr *net.UDPAddr ListenAddr string lastTCPHeart int64 lastUDPHeart int64 Passwd string NetTimeout int64 UDPTimeout int64 running int32 tcpConnPool chan net.Conn udpConnPool chan addionData stopCtx context.Context stopFn context.CancelFunc enableTCP bool enableUDP bool } func (n *NatServer) Run() error { if n.running != 0 { return fmt.Errorf("Server Already Run") } n.stopCtx, n.stopFn = context.WithCancel(context.Background()) if n.NetTimeout == 0 { n.NetTimeout = 10000 } if n.Passwd != "" { MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16] } var wg sync.WaitGroup if n.enableUDP { wg.Add(1) go func() { defer wg.Done() n.runUdpListen() }() } if n.enableTCP { wg.Add(1) go func() { defer wg.Done() n.runTcpListen() }() } wg.Wait() return nil } func (n *NatServer) runTcpListen() error { var err error n.tcpConnPool = make(chan net.Conn, 128) atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) starlog.Infoln("nat server tcp listener start run") n.listenTcp, err = net.Listen("tcp", n.ListenAddr) if err != nil { starlog.Errorln("nat server tcp listener start failed:", err) return err } msgChan := make(chan []byte, 16) for { conn, err := n.listenTcp.Accept() if err != nil { continue } var ok bool if n.cmdTCPConn == nil { if conn, ok = n.checkIsTcpControlConn(conn); ok { n.cmdTCPConn = conn conn.Write(MSG_CMD_HELLO_REPLY) go n.handleTcpControlConn(conn, msgChan) continue } } if conn, ok = n.checkIsTcpNewConn(conn); ok { starlog.Noticef("new tcp cmd conn is client conn %v\n", conn.RemoteAddr().String()) n.tcpConnPool <- conn continue } starlog.Noticef("new tcp cmd conn is not client conn %v\n", conn.RemoteAddr().String()) go func() { msgChan <- MSG_NEW_CONN_HELLO }() go n.pairNewClientConn(conn) } } func (n *NatServer) runUdpListen() error { var err error atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) starlog.Infoln("nat server udp listener start run") if n.UDPTimeout == 0 { n.UDPTimeout = 120 } n.udpConnPool = make(chan addionData, 128) udpListenAddr, err := net.ResolveUDPAddr("udp", n.ListenAddr) if err != nil { starlog.Errorln("nat server udp listener start failed:", err) return err } n.listenUDP, err = net.ListenUDP("udp", udpListenAddr) if err != nil { starlog.Errorln("nat server tcp listener start failed:", err) return err } go func() { for { select { case <-n.stopCtx.Done(): if n.listenUDP != nil { n.listenUDP.Close() } case <-time.After(time.Second * 30): if time.Now().Unix()-n.lastUDPHeart > n.UDPTimeout { if n.udpCmdAddr != nil { n.udpCmdAddr = nil } } if n.udpCmdAddr != nil { n.listenUDP.WriteToUDP(MSG_HEARTBEAT, n.udpCmdAddr) } n.udpConnMap.Range(func(key, value interface{}) bool { if time.Now().Unix()-value.(addionData).lastHeartbeat > n.UDPTimeout { if taregt, ok := n.udpPairMap.Load(key); ok { n.udpConnMap.Delete(taregt) n.udpPairMap.Delete(taregt) } n.udpConnMap.Delete(key) n.udpPairMap.Delete(key) } return true }) } } }() for { data := make([]byte, 8192) c, udpAddr, err := n.listenUDP.ReadFromUDP(data) if err != nil { continue } n.handleUdpData(udpAddr, data[:c]) } } type addionData struct { lastHeartbeat int64 Addr *net.UDPAddr MsgFrom []byte } func (n *NatServer) handleUdpData(addr *net.UDPAddr, data []byte) { starlog.Infoln("handle udp data from:", addr.String()) if addr.String() == n.udpCmdAddr.String() && len(data) >= 16 { if bytes.Equal(data[:16], MSG_HEARTBEAT) { starlog.Infoln("recv udp cmd heartbeat") n.lastUDPHeart = time.Now().Unix() } return } if n.udpCmdAddr == nil { if len(data) >= 16 && bytes.Equal(data[:16], MSG_CMD_HELLO) { starlog.Infof("recv udp cmd hello from %v\n", addr.String()) n.udpCmdAddr = addr n.lastUDPHeart = time.Now().Unix() n.listenUDP.WriteToUDP(MSG_CMD_HELLO_REPLY, addr) return } } if _, ok := n.udpConnMap.Load(addr.IP.String()); ok { if target, ok := n.udpPairMap.Load(addr.IP.String()); ok { starlog.Infof("found udp pair data %v <=====> %v\n", addr.String(), target.(*net.UDPAddr).String()) rmt := target.(*net.UDPAddr) if _, ok := n.udpConnMap.Load(rmt.IP.String()); !ok { n.udpConnMap.Delete(addr.IP.String()) n.udpPairMap.Delete(addr.IP.String()) n.udpPairMap.Delete(rmt.IP.String()) starlog.Errorf("udp pair data %v <=====> %v fail,remote not found\n", addr.String(), rmt.String()) return } tmp, _ := n.udpConnMap.Load(addr.IP.String()) current := tmp.(addionData) current.lastHeartbeat = time.Now().Unix() n.udpConnMap.Store(addr.IP.String(), current) return } } if len(data) >= 16 { if bytes.Equal(data[:16], MSG_NEW_CONN_HELLO) { starlog.Infof("recv new udp conn hello from %v\n", addr.String()) if len(data) < 16 { data = data[16:] } else { data = []byte{} } n.udpConnMap.Store(addr.IP.String(), addionData{ lastHeartbeat: time.Now().Unix(), Addr: addr, }) n.udpConnPool <- addionData{ lastHeartbeat: time.Now().Unix(), Addr: addr, MsgFrom: data, } return } } starlog.Infof("wait pair udp conn %v\n", addr.String()) if n.udpCmdAddr == nil { starlog.Infof("wait pair udp conn %v fail,cmd addr is nil\n", addr.String()) return } else { n.listenUDP.WriteToUDP(MSG_NEW_CONN_HELLO, n.udpCmdAddr) } go func() { pairAddr := <-n.udpConnPool n.udpConnMap.Store(addr.String(), addionData{ lastHeartbeat: time.Now().Unix(), Addr: addr, }) n.udpPairMap.Store(addr.IP.String(), pairAddr.Addr) n.udpPairMap.Store(pairAddr.Addr.String(), addr.IP) starlog.Infof("pair udp conn %v <=====> %v\n", addr.String(), pairAddr.Addr.String()) if len(pairAddr.MsgFrom) > 0 { n.listenUDP.WriteToUDP(pairAddr.MsgFrom, addr) } n.listenUDP.WriteToUDP(data, pairAddr.Addr) }() } func (n *NatServer) pairNewClientConn(conn net.Conn) { log := starlog.Std.NewFlag() log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String()) select { case <-time.After(time.Millisecond * time.Duration(n.NetTimeout)): log.Errorln("pair new conn fail,wait timeout,conn is:", conn) conn.Close() return case nconn := <-n.tcpConnPool: log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String()) go func() { defer nconn.Close() defer conn.Close() io.Copy(nconn, conn) }() go func() { defer nconn.Close() defer conn.Close() io.Copy(conn, nconn) }() return } } func (n *NatServer) handleTcpControlConn(conn net.Conn, msg chan []byte) { go func() { for { select { case data := <-msg: _, err := conn.Write(data) if err != nil { conn.Close() n.cmdTCPConn = nil return } case <-time.After(time.Minute): _, err := conn.Write(MSG_HEARTBEAT) if err != nil { conn.Close() n.cmdTCPConn = nil return } } } }() for { header := make([]byte, 16) _, err := io.ReadFull(conn, header) if err != nil { conn.Close() n.cmdTCPConn = nil return } if bytes.Equal(header, MSG_HEARTBEAT) { n.lastTCPHeart = time.Now().Unix() } continue } } func (n *NatServer) checkIsTcpControlConn(conn net.Conn) (net.Conn, bool) { log := starlog.Std.NewFlag() log.Noticef("start check tcp cmd conn %v\n", conn.RemoteAddr().String()) header := make([]byte, 16) conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200)) count, err := io.ReadFull(conn, header) conn.SetReadDeadline(time.Time{}) if err == nil { if bytes.Equal(header, MSG_CMD_HELLO) { log.Infof("check tcp cmd conn success:%v\n", conn.RemoteAddr().String()) return conn, true } } log.Infof("check tcp cmd conn fail:%v %v\n", conn.RemoteAddr().String(), err) return NewCensorConn(header[:count], conn), false } func (n *NatServer) checkIsTcpNewConn(conn net.Conn) (net.Conn, bool) { if n.cmdTCPConn == nil { return conn, false } remoteIp := strings.Split(n.cmdTCPConn.RemoteAddr().String(), ":")[0] newConnIp := strings.Split(conn.RemoteAddr().String(), ":")[0] if remoteIp != newConnIp { return conn, false } header := make([]byte, 16) conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200)) read, err := io.ReadFull(conn, header) conn.SetReadDeadline(time.Time{}) if err == nil { if bytes.Equal(header, MSG_NEW_CONN_HELLO) { return conn, true } } return NewCensorConn(header[:read], conn), false } type censorConn struct { reader io.Reader conn net.Conn } func NewCensorConn(header []byte, conn net.Conn) censorConn { return censorConn{ reader: io.MultiReader(bytes.NewReader(header), conn), conn: conn, } } func (c censorConn) Read(p []byte) (int, error) { return c.reader.Read(p) } func (c censorConn) Write(p []byte) (int, error) { return c.conn.Write(p) } func (c censorConn) Close() error { return c.conn.Close() } func (c censorConn) LocalAddr() net.Addr { return c.conn.LocalAddr() } func (c censorConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } func (c censorConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } func (c censorConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } func (c censorConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }