nat function improve
This commit is contained in:
parent
d2feccf3b3
commit
1276d3b6dd
@ -26,7 +26,7 @@ func (h *ReverseConfig) Run() error {
|
||||
}
|
||||
for key, proxy := range h.proxy {
|
||||
h.httpmux.HandleFunc(key, func(writer http.ResponseWriter, request *http.Request) {
|
||||
starlog.Infof("<%s> Req Path:%s Addr:%s UA:%s\n", h.Name, request.URL.Path, request.RemoteAddr, request.Header.Get("User-Agent"))
|
||||
starlog.Infof("<%s> Req Path:%s ListenAddr:%s UA:%s\n", h.Name, request.URL.Path, request.RemoteAddr, request.Header.Get("User-Agent"))
|
||||
|
||||
if !h.BasicAuth(writer, request) {
|
||||
h.SetResponseHeader(writer)
|
||||
|
23
net/nat_test.go
Normal file
23
net/nat_test.go
Normal file
@ -0,0 +1,23 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNat(t *testing.T) {
|
||||
var s = NatServer{
|
||||
ListenAddr: "0.0.0.0:10020",
|
||||
enableTCP: true,
|
||||
}
|
||||
var c = NatClient{
|
||||
ServiceTarget: "139.199.163.65:80",
|
||||
CmdTarget: "127.0.0.1:10020",
|
||||
enableTCP: true,
|
||||
}
|
||||
go s.Run()
|
||||
go c.Run()
|
||||
for {
|
||||
time.Sleep(time.Second * 20)
|
||||
}
|
||||
}
|
117
net/natclient.go
117
net/natclient.go
@ -1,27 +1,138 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"b612.me/starlog"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SimpleNatClient struct {
|
||||
type NatClient struct {
|
||||
mu sync.RWMutex
|
||||
cmdTCPConn net.Conn
|
||||
cmdUDPConn *net.UDPAddr
|
||||
ServiceTarget string
|
||||
CmdTarget string
|
||||
tcpAlived bool
|
||||
DialTimeout int
|
||||
enableTCP bool
|
||||
enableUDP bool
|
||||
Passwd string
|
||||
stopCtx context.Context
|
||||
stopFn context.CancelFunc
|
||||
}
|
||||
|
||||
func (s *SimpleNatClient) tcpCmdConn() net.Conn {
|
||||
func (s *NatClient) tcpCmdConn() net.Conn {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.cmdTCPConn
|
||||
}
|
||||
|
||||
func (s *SimpleNatClient) tcpCmdConnAlived() bool {
|
||||
func (s *NatClient) tcpCmdConnAlived() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.tcpAlived
|
||||
}
|
||||
|
||||
func (s *NatClient) setTcpCmdConnAlived(v bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tcpAlived = v
|
||||
}
|
||||
|
||||
func (s *NatClient) Run() {
|
||||
s.stopCtx, s.stopFn = context.WithCancel(context.Background())
|
||||
if s.DialTimeout == 0 {
|
||||
s.DialTimeout = 10000
|
||||
}
|
||||
if s.Passwd != "" {
|
||||
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(s.Passwd)...))[:16]
|
||||
}
|
||||
if s.enableTCP {
|
||||
s.runTcp()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NatClient) runTcp() error {
|
||||
var err error
|
||||
starlog.Noticeln("nat client tcp module start run")
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCtx.Done():
|
||||
if s.cmdTCPConn != nil {
|
||||
s.setTcpCmdConnAlived(false)
|
||||
s.cmdTCPConn.Close()
|
||||
return nil
|
||||
}
|
||||
case <-time.After(time.Millisecond * 1500):
|
||||
}
|
||||
if s.cmdTCPConn != nil && s.tcpCmdConnAlived() {
|
||||
continue
|
||||
}
|
||||
s.cmdTCPConn, err = net.DialTimeout("tcp", s.CmdTarget, time.Millisecond*time.Duration(s.DialTimeout))
|
||||
if err != nil {
|
||||
starlog.Errorf("dail remote tcp cmd server %v fail:%v;will retry\n", s.CmdTarget, err)
|
||||
time.Sleep(time.Second * 2)
|
||||
s.cmdTCPConn = nil
|
||||
continue
|
||||
}
|
||||
starlog.Infoln("dail remote tcp cmd server ok,remote:", s.CmdTarget)
|
||||
s.tcpCmdConn().Write(MSG_CMD_HELLO)
|
||||
s.setTcpCmdConnAlived(true)
|
||||
go s.handleTcpCmdConn(s.tcpCmdConn())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NatClient) handleTcpCmdConn(conn net.Conn) {
|
||||
for {
|
||||
header := make([]byte, 16)
|
||||
_, err := io.ReadFull(conn, header)
|
||||
if err != nil {
|
||||
starlog.Infoln("tcp cmd server read fail:", err)
|
||||
conn.Close()
|
||||
s.setTcpCmdConnAlived(false)
|
||||
return
|
||||
}
|
||||
if bytes.Equal(header, MSG_CMD_HELLO_REPLY) {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(header, MSG_NEW_CONN_HELLO) {
|
||||
go s.newRemoteTcpConn()
|
||||
}
|
||||
if bytes.Equal(header, MSG_HEARTBEAT) {
|
||||
_, err = conn.Write(MSG_HEARTBEAT)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.setTcpCmdConnAlived(false)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NatClient) newRemoteTcpConn() {
|
||||
log := starlog.Std.NewFlag()
|
||||
starlog.Infoln("recv request,create new tcp conn")
|
||||
nconn, err := net.DialTimeout("tcp", s.CmdTarget, time.Millisecond*time.Duration(s.DialTimeout))
|
||||
if err != nil {
|
||||
log.Errorf("dail server tcp conn %v fail:%v\n", s.CmdTarget, err)
|
||||
return
|
||||
}
|
||||
_, err = nconn.Write(MSG_NEW_CONN_HELLO)
|
||||
if err != nil {
|
||||
nconn.Close()
|
||||
log.Errorf("write new client hello to server %v fail:%v\n", s.CmdTarget, err)
|
||||
return
|
||||
}
|
||||
cconn, err := net.DialTimeout("tcp", s.ServiceTarget, time.Millisecond*time.Duration(s.DialTimeout))
|
||||
if err != nil {
|
||||
log.Errorf("dail remote tcp conn %v fail:%v\n", s.CmdTarget, err)
|
||||
return
|
||||
}
|
||||
go io.Copy(cconn, nconn)
|
||||
go io.Copy(nconn, cconn)
|
||||
}
|
||||
|
186
net/natserver.go
186
net/natserver.go
@ -1,40 +1,46 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"b612.me/starlog"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MSG_CMD_HELLO 控制链路主动链接参头 16byte
|
||||
var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA1")
|
||||
var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA2")
|
||||
var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA11965122519670220")
|
||||
var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014")
|
||||
|
||||
// MSG_NEW_CONN_HELLO 交链路主动连接头 16byte
|
||||
var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF")
|
||||
var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612")
|
||||
|
||||
var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008")
|
||||
|
||||
type NatServer struct {
|
||||
sync.RWMutex
|
||||
cmdTCPConn net.Conn
|
||||
listenTcp net.Listener
|
||||
listenUDP *net.UDPConn
|
||||
Addr string
|
||||
Port int
|
||||
ListenAddr string
|
||||
lastTCPHeart int64
|
||||
lastUDPHeart int64
|
||||
Passwd string
|
||||
DialTimeout int64
|
||||
NetTimeout int64
|
||||
UDPTimeout int64
|
||||
running int32
|
||||
tcpConnPool chan net.Conn
|
||||
stopCtx context.Context
|
||||
stopFn context.CancelFunc
|
||||
enableTCP bool
|
||||
enableUDP bool
|
||||
}
|
||||
|
||||
func (n *NatServer) Run() error {
|
||||
@ -42,48 +48,164 @@ func (n *NatServer) Run() error {
|
||||
return fmt.Errorf("Server Already Run")
|
||||
}
|
||||
n.stopCtx, n.stopFn = context.WithCancel(context.Background())
|
||||
if n.NetTimeout == 0 {
|
||||
n.NetTimeout = 10000
|
||||
}
|
||||
if n.Passwd != "" {
|
||||
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16]
|
||||
}
|
||||
|
||||
if n.enableTCP {
|
||||
go n.runTcpListen()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NatServer) cmdTcploop(conn net.Conn) error {
|
||||
var header = make([]byte, 16)
|
||||
for {
|
||||
c, err := conn.Read(header)
|
||||
if err != nil {
|
||||
//todo
|
||||
}
|
||||
if c != 16 {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NatServer) runTcpListen() error {
|
||||
var err error
|
||||
n.tcpConnPool = make(chan net.Conn, 128)
|
||||
atomic.AddInt32(&n.running, 1)
|
||||
defer atomic.AddInt32(&n.running, -1)
|
||||
listener, err := net.Listen("tcp", n.Addr)
|
||||
starlog.Infoln("nat server tcp listener start run")
|
||||
n.listenTcp, err = net.Listen("tcp", n.ListenAddr)
|
||||
if err != nil {
|
||||
starlog.Errorln("nat server tcp listener start failed:", err)
|
||||
return err
|
||||
}
|
||||
n.listenTcp = listener
|
||||
msgChan := make(chan []byte, 16)
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
conn, err := n.listenTcp.Accept()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
headedr := make([]byte, 16)
|
||||
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 700))
|
||||
c, err := conn.Read(headedr)
|
||||
if err == nil && c == 16 {
|
||||
if bytes.Equal(headedr, MSG_CMD_HELLO) {
|
||||
if n.cmdTCPConn != nil {
|
||||
n.cmdTCPConn.Close()
|
||||
}
|
||||
var ok bool
|
||||
if n.cmdTCPConn == nil {
|
||||
if conn, ok = n.checkIsTcpControlConn(conn); ok {
|
||||
n.cmdTCPConn = conn
|
||||
conn.Write(MSG_CMD_HELLO_REPLY)
|
||||
//
|
||||
go n.handleTcpControlConn(conn, msgChan)
|
||||
continue
|
||||
}
|
||||
}
|
||||
io.ReadFull(conn, headedr)
|
||||
if conn, ok = n.checkIsTcpNewConn(conn); ok {
|
||||
starlog.Noticef("new tcp cmd conn is client conn %v\n", conn.RemoteAddr().String())
|
||||
n.tcpConnPool <- conn
|
||||
continue
|
||||
}
|
||||
starlog.Noticef("new tcp cmd conn is not client conn %v\n", conn.RemoteAddr().String())
|
||||
go func() {
|
||||
msgChan <- MSG_NEW_CONN_HELLO
|
||||
}()
|
||||
go n.pairNewClientConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NatServer) pairNewClientConn(conn net.Conn) {
|
||||
log := starlog.Std.NewFlag()
|
||||
log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String())
|
||||
select {
|
||||
case <-time.After(time.Millisecond * time.Duration(n.NetTimeout)):
|
||||
log.Errorln("pair new conn fail,wait timeout,conn is:", conn)
|
||||
conn.Close()
|
||||
return
|
||||
case nconn := <-n.tcpConnPool:
|
||||
log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String())
|
||||
go io.Copy(nconn, conn)
|
||||
go io.Copy(conn, nconn)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NatServer) handleTcpControlConn(conn net.Conn, msg chan []byte) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case data := <-msg:
|
||||
_, err := conn.Write(data)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
n.cmdTCPConn = nil
|
||||
return
|
||||
}
|
||||
case <-time.After(time.Minute):
|
||||
_, err := conn.Write(MSG_HEARTBEAT)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
n.cmdTCPConn = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
for {
|
||||
header := make([]byte, 16)
|
||||
_, err := io.ReadFull(conn, header)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
n.cmdTCPConn = nil
|
||||
return
|
||||
}
|
||||
if bytes.Equal(header, MSG_HEARTBEAT) {
|
||||
n.lastTCPHeart = time.Now().Unix()
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NatServer) checkIsTcpControlConn(conn net.Conn) (net.Conn, bool) {
|
||||
log := starlog.Std.NewFlag()
|
||||
log.Noticef("start check tcp cmd conn %v\n", conn.RemoteAddr().String())
|
||||
header := make([]byte, 16)
|
||||
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
|
||||
count, err := io.ReadFull(conn, header)
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
if err == nil {
|
||||
if bytes.Equal(header, MSG_CMD_HELLO) {
|
||||
log.Infof("check tcp cmd conn success:%v\n", conn.RemoteAddr().String())
|
||||
return conn, true
|
||||
}
|
||||
}
|
||||
log.Infof("check tcp cmd conn fail:%v %v\n", conn.RemoteAddr().String(), err)
|
||||
return NewCensorConn(header[:count], conn), false
|
||||
}
|
||||
|
||||
func (n *NatServer) checkIsTcpNewConn(conn net.Conn) (net.Conn, bool) {
|
||||
if n.cmdTCPConn == nil {
|
||||
return conn, false
|
||||
}
|
||||
remoteIp := strings.Split(n.cmdTCPConn.RemoteAddr().String(), ":")[0]
|
||||
newConnIp := strings.Split(conn.RemoteAddr().String(), ":")[0]
|
||||
if remoteIp != newConnIp {
|
||||
return conn, false
|
||||
}
|
||||
header := make([]byte, 16)
|
||||
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
|
||||
read, err := io.ReadFull(conn, header)
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
if err == nil {
|
||||
if bytes.Equal(header, MSG_NEW_CONN_HELLO) {
|
||||
return conn, true
|
||||
}
|
||||
}
|
||||
return NewCensorConn(header[:read], conn), false
|
||||
}
|
||||
|
||||
type censorConn struct {
|
||||
reader io.Reader
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func NewCensorConn(header []byte, conn net.Conn) censorConn {
|
||||
return censorConn{
|
||||
reader: io.MultiReader(bytes.NewReader(header), conn),
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
func (c censorConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
|
||||
func (c censorConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
|
||||
func (c censorConn) Close() error { return c.conn.Close() }
|
||||
func (c censorConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
||||
func (c censorConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
|
||||
func (c censorConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
|
||||
func (c censorConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
|
||||
func (c censorConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
|
||||
|
Loading…
x
Reference in New Issue
Block a user