master
兔子 9 months ago
parent 1276d3b6dd
commit 4074adfcd9

@ -13,3 +13,46 @@ var Cmd = &cobra.Command{
func init() {
Cmd.AddCommand(netforward.CmdNetforward)
}
var natc NatClient
var nats NatServer
func init() {
CmdNatClient.Flags().StringVarP(&natc.ServiceTarget, "target", "t", "", "forward server target address")
CmdNatClient.Flags().StringVarP(&natc.CmdTarget, "server", "s", "", "nat server command address")
CmdNatClient.Flags().StringVarP(&natc.Passwd, "passwd", "p", "", "password")
CmdNatClient.Flags().BoolVarP(&natc.enableTCP, "enable-tcp", "T", true, "enable tcp forward")
CmdNatClient.Flags().BoolVarP(&natc.enableUDP, "enable-udp", "U", true, "enable udp forward")
CmdNatClient.Flags().IntVarP(&natc.DialTimeout, "dial-timeout", "d", 10000, "dial timeout milliseconds")
CmdNatClient.Flags().IntVarP(&natc.UdpTimeout, "udp-timeout", "D", 60000, "udp connection timeout milliseconds")
Cmd.AddCommand(CmdNatClient)
CmdNatServer.Flags().StringVarP(&nats.ListenAddr, "listen", "l", "", "listen address")
CmdNatServer.Flags().StringVarP(&nats.Passwd, "passwd", "p", "", "password")
CmdNatServer.Flags().Int64VarP(&nats.UDPTimeout, "udp-timeout", "D", 60000, "udp connection timeout milliseconds")
CmdNatServer.Flags().Int64VarP(&nats.NetTimeout, "dial-timeout", "d", 10000, "dial timeout milliseconds")
CmdNatServer.Flags().BoolVarP(&nats.enableTCP, "enable-tcp", "T", true, "enable tcp forward")
CmdNatServer.Flags().BoolVarP(&nats.enableUDP, "enable-udp", "U", true, "enable udp forward")
Cmd.AddCommand(CmdNatServer)
}
var CmdNatClient = &cobra.Command{
Use: "natc",
Short: "nat client",
Run: func(cmd *cobra.Command, args []string) {
if natc.ServiceTarget == "" || natc.CmdTarget == "" {
cmd.Help()
return
}
natc.Run()
},
}
var CmdNatServer = &cobra.Command{
Use: "nats",
Short: "nat server",
Run: func(cmd *cobra.Command, args []string) {
nats.Run()
},
}

@ -9,11 +9,13 @@ func TestNat(t *testing.T) {
var s = NatServer{
ListenAddr: "0.0.0.0:10020",
enableTCP: true,
enableUDP: true,
}
var c = NatClient{
ServiceTarget: "139.199.163.65:80",
ServiceTarget: "dns.b612.me:521",
CmdTarget: "127.0.0.1:10020",
enableTCP: true,
enableUDP: true,
}
go s.Run()
go c.Run()

@ -14,14 +14,16 @@ import (
type NatClient struct {
mu sync.RWMutex
cmdTCPConn net.Conn
cmdUDPConn *net.UDPAddr
cmdUDPConn *net.UDPConn
ServiceTarget string
CmdTarget string
tcpAlived bool
DialTimeout int
UdpTimeout int
enableTCP bool
enableUDP bool
Passwd string
udpAlived bool
stopCtx context.Context
stopFn context.CancelFunc
}
@ -32,6 +34,12 @@ func (s *NatClient) tcpCmdConn() net.Conn {
return s.cmdTCPConn
}
func (s *NatClient) udpCmdConn() *net.UDPConn {
s.mu.RLock()
defer s.mu.RUnlock()
return s.cmdUDPConn
}
func (s *NatClient) tcpCmdConnAlived() bool {
s.mu.RLock()
defer s.mu.RUnlock()
@ -44,7 +52,19 @@ func (s *NatClient) setTcpCmdConnAlived(v bool) {
s.tcpAlived = v
}
func (s *NatClient) Run() {
func (s *NatClient) udpCmdConnAlived() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.udpAlived
}
func (s *NatClient) setUdpCmdConnAlived(v bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.udpAlived = v
}
func (s *NatClient) Run() error {
s.stopCtx, s.stopFn = context.WithCancel(context.Background())
if s.DialTimeout == 0 {
s.DialTimeout = 10000
@ -52,9 +72,23 @@ func (s *NatClient) Run() {
if s.Passwd != "" {
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(s.Passwd)...))[:16]
}
var wg sync.WaitGroup
if s.enableUDP {
wg.Add(1)
go func() {
defer wg.Done()
s.runUdp()
}()
}
if s.enableTCP {
wg.Add(1)
go func() {
defer wg.Done()
s.runTcp()
}()
}
wg.Wait()
return nil
}
func (s *NatClient) runTcp() error {
@ -87,6 +121,70 @@ func (s *NatClient) runTcp() error {
}
}
func (s *NatClient) runUdp() error {
starlog.Noticeln("nat client udp module start run")
if s.UdpTimeout == 0 {
s.UdpTimeout = 600000
}
for {
select {
case <-s.stopCtx.Done():
if s.cmdTCPConn != nil {
s.setUdpCmdConnAlived(false)
s.cmdUDPConn.Close()
return nil
}
case <-time.After(time.Millisecond * 3000):
}
if s.cmdUDPConn != nil && s.udpCmdConnAlived() {
continue
}
rmt, err := net.ResolveUDPAddr("udp", s.CmdTarget)
if err != nil {
starlog.Errorf("dail remote udp cmd server %v fail:%v;will retry\n", s.CmdTarget, err)
time.Sleep(time.Second * 2)
continue
}
s.cmdUDPConn, err = net.DialUDP("udp", nil, rmt)
if err != nil {
starlog.Errorf("dail remote udp cmd server %v fail:%v;will retry\n", s.CmdTarget, err)
time.Sleep(time.Second * 2)
s.cmdTCPConn = nil
continue
}
starlog.Infoln("dail remote udp cmd server ok,remote:", s.CmdTarget)
s.udpCmdConn().Write(MSG_CMD_HELLO)
s.setUdpCmdConnAlived(true)
go s.handleUdpCmdConn(s.udpCmdConn())
}
}
func (s *NatClient) handleUdpCmdConn(conn *net.UDPConn) {
for {
header := make([]byte, 16)
_, err := io.ReadFull(conn, header)
if err != nil {
starlog.Infoln("udp cmd server read fail:", err)
conn.Close()
s.setUdpCmdConnAlived(false)
return
}
if bytes.Equal(header, MSG_CMD_HELLO_REPLY) {
continue
}
if bytes.Equal(header, MSG_NEW_CONN_HELLO) {
go s.newRemoteUdpConn()
}
if bytes.Equal(header, MSG_HEARTBEAT) {
_, err = conn.Write(MSG_HEARTBEAT)
if err != nil {
conn.Close()
s.setUdpCmdConnAlived(false)
return
}
}
}
}
func (s *NatClient) handleTcpCmdConn(conn net.Conn) {
for {
header := make([]byte, 16)
@ -125,14 +223,121 @@ func (s *NatClient) newRemoteTcpConn() {
_, err = nconn.Write(MSG_NEW_CONN_HELLO)
if err != nil {
nconn.Close()
log.Errorf("write new client hello to server %v fail:%v\n", s.CmdTarget, err)
log.Errorf("write new tcp client hello to server %v fail:%v\n", s.CmdTarget, err)
return
}
cconn, err := net.DialTimeout("tcp", s.ServiceTarget, time.Millisecond*time.Duration(s.DialTimeout))
if err != nil {
log.Errorf("dail remote tcp conn %v fail:%v\n", s.CmdTarget, err)
nconn.Close()
return
}
go io.Copy(cconn, nconn)
go io.Copy(nconn, cconn)
go func() {
for {
data := make([]byte, 8192)
nconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout)))
n, err := nconn.Read(data)
if err != nil {
starlog.Infoln("read from tcp server fail:", nconn.RemoteAddr(), err)
nconn.Close()
cconn.Close()
return
}
_, err = cconn.Write(data[:n])
//starlog.Debugln("write to udp client:", p, err, cconn.LocalAddr(), cconn.RemoteAddr())
if err != nil {
starlog.Infoln("write to tcp client fail:", cconn.RemoteAddr(), err)
nconn.Close()
cconn.Close()
return
}
}
}()
go func() {
for {
data := make([]byte, 8192)
cconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout)))
n, err := cconn.Read(data)
if err != nil {
starlog.Infoln("read from tcp server fail:", cconn.RemoteAddr(), err)
nconn.Close()
cconn.Close()
return
}
_, err = nconn.Write(data[:n])
if err != nil {
starlog.Infoln("write to tcp client fail:", nconn.RemoteAddr(), err)
nconn.Close()
cconn.Close()
return
}
}
}()
}
func (s *NatClient) newRemoteUdpConn() {
log := starlog.Std.NewFlag()
starlog.Infoln("recv request,create new udp conn")
rmt, err := net.ResolveUDPAddr("udp", s.CmdTarget)
if err != nil {
log.Errorf("dail server udp conn %v fail:%v\n", s.CmdTarget, err)
return
}
nconn, err := net.DialUDP("udp", nil, rmt)
if err != nil {
log.Errorf("dail server udp conn %v fail:%v\n", s.CmdTarget, err)
return
}
log.Infof("dail server udp conn %v ok\n", s.CmdTarget)
_, err = nconn.Write(MSG_NEW_CONN_HELLO)
if err != nil {
nconn.Close()
log.Errorf("write new udp client hello to server %v fail:%v\n", s.CmdTarget, err)
return
}
rmt, err = net.ResolveUDPAddr("udp", s.ServiceTarget)
if err != nil {
log.Errorf("dail server udp conn %v fail:%v\n", s.ServiceTarget, err)
return
}
cconn, err := net.DialUDP("udp", nil, rmt)
if err != nil {
log.Errorf("dail remote udp conn %v fail:%v\n", s.ServiceTarget, err)
return
}
log.Infof("dail remote udp conn %v ok\n", s.ServiceTarget)
go func() {
for {
data := make([]byte, 8192)
nconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout)))
n, err := nconn.Read(data)
if err != nil {
starlog.Infoln("read from udp server fail:", err)
return
}
_, err = cconn.Write(data[:n])
//starlog.Debugln("write to udp client:", p, err, cconn.LocalAddr(), cconn.RemoteAddr())
if err != nil {
starlog.Infoln("write to udp client fail:", err)
return
}
}
}()
go func() {
for {
data := make([]byte, 8192)
cconn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(s.UdpTimeout)))
n, err := cconn.Read(data)
if err != nil {
starlog.Infoln("read from udp server fail:", err)
return
}
_, err = nconn.Write(data[:n])
if err != nil {
starlog.Infoln("write to udp client fail:", err)
return
}
}
}()
}

@ -22,6 +22,7 @@ var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014"
// MSG_NEW_CONN_HELLO 交链路主动连接头 16byte
var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612")
// MSG_HEARTBEAT 心跳报文 16byte
var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008")
type NatServer struct {
@ -29,6 +30,9 @@ type NatServer struct {
cmdTCPConn net.Conn
listenTcp net.Listener
listenUDP *net.UDPConn
udpConnMap sync.Map
udpPairMap sync.Map
udpCmdAddr *net.UDPAddr
ListenAddr string
lastTCPHeart int64
lastUDPHeart int64
@ -37,6 +41,7 @@ type NatServer struct {
UDPTimeout int64
running int32
tcpConnPool chan net.Conn
udpConnPool chan addionData
stopCtx context.Context
stopFn context.CancelFunc
enableTCP bool
@ -54,10 +59,22 @@ func (n *NatServer) Run() error {
if n.Passwd != "" {
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16]
}
var wg sync.WaitGroup
if n.enableUDP {
wg.Add(1)
go func() {
defer wg.Done()
n.runUdpListen()
}()
}
if n.enableTCP {
go n.runTcpListen()
wg.Add(1)
go func() {
defer wg.Done()
n.runTcpListen()
}()
}
wg.Wait()
return nil
}
@ -100,6 +117,151 @@ func (n *NatServer) runTcpListen() error {
}
}
func (n *NatServer) runUdpListen() error {
var err error
atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1)
starlog.Infoln("nat server udp listener start run")
if n.UDPTimeout == 0 {
n.UDPTimeout = 120
}
n.udpConnPool = make(chan addionData, 128)
udpListenAddr, err := net.ResolveUDPAddr("udp", n.ListenAddr)
if err != nil {
starlog.Errorln("nat server udp listener start failed:", err)
return err
}
n.listenUDP, err = net.ListenUDP("udp", udpListenAddr)
if err != nil {
starlog.Errorln("nat server tcp listener start failed:", err)
return err
}
go func() {
for {
select {
case <-n.stopCtx.Done():
if n.listenUDP != nil {
n.listenUDP.Close()
}
case <-time.After(time.Second * 30):
if time.Now().Unix()-n.lastUDPHeart > n.UDPTimeout {
if n.udpCmdAddr != nil {
n.udpCmdAddr = nil
}
}
if n.udpCmdAddr != nil {
n.listenUDP.WriteToUDP(MSG_HEARTBEAT, n.udpCmdAddr)
}
n.udpConnMap.Range(func(key, value interface{}) bool {
if time.Now().Unix()-value.(addionData).lastHeartbeat > n.UDPTimeout {
if taregt, ok := n.udpPairMap.Load(key); ok {
n.udpConnMap.Delete(taregt)
n.udpPairMap.Delete(taregt)
}
n.udpConnMap.Delete(key)
n.udpPairMap.Delete(key)
}
return true
})
}
}
}()
for {
data := make([]byte, 8192)
c, udpAddr, err := n.listenUDP.ReadFromUDP(data)
if err != nil {
continue
}
n.handleUdpData(udpAddr, data[:c])
}
}
type addionData struct {
lastHeartbeat int64
Addr *net.UDPAddr
MsgFrom []byte
}
func (n *NatServer) handleUdpData(addr *net.UDPAddr, data []byte) {
starlog.Infoln("handle udp data from:", addr.String())
if addr.String() == n.udpCmdAddr.String() && len(data) >= 16 {
if bytes.Equal(data[:16], MSG_HEARTBEAT) {
starlog.Infoln("recv udp cmd heartbeat")
n.lastUDPHeart = time.Now().Unix()
}
return
}
if n.udpCmdAddr == nil {
if len(data) >= 16 && bytes.Equal(data[:16], MSG_CMD_HELLO) {
starlog.Infof("recv udp cmd hello from %v\n", addr.String())
n.udpCmdAddr = addr
n.lastUDPHeart = time.Now().Unix()
n.listenUDP.WriteToUDP(MSG_CMD_HELLO_REPLY, addr)
return
}
}
if _, ok := n.udpConnMap.Load(addr.IP.String()); ok {
if target, ok := n.udpPairMap.Load(addr.IP.String()); ok {
starlog.Infof("found udp pair data %v <=====> %v\n", addr.String(), target.(*net.UDPAddr).String())
rmt := target.(*net.UDPAddr)
if _, ok := n.udpConnMap.Load(rmt.IP.String()); !ok {
n.udpConnMap.Delete(addr.IP.String())
n.udpPairMap.Delete(addr.IP.String())
n.udpPairMap.Delete(rmt.IP.String())
starlog.Errorf("udp pair data %v <=====> %v fail,remote not found\n", addr.String(), rmt.String())
return
}
tmp, _ := n.udpConnMap.Load(addr.IP.String())
current := tmp.(addionData)
current.lastHeartbeat = time.Now().Unix()
n.udpConnMap.Store(addr.IP.String(), current)
return
}
}
if len(data) >= 16 {
if bytes.Equal(data[:16], MSG_NEW_CONN_HELLO) {
starlog.Infof("recv new udp conn hello from %v\n", addr.String())
if len(data) < 16 {
data = data[16:]
} else {
data = []byte{}
}
n.udpConnMap.Store(addr.IP.String(), addionData{
lastHeartbeat: time.Now().Unix(),
Addr: addr,
})
n.udpConnPool <- addionData{
lastHeartbeat: time.Now().Unix(),
Addr: addr,
MsgFrom: data,
}
return
}
}
starlog.Infof("wait pair udp conn %v\n", addr.String())
if n.udpCmdAddr == nil {
starlog.Infof("wait pair udp conn %v fail,cmd addr is nil\n", addr.String())
return
} else {
n.listenUDP.WriteToUDP(MSG_NEW_CONN_HELLO, n.udpCmdAddr)
}
go func() {
pairAddr := <-n.udpConnPool
n.udpConnMap.Store(addr.String(), addionData{
lastHeartbeat: time.Now().Unix(),
Addr: addr,
})
n.udpPairMap.Store(addr.IP.String(), pairAddr.Addr)
n.udpPairMap.Store(pairAddr.Addr.String(), addr.IP)
starlog.Infof("pair udp conn %v <=====> %v\n", addr.String(), pairAddr.Addr.String())
if len(pairAddr.MsgFrom) > 0 {
n.listenUDP.WriteToUDP(pairAddr.MsgFrom, addr)
}
n.listenUDP.WriteToUDP(data, pairAddr.Addr)
}()
}
func (n *NatServer) pairNewClientConn(conn net.Conn) {
log := starlog.Std.NewFlag()
log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String())
@ -110,8 +272,16 @@ func (n *NatServer) pairNewClientConn(conn net.Conn) {
return
case nconn := <-n.tcpConnPool:
log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String())
go io.Copy(nconn, conn)
go io.Copy(conn, nconn)
go func() {
defer nconn.Close()
defer conn.Close()
io.Copy(nconn, conn)
}()
go func() {
defer nconn.Close()
defer conn.Close()
io.Copy(conn, nconn)
}()
return
}
}

Loading…
Cancel
Save