nat function improve

master
兔子 9 months ago
parent d2feccf3b3
commit 1276d3b6dd

@ -26,7 +26,7 @@ func (h *ReverseConfig) Run() error {
}
for key, proxy := range h.proxy {
h.httpmux.HandleFunc(key, func(writer http.ResponseWriter, request *http.Request) {
starlog.Infof("<%s> Req Path:%s Addr:%s UA:%s\n", h.Name, request.URL.Path, request.RemoteAddr, request.Header.Get("User-Agent"))
starlog.Infof("<%s> Req Path:%s ListenAddr:%s UA:%s\n", h.Name, request.URL.Path, request.RemoteAddr, request.Header.Get("User-Agent"))
if !h.BasicAuth(writer, request) {
h.SetResponseHeader(writer)

@ -0,0 +1,23 @@
package net
import (
"testing"
"time"
)
func TestNat(t *testing.T) {
var s = NatServer{
ListenAddr: "0.0.0.0:10020",
enableTCP: true,
}
var c = NatClient{
ServiceTarget: "139.199.163.65:80",
CmdTarget: "127.0.0.1:10020",
enableTCP: true,
}
go s.Run()
go c.Run()
for {
time.Sleep(time.Second * 20)
}
}

@ -1,27 +1,138 @@
package net
import (
"b612.me/starlog"
"bytes"
"context"
"crypto/sha256"
"io"
"net"
"sync"
"time"
)
type SimpleNatClient struct {
type NatClient struct {
mu sync.RWMutex
cmdTCPConn net.Conn
cmdUDPConn *net.UDPAddr
ServiceTarget string
CmdTarget string
tcpAlived bool
DialTimeout int
enableTCP bool
enableUDP bool
Passwd string
stopCtx context.Context
stopFn context.CancelFunc
}
func (s *SimpleNatClient) tcpCmdConn() net.Conn {
func (s *NatClient) tcpCmdConn() net.Conn {
s.mu.RLock()
defer s.mu.RUnlock()
return s.cmdTCPConn
}
func (s *SimpleNatClient) tcpCmdConnAlived() bool {
func (s *NatClient) tcpCmdConnAlived() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.tcpAlived
}
func (s *NatClient) setTcpCmdConnAlived(v bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.tcpAlived = v
}
func (s *NatClient) Run() {
s.stopCtx, s.stopFn = context.WithCancel(context.Background())
if s.DialTimeout == 0 {
s.DialTimeout = 10000
}
if s.Passwd != "" {
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(s.Passwd)...))[:16]
}
if s.enableTCP {
s.runTcp()
}
}
func (s *NatClient) runTcp() error {
var err error
starlog.Noticeln("nat client tcp module start run")
for {
select {
case <-s.stopCtx.Done():
if s.cmdTCPConn != nil {
s.setTcpCmdConnAlived(false)
s.cmdTCPConn.Close()
return nil
}
case <-time.After(time.Millisecond * 1500):
}
if s.cmdTCPConn != nil && s.tcpCmdConnAlived() {
continue
}
s.cmdTCPConn, err = net.DialTimeout("tcp", s.CmdTarget, time.Millisecond*time.Duration(s.DialTimeout))
if err != nil {
starlog.Errorf("dail remote tcp cmd server %v fail:%v;will retry\n", s.CmdTarget, err)
time.Sleep(time.Second * 2)
s.cmdTCPConn = nil
continue
}
starlog.Infoln("dail remote tcp cmd server ok,remote:", s.CmdTarget)
s.tcpCmdConn().Write(MSG_CMD_HELLO)
s.setTcpCmdConnAlived(true)
go s.handleTcpCmdConn(s.tcpCmdConn())
}
}
func (s *NatClient) handleTcpCmdConn(conn net.Conn) {
for {
header := make([]byte, 16)
_, err := io.ReadFull(conn, header)
if err != nil {
starlog.Infoln("tcp cmd server read fail:", err)
conn.Close()
s.setTcpCmdConnAlived(false)
return
}
if bytes.Equal(header, MSG_CMD_HELLO_REPLY) {
continue
}
if bytes.Equal(header, MSG_NEW_CONN_HELLO) {
go s.newRemoteTcpConn()
}
if bytes.Equal(header, MSG_HEARTBEAT) {
_, err = conn.Write(MSG_HEARTBEAT)
if err != nil {
conn.Close()
s.setTcpCmdConnAlived(false)
return
}
}
}
}
func (s *NatClient) newRemoteTcpConn() {
log := starlog.Std.NewFlag()
starlog.Infoln("recv request,create new tcp conn")
nconn, err := net.DialTimeout("tcp", s.CmdTarget, time.Millisecond*time.Duration(s.DialTimeout))
if err != nil {
log.Errorf("dail server tcp conn %v fail:%v\n", s.CmdTarget, err)
return
}
_, err = nconn.Write(MSG_NEW_CONN_HELLO)
if err != nil {
nconn.Close()
log.Errorf("write new client hello to server %v fail:%v\n", s.CmdTarget, err)
return
}
cconn, err := net.DialTimeout("tcp", s.ServiceTarget, time.Millisecond*time.Duration(s.DialTimeout))
if err != nil {
log.Errorf("dail remote tcp conn %v fail:%v\n", s.CmdTarget, err)
return
}
go io.Copy(cconn, nconn)
go io.Copy(nconn, cconn)
}

@ -1,40 +1,46 @@
package net
import (
"b612.me/starlog"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
)
// MSG_CMD_HELLO 控制链路主动链接参头 16byte
var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA1")
var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA2")
var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA11965122519670220")
var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014")
// MSG_NEW_CONN_HELLO 交链路主动连接头 16byte
var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF")
var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612")
var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008")
type NatServer struct {
sync.RWMutex
cmdTCPConn net.Conn
listenTcp net.Listener
listenUDP *net.UDPConn
Addr string
Port int
ListenAddr string
lastTCPHeart int64
lastUDPHeart int64
Passwd string
DialTimeout int64
NetTimeout int64
UDPTimeout int64
running int32
tcpConnPool chan net.Conn
stopCtx context.Context
stopFn context.CancelFunc
enableTCP bool
enableUDP bool
}
func (n *NatServer) Run() error {
@ -42,48 +48,164 @@ func (n *NatServer) Run() error {
return fmt.Errorf("Server Already Run")
}
n.stopCtx, n.stopFn = context.WithCancel(context.Background())
return nil
}
func (n *NatServer) cmdTcploop(conn net.Conn) error {
var header = make([]byte, 16)
for {
c, err := conn.Read(header)
if err != nil {
//todo
}
if c != 16 {
if n.NetTimeout == 0 {
n.NetTimeout = 10000
}
if n.Passwd != "" {
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16]
}
}
if n.enableTCP {
go n.runTcpListen()
}
return nil
}
func (n *NatServer) runTcpListen() error {
var err error
n.tcpConnPool = make(chan net.Conn, 128)
atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1)
listener, err := net.Listen("tcp", n.Addr)
starlog.Infoln("nat server tcp listener start run")
n.listenTcp, err = net.Listen("tcp", n.ListenAddr)
if err != nil {
starlog.Errorln("nat server tcp listener start failed:", err)
return err
}
n.listenTcp = listener
msgChan := make(chan []byte, 16)
for {
conn, err := listener.Accept()
conn, err := n.listenTcp.Accept()
if err != nil {
continue
}
headedr := make([]byte, 16)
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 700))
c, err := conn.Read(headedr)
if err == nil && c == 16 {
if bytes.Equal(headedr, MSG_CMD_HELLO) {
if n.cmdTCPConn != nil {
n.cmdTCPConn.Close()
}
var ok bool
if n.cmdTCPConn == nil {
if conn, ok = n.checkIsTcpControlConn(conn); ok {
n.cmdTCPConn = conn
conn.Write(MSG_CMD_HELLO_REPLY)
//
go n.handleTcpControlConn(conn, msgChan)
continue
}
}
io.ReadFull(conn, headedr)
if conn, ok = n.checkIsTcpNewConn(conn); ok {
starlog.Noticef("new tcp cmd conn is client conn %v\n", conn.RemoteAddr().String())
n.tcpConnPool <- conn
continue
}
starlog.Noticef("new tcp cmd conn is not client conn %v\n", conn.RemoteAddr().String())
go func() {
msgChan <- MSG_NEW_CONN_HELLO
}()
go n.pairNewClientConn(conn)
}
}
func (n *NatServer) pairNewClientConn(conn net.Conn) {
log := starlog.Std.NewFlag()
log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String())
select {
case <-time.After(time.Millisecond * time.Duration(n.NetTimeout)):
log.Errorln("pair new conn fail,wait timeout,conn is:", conn)
conn.Close()
return
case nconn := <-n.tcpConnPool:
log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String())
go io.Copy(nconn, conn)
go io.Copy(conn, nconn)
return
}
}
func (n *NatServer) handleTcpControlConn(conn net.Conn, msg chan []byte) {
go func() {
for {
select {
case data := <-msg:
_, err := conn.Write(data)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
case <-time.After(time.Minute):
_, err := conn.Write(MSG_HEARTBEAT)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
}
}
}()
for {
header := make([]byte, 16)
_, err := io.ReadFull(conn, header)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
if bytes.Equal(header, MSG_HEARTBEAT) {
n.lastTCPHeart = time.Now().Unix()
}
continue
}
}
func (n *NatServer) checkIsTcpControlConn(conn net.Conn) (net.Conn, bool) {
log := starlog.Std.NewFlag()
log.Noticef("start check tcp cmd conn %v\n", conn.RemoteAddr().String())
header := make([]byte, 16)
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
count, err := io.ReadFull(conn, header)
conn.SetReadDeadline(time.Time{})
if err == nil {
if bytes.Equal(header, MSG_CMD_HELLO) {
log.Infof("check tcp cmd conn success:%v\n", conn.RemoteAddr().String())
return conn, true
}
}
log.Infof("check tcp cmd conn fail:%v %v\n", conn.RemoteAddr().String(), err)
return NewCensorConn(header[:count], conn), false
}
func (n *NatServer) checkIsTcpNewConn(conn net.Conn) (net.Conn, bool) {
if n.cmdTCPConn == nil {
return conn, false
}
remoteIp := strings.Split(n.cmdTCPConn.RemoteAddr().String(), ":")[0]
newConnIp := strings.Split(conn.RemoteAddr().String(), ":")[0]
if remoteIp != newConnIp {
return conn, false
}
header := make([]byte, 16)
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
read, err := io.ReadFull(conn, header)
conn.SetReadDeadline(time.Time{})
if err == nil {
if bytes.Equal(header, MSG_NEW_CONN_HELLO) {
return conn, true
}
}
return NewCensorConn(header[:read], conn), false
}
type censorConn struct {
reader io.Reader
conn net.Conn
}
func NewCensorConn(header []byte, conn net.Conn) censorConn {
return censorConn{
reader: io.MultiReader(bytes.NewReader(header), conn),
conn: conn,
}
}
func (c censorConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c censorConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
func (c censorConn) Close() error { return c.conn.Close() }
func (c censorConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c censorConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c censorConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c censorConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c censorConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }

Loading…
Cancel
Save