diff --git a/httpreverse/service.go b/httpreverse/service.go index 2d0ae60..0f743ba 100644 --- a/httpreverse/service.go +++ b/httpreverse/service.go @@ -26,7 +26,7 @@ func (h *ReverseConfig) Run() error { } for key, proxy := range h.proxy { h.httpmux.HandleFunc(key, func(writer http.ResponseWriter, request *http.Request) { - starlog.Infof("<%s> Req Path:%s Addr:%s UA:%s\n", h.Name, request.URL.Path, request.RemoteAddr, request.Header.Get("User-Agent")) + starlog.Infof("<%s> Req Path:%s ListenAddr:%s UA:%s\n", h.Name, request.URL.Path, request.RemoteAddr, request.Header.Get("User-Agent")) if !h.BasicAuth(writer, request) { h.SetResponseHeader(writer) diff --git a/net/nat_test.go b/net/nat_test.go new file mode 100644 index 0000000..06b1978 --- /dev/null +++ b/net/nat_test.go @@ -0,0 +1,23 @@ +package net + +import ( + "testing" + "time" +) + +func TestNat(t *testing.T) { + var s = NatServer{ + ListenAddr: "0.0.0.0:10020", + enableTCP: true, + } + var c = NatClient{ + ServiceTarget: "139.199.163.65:80", + CmdTarget: "127.0.0.1:10020", + enableTCP: true, + } + go s.Run() + go c.Run() + for { + time.Sleep(time.Second * 20) + } +} diff --git a/net/natclient.go b/net/natclient.go index 2416d8b..e564159 100644 --- a/net/natclient.go +++ b/net/natclient.go @@ -1,27 +1,138 @@ package net import ( + "b612.me/starlog" + "bytes" + "context" + "crypto/sha256" + "io" "net" "sync" + "time" ) -type SimpleNatClient struct { +type NatClient struct { mu sync.RWMutex cmdTCPConn net.Conn cmdUDPConn *net.UDPAddr ServiceTarget string CmdTarget string tcpAlived bool + DialTimeout int + enableTCP bool + enableUDP bool + Passwd string + stopCtx context.Context + stopFn context.CancelFunc } -func (s *SimpleNatClient) tcpCmdConn() net.Conn { +func (s *NatClient) tcpCmdConn() net.Conn { s.mu.RLock() defer s.mu.RUnlock() return s.cmdTCPConn } -func (s *SimpleNatClient) tcpCmdConnAlived() bool { +func (s *NatClient) tcpCmdConnAlived() bool { s.mu.RLock() defer s.mu.RUnlock() return s.tcpAlived } + +func (s *NatClient) setTcpCmdConnAlived(v bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.tcpAlived = v +} + +func (s *NatClient) Run() { + s.stopCtx, s.stopFn = context.WithCancel(context.Background()) + if s.DialTimeout == 0 { + s.DialTimeout = 10000 + } + if s.Passwd != "" { + MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(s.Passwd)...))[:16] + } + if s.enableTCP { + s.runTcp() + } +} + +func (s *NatClient) runTcp() error { + var err error + starlog.Noticeln("nat client tcp module start run") + for { + select { + case <-s.stopCtx.Done(): + if s.cmdTCPConn != nil { + s.setTcpCmdConnAlived(false) + s.cmdTCPConn.Close() + return nil + } + case <-time.After(time.Millisecond * 1500): + } + if s.cmdTCPConn != nil && s.tcpCmdConnAlived() { + continue + } + s.cmdTCPConn, err = net.DialTimeout("tcp", s.CmdTarget, time.Millisecond*time.Duration(s.DialTimeout)) + if err != nil { + starlog.Errorf("dail remote tcp cmd server %v fail:%v;will retry\n", s.CmdTarget, err) + time.Sleep(time.Second * 2) + s.cmdTCPConn = nil + continue + } + starlog.Infoln("dail remote tcp cmd server ok,remote:", s.CmdTarget) + s.tcpCmdConn().Write(MSG_CMD_HELLO) + s.setTcpCmdConnAlived(true) + go s.handleTcpCmdConn(s.tcpCmdConn()) + } +} + +func (s *NatClient) handleTcpCmdConn(conn net.Conn) { + for { + header := make([]byte, 16) + _, err := io.ReadFull(conn, header) + if err != nil { + starlog.Infoln("tcp cmd server read fail:", err) + conn.Close() + s.setTcpCmdConnAlived(false) + return + } + if bytes.Equal(header, MSG_CMD_HELLO_REPLY) { + continue + } + if bytes.Equal(header, MSG_NEW_CONN_HELLO) { + go s.newRemoteTcpConn() + } + if bytes.Equal(header, MSG_HEARTBEAT) { + _, err = conn.Write(MSG_HEARTBEAT) + if err != nil { + conn.Close() + s.setTcpCmdConnAlived(false) + return + } + } + } +} + +func (s *NatClient) newRemoteTcpConn() { + log := starlog.Std.NewFlag() + starlog.Infoln("recv request,create new tcp conn") + nconn, err := net.DialTimeout("tcp", s.CmdTarget, time.Millisecond*time.Duration(s.DialTimeout)) + if err != nil { + log.Errorf("dail server tcp conn %v fail:%v\n", s.CmdTarget, err) + return + } + _, err = nconn.Write(MSG_NEW_CONN_HELLO) + if err != nil { + nconn.Close() + log.Errorf("write new client hello to server %v fail:%v\n", s.CmdTarget, err) + return + } + cconn, err := net.DialTimeout("tcp", s.ServiceTarget, time.Millisecond*time.Duration(s.DialTimeout)) + if err != nil { + log.Errorf("dail remote tcp conn %v fail:%v\n", s.CmdTarget, err) + return + } + go io.Copy(cconn, nconn) + go io.Copy(nconn, cconn) +} diff --git a/net/natserver.go b/net/natserver.go index 35d4c6c..fe0c372 100644 --- a/net/natserver.go +++ b/net/natserver.go @@ -1,40 +1,46 @@ package net import ( + "b612.me/starlog" "bytes" "context" + "crypto/sha256" "encoding/hex" "fmt" "io" "net" + "strings" "sync" "sync/atomic" "time" ) // MSG_CMD_HELLO 控制链路主动链接参头 16byte -var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA1") -var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA2") +var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA11965122519670220") +var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014") // MSG_NEW_CONN_HELLO 交链路主动连接头 16byte -var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF") +var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612") + +var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008") type NatServer struct { sync.RWMutex cmdTCPConn net.Conn listenTcp net.Listener listenUDP *net.UDPConn - Addr string - Port int + ListenAddr string lastTCPHeart int64 lastUDPHeart int64 Passwd string - DialTimeout int64 + NetTimeout int64 UDPTimeout int64 running int32 tcpConnPool chan net.Conn stopCtx context.Context stopFn context.CancelFunc + enableTCP bool + enableUDP bool } func (n *NatServer) Run() error { @@ -42,48 +48,164 @@ func (n *NatServer) Run() error { return fmt.Errorf("Server Already Run") } n.stopCtx, n.stopFn = context.WithCancel(context.Background()) - return nil -} - -func (n *NatServer) cmdTcploop(conn net.Conn) error { - var header = make([]byte, 16) - for { - c, err := conn.Read(header) - if err != nil { - //todo - } - if c != 16 { + if n.NetTimeout == 0 { + n.NetTimeout = 10000 + } + if n.Passwd != "" { + MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16] + } - } + if n.enableTCP { + go n.runTcpListen() } + return nil } func (n *NatServer) runTcpListen() error { + var err error + n.tcpConnPool = make(chan net.Conn, 128) atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) - listener, err := net.Listen("tcp", n.Addr) + starlog.Infoln("nat server tcp listener start run") + n.listenTcp, err = net.Listen("tcp", n.ListenAddr) if err != nil { + starlog.Errorln("nat server tcp listener start failed:", err) return err } - n.listenTcp = listener + msgChan := make(chan []byte, 16) for { - conn, err := listener.Accept() + conn, err := n.listenTcp.Accept() if err != nil { continue } - headedr := make([]byte, 16) - conn.SetReadDeadline(time.Now().Add(time.Millisecond * 700)) - c, err := conn.Read(headedr) - if err == nil && c == 16 { - if bytes.Equal(headedr, MSG_CMD_HELLO) { - if n.cmdTCPConn != nil { - n.cmdTCPConn.Close() - } + var ok bool + if n.cmdTCPConn == nil { + if conn, ok = n.checkIsTcpControlConn(conn); ok { n.cmdTCPConn = conn conn.Write(MSG_CMD_HELLO_REPLY) - // + go n.handleTcpControlConn(conn, msgChan) + continue } } - io.ReadFull(conn, headedr) + if conn, ok = n.checkIsTcpNewConn(conn); ok { + starlog.Noticef("new tcp cmd conn is client conn %v\n", conn.RemoteAddr().String()) + n.tcpConnPool <- conn + continue + } + starlog.Noticef("new tcp cmd conn is not client conn %v\n", conn.RemoteAddr().String()) + go func() { + msgChan <- MSG_NEW_CONN_HELLO + }() + go n.pairNewClientConn(conn) + } +} + +func (n *NatServer) pairNewClientConn(conn net.Conn) { + log := starlog.Std.NewFlag() + log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String()) + select { + case <-time.After(time.Millisecond * time.Duration(n.NetTimeout)): + log.Errorln("pair new conn fail,wait timeout,conn is:", conn) + conn.Close() + return + case nconn := <-n.tcpConnPool: + log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String()) + go io.Copy(nconn, conn) + go io.Copy(conn, nconn) + return + } +} + +func (n *NatServer) handleTcpControlConn(conn net.Conn, msg chan []byte) { + go func() { + for { + select { + case data := <-msg: + _, err := conn.Write(data) + if err != nil { + conn.Close() + n.cmdTCPConn = nil + return + } + case <-time.After(time.Minute): + _, err := conn.Write(MSG_HEARTBEAT) + if err != nil { + conn.Close() + n.cmdTCPConn = nil + return + } + } + } + }() + for { + header := make([]byte, 16) + _, err := io.ReadFull(conn, header) + if err != nil { + conn.Close() + n.cmdTCPConn = nil + return + } + if bytes.Equal(header, MSG_HEARTBEAT) { + n.lastTCPHeart = time.Now().Unix() + } + continue + } +} + +func (n *NatServer) checkIsTcpControlConn(conn net.Conn) (net.Conn, bool) { + log := starlog.Std.NewFlag() + log.Noticef("start check tcp cmd conn %v\n", conn.RemoteAddr().String()) + header := make([]byte, 16) + conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200)) + count, err := io.ReadFull(conn, header) + conn.SetReadDeadline(time.Time{}) + if err == nil { + if bytes.Equal(header, MSG_CMD_HELLO) { + log.Infof("check tcp cmd conn success:%v\n", conn.RemoteAddr().String()) + return conn, true + } + } + log.Infof("check tcp cmd conn fail:%v %v\n", conn.RemoteAddr().String(), err) + return NewCensorConn(header[:count], conn), false +} + +func (n *NatServer) checkIsTcpNewConn(conn net.Conn) (net.Conn, bool) { + if n.cmdTCPConn == nil { + return conn, false + } + remoteIp := strings.Split(n.cmdTCPConn.RemoteAddr().String(), ":")[0] + newConnIp := strings.Split(conn.RemoteAddr().String(), ":")[0] + if remoteIp != newConnIp { + return conn, false + } + header := make([]byte, 16) + conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200)) + read, err := io.ReadFull(conn, header) + conn.SetReadDeadline(time.Time{}) + if err == nil { + if bytes.Equal(header, MSG_NEW_CONN_HELLO) { + return conn, true + } + } + return NewCensorConn(header[:read], conn), false +} + +type censorConn struct { + reader io.Reader + conn net.Conn +} + +func NewCensorConn(header []byte, conn net.Conn) censorConn { + return censorConn{ + reader: io.MultiReader(bytes.NewReader(header), conn), + conn: conn, } } +func (c censorConn) Read(p []byte) (int, error) { return c.reader.Read(p) } +func (c censorConn) Write(p []byte) (int, error) { return c.conn.Write(p) } +func (c censorConn) Close() error { return c.conn.Close() } +func (c censorConn) LocalAddr() net.Addr { return c.conn.LocalAddr() } +func (c censorConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } +func (c censorConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } +func (c censorConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } +func (c censorConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }