You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
382 lines
10 KiB
Go
382 lines
10 KiB
Go
package net
|
|
|
|
import (
|
|
"b612.me/starlog"
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// MSG_CMD_HELLO 控制链路主动链接参头 16byte
|
|
var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA11965122519670220")
|
|
var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014")
|
|
|
|
// MSG_NEW_CONN_HELLO 交链路主动连接头 16byte
|
|
var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612")
|
|
|
|
// MSG_HEARTBEAT 心跳报文 16byte
|
|
var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008")
|
|
|
|
type NatServer struct {
|
|
sync.RWMutex
|
|
cmdTCPConn net.Conn
|
|
listenTcp net.Listener
|
|
listenUDP *net.UDPConn
|
|
udpConnMap sync.Map
|
|
udpPairMap sync.Map
|
|
udpCmdAddr *net.UDPAddr
|
|
ListenAddr string
|
|
lastTCPHeart int64
|
|
lastUDPHeart int64
|
|
Passwd string
|
|
NetTimeout int64
|
|
UDPTimeout int64
|
|
running int32
|
|
tcpConnPool chan net.Conn
|
|
udpConnPool chan addionData
|
|
stopCtx context.Context
|
|
stopFn context.CancelFunc
|
|
enableTCP bool
|
|
enableUDP bool
|
|
}
|
|
|
|
func (n *NatServer) Run() error {
|
|
if n.running != 0 {
|
|
return fmt.Errorf("Server Already Run")
|
|
}
|
|
n.stopCtx, n.stopFn = context.WithCancel(context.Background())
|
|
if n.NetTimeout == 0 {
|
|
n.NetTimeout = 10000
|
|
}
|
|
if n.Passwd != "" {
|
|
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16]
|
|
}
|
|
var wg sync.WaitGroup
|
|
if n.enableUDP {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
n.runUdpListen()
|
|
}()
|
|
}
|
|
if n.enableTCP {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
n.runTcpListen()
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
return nil
|
|
}
|
|
|
|
func (n *NatServer) runTcpListen() error {
|
|
var err error
|
|
n.tcpConnPool = make(chan net.Conn, 128)
|
|
atomic.AddInt32(&n.running, 1)
|
|
defer atomic.AddInt32(&n.running, -1)
|
|
starlog.Infoln("nat server tcp listener start run")
|
|
n.listenTcp, err = net.Listen("tcp", n.ListenAddr)
|
|
if err != nil {
|
|
starlog.Errorln("nat server tcp listener start failed:", err)
|
|
return err
|
|
}
|
|
msgChan := make(chan []byte, 16)
|
|
for {
|
|
conn, err := n.listenTcp.Accept()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
var ok bool
|
|
if n.cmdTCPConn == nil {
|
|
if conn, ok = n.checkIsTcpControlConn(conn); ok {
|
|
n.cmdTCPConn = conn
|
|
conn.Write(MSG_CMD_HELLO_REPLY)
|
|
go n.handleTcpControlConn(conn, msgChan)
|
|
continue
|
|
}
|
|
}
|
|
if conn, ok = n.checkIsTcpNewConn(conn); ok {
|
|
starlog.Noticef("new tcp cmd conn is client conn %v\n", conn.RemoteAddr().String())
|
|
n.tcpConnPool <- conn
|
|
continue
|
|
}
|
|
starlog.Noticef("new tcp cmd conn is not client conn %v\n", conn.RemoteAddr().String())
|
|
go func() {
|
|
msgChan <- MSG_NEW_CONN_HELLO
|
|
}()
|
|
go n.pairNewClientConn(conn)
|
|
}
|
|
}
|
|
|
|
func (n *NatServer) runUdpListen() error {
|
|
var err error
|
|
atomic.AddInt32(&n.running, 1)
|
|
defer atomic.AddInt32(&n.running, -1)
|
|
starlog.Infoln("nat server udp listener start run")
|
|
if n.UDPTimeout == 0 {
|
|
n.UDPTimeout = 120
|
|
}
|
|
n.udpConnPool = make(chan addionData, 128)
|
|
udpListenAddr, err := net.ResolveUDPAddr("udp", n.ListenAddr)
|
|
if err != nil {
|
|
starlog.Errorln("nat server udp listener start failed:", err)
|
|
return err
|
|
}
|
|
n.listenUDP, err = net.ListenUDP("udp", udpListenAddr)
|
|
if err != nil {
|
|
starlog.Errorln("nat server tcp listener start failed:", err)
|
|
return err
|
|
}
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-n.stopCtx.Done():
|
|
if n.listenUDP != nil {
|
|
n.listenUDP.Close()
|
|
}
|
|
case <-time.After(time.Second * 30):
|
|
if time.Now().Unix()-n.lastUDPHeart > n.UDPTimeout {
|
|
if n.udpCmdAddr != nil {
|
|
n.udpCmdAddr = nil
|
|
}
|
|
}
|
|
if n.udpCmdAddr != nil {
|
|
n.listenUDP.WriteToUDP(MSG_HEARTBEAT, n.udpCmdAddr)
|
|
}
|
|
n.udpConnMap.Range(func(key, value interface{}) bool {
|
|
if time.Now().Unix()-value.(addionData).lastHeartbeat > n.UDPTimeout {
|
|
if taregt, ok := n.udpPairMap.Load(key); ok {
|
|
n.udpConnMap.Delete(taregt)
|
|
n.udpPairMap.Delete(taregt)
|
|
}
|
|
n.udpConnMap.Delete(key)
|
|
n.udpPairMap.Delete(key)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
}()
|
|
for {
|
|
data := make([]byte, 8192)
|
|
c, udpAddr, err := n.listenUDP.ReadFromUDP(data)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
n.handleUdpData(udpAddr, data[:c])
|
|
}
|
|
}
|
|
|
|
type addionData struct {
|
|
lastHeartbeat int64
|
|
Addr *net.UDPAddr
|
|
MsgFrom []byte
|
|
}
|
|
|
|
func (n *NatServer) handleUdpData(addr *net.UDPAddr, data []byte) {
|
|
starlog.Infoln("handle udp data from:", addr.String())
|
|
if addr.String() == n.udpCmdAddr.String() && len(data) >= 16 {
|
|
if bytes.Equal(data[:16], MSG_HEARTBEAT) {
|
|
starlog.Infoln("recv udp cmd heartbeat")
|
|
n.lastUDPHeart = time.Now().Unix()
|
|
}
|
|
return
|
|
}
|
|
if n.udpCmdAddr == nil {
|
|
if len(data) >= 16 && bytes.Equal(data[:16], MSG_CMD_HELLO) {
|
|
starlog.Infof("recv udp cmd hello from %v\n", addr.String())
|
|
n.udpCmdAddr = addr
|
|
n.lastUDPHeart = time.Now().Unix()
|
|
n.listenUDP.WriteToUDP(MSG_CMD_HELLO_REPLY, addr)
|
|
return
|
|
}
|
|
}
|
|
if _, ok := n.udpConnMap.Load(addr.IP.String()); ok {
|
|
if target, ok := n.udpPairMap.Load(addr.IP.String()); ok {
|
|
starlog.Infof("found udp pair data %v <=====> %v\n", addr.String(), target.(*net.UDPAddr).String())
|
|
rmt := target.(*net.UDPAddr)
|
|
if _, ok := n.udpConnMap.Load(rmt.IP.String()); !ok {
|
|
n.udpConnMap.Delete(addr.IP.String())
|
|
n.udpPairMap.Delete(addr.IP.String())
|
|
n.udpPairMap.Delete(rmt.IP.String())
|
|
starlog.Errorf("udp pair data %v <=====> %v fail,remote not found\n", addr.String(), rmt.String())
|
|
return
|
|
}
|
|
tmp, _ := n.udpConnMap.Load(addr.IP.String())
|
|
current := tmp.(addionData)
|
|
current.lastHeartbeat = time.Now().Unix()
|
|
n.udpConnMap.Store(addr.IP.String(), current)
|
|
return
|
|
}
|
|
}
|
|
if len(data) >= 16 {
|
|
if bytes.Equal(data[:16], MSG_NEW_CONN_HELLO) {
|
|
starlog.Infof("recv new udp conn hello from %v\n", addr.String())
|
|
if len(data) < 16 {
|
|
data = data[16:]
|
|
} else {
|
|
data = []byte{}
|
|
}
|
|
n.udpConnMap.Store(addr.IP.String(), addionData{
|
|
lastHeartbeat: time.Now().Unix(),
|
|
Addr: addr,
|
|
})
|
|
n.udpConnPool <- addionData{
|
|
lastHeartbeat: time.Now().Unix(),
|
|
Addr: addr,
|
|
MsgFrom: data,
|
|
}
|
|
return
|
|
}
|
|
}
|
|
starlog.Infof("wait pair udp conn %v\n", addr.String())
|
|
if n.udpCmdAddr == nil {
|
|
starlog.Infof("wait pair udp conn %v fail,cmd addr is nil\n", addr.String())
|
|
return
|
|
} else {
|
|
n.listenUDP.WriteToUDP(MSG_NEW_CONN_HELLO, n.udpCmdAddr)
|
|
}
|
|
go func() {
|
|
pairAddr := <-n.udpConnPool
|
|
n.udpConnMap.Store(addr.String(), addionData{
|
|
lastHeartbeat: time.Now().Unix(),
|
|
Addr: addr,
|
|
})
|
|
n.udpPairMap.Store(addr.IP.String(), pairAddr.Addr)
|
|
n.udpPairMap.Store(pairAddr.Addr.String(), addr.IP)
|
|
starlog.Infof("pair udp conn %v <=====> %v\n", addr.String(), pairAddr.Addr.String())
|
|
if len(pairAddr.MsgFrom) > 0 {
|
|
n.listenUDP.WriteToUDP(pairAddr.MsgFrom, addr)
|
|
}
|
|
n.listenUDP.WriteToUDP(data, pairAddr.Addr)
|
|
}()
|
|
|
|
}
|
|
|
|
func (n *NatServer) pairNewClientConn(conn net.Conn) {
|
|
log := starlog.Std.NewFlag()
|
|
log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String())
|
|
select {
|
|
case <-time.After(time.Millisecond * time.Duration(n.NetTimeout)):
|
|
log.Errorln("pair new conn fail,wait timeout,conn is:", conn)
|
|
conn.Close()
|
|
return
|
|
case nconn := <-n.tcpConnPool:
|
|
log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String())
|
|
go func() {
|
|
defer nconn.Close()
|
|
defer conn.Close()
|
|
io.Copy(nconn, conn)
|
|
}()
|
|
go func() {
|
|
defer nconn.Close()
|
|
defer conn.Close()
|
|
io.Copy(conn, nconn)
|
|
}()
|
|
return
|
|
}
|
|
}
|
|
|
|
func (n *NatServer) handleTcpControlConn(conn net.Conn, msg chan []byte) {
|
|
go func() {
|
|
for {
|
|
select {
|
|
case data := <-msg:
|
|
_, err := conn.Write(data)
|
|
if err != nil {
|
|
conn.Close()
|
|
n.cmdTCPConn = nil
|
|
return
|
|
}
|
|
case <-time.After(time.Minute):
|
|
_, err := conn.Write(MSG_HEARTBEAT)
|
|
if err != nil {
|
|
conn.Close()
|
|
n.cmdTCPConn = nil
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
for {
|
|
header := make([]byte, 16)
|
|
_, err := io.ReadFull(conn, header)
|
|
if err != nil {
|
|
conn.Close()
|
|
n.cmdTCPConn = nil
|
|
return
|
|
}
|
|
if bytes.Equal(header, MSG_HEARTBEAT) {
|
|
n.lastTCPHeart = time.Now().Unix()
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
|
|
func (n *NatServer) checkIsTcpControlConn(conn net.Conn) (net.Conn, bool) {
|
|
log := starlog.Std.NewFlag()
|
|
log.Noticef("start check tcp cmd conn %v\n", conn.RemoteAddr().String())
|
|
header := make([]byte, 16)
|
|
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
|
|
count, err := io.ReadFull(conn, header)
|
|
conn.SetReadDeadline(time.Time{})
|
|
if err == nil {
|
|
if bytes.Equal(header, MSG_CMD_HELLO) {
|
|
log.Infof("check tcp cmd conn success:%v\n", conn.RemoteAddr().String())
|
|
return conn, true
|
|
}
|
|
}
|
|
log.Infof("check tcp cmd conn fail:%v %v\n", conn.RemoteAddr().String(), err)
|
|
return NewCensorConn(header[:count], conn), false
|
|
}
|
|
|
|
func (n *NatServer) checkIsTcpNewConn(conn net.Conn) (net.Conn, bool) {
|
|
if n.cmdTCPConn == nil {
|
|
return conn, false
|
|
}
|
|
remoteIp := strings.Split(n.cmdTCPConn.RemoteAddr().String(), ":")[0]
|
|
newConnIp := strings.Split(conn.RemoteAddr().String(), ":")[0]
|
|
if remoteIp != newConnIp {
|
|
return conn, false
|
|
}
|
|
header := make([]byte, 16)
|
|
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
|
|
read, err := io.ReadFull(conn, header)
|
|
conn.SetReadDeadline(time.Time{})
|
|
if err == nil {
|
|
if bytes.Equal(header, MSG_NEW_CONN_HELLO) {
|
|
return conn, true
|
|
}
|
|
}
|
|
return NewCensorConn(header[:read], conn), false
|
|
}
|
|
|
|
type censorConn struct {
|
|
reader io.Reader
|
|
conn net.Conn
|
|
}
|
|
|
|
func NewCensorConn(header []byte, conn net.Conn) censorConn {
|
|
return censorConn{
|
|
reader: io.MultiReader(bytes.NewReader(header), conn),
|
|
conn: conn,
|
|
}
|
|
}
|
|
func (c censorConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
|
|
func (c censorConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
|
|
func (c censorConn) Close() error { return c.conn.Close() }
|
|
func (c censorConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
|
func (c censorConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
|
|
func (c censorConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
|
|
func (c censorConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
|
|
func (c censorConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
|