You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
star/net/natserver.go

382 lines
10 KiB
Go

package net
import (
"b612.me/starlog"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
)
// MSG_CMD_HELLO 控制链路主动链接参头 16byte
var MSG_CMD_HELLO, _ = hex.DecodeString("B6121127AF7ECDA11965122519670220")
var MSG_CMD_HELLO_REPLY, _ = hex.DecodeString("B6121127AF7ECDA22002200820112014")
// MSG_NEW_CONN_HELLO 交链路主动连接头 16byte
var MSG_NEW_CONN_HELLO, _ = hex.DecodeString("B6121127AF7ECDFF201820202022B612")
// MSG_HEARTBEAT 心跳报文 16byte
var MSG_HEARTBEAT, _ = hex.DecodeString("B612112704011008B612112704011008")
type NatServer struct {
sync.RWMutex
cmdTCPConn net.Conn
listenTcp net.Listener
listenUDP *net.UDPConn
udpConnMap sync.Map
udpPairMap sync.Map
udpCmdAddr *net.UDPAddr
ListenAddr string
lastTCPHeart int64
lastUDPHeart int64
Passwd string
NetTimeout int64
UDPTimeout int64
running int32
tcpConnPool chan net.Conn
udpConnPool chan addionData
stopCtx context.Context
stopFn context.CancelFunc
enableTCP bool
enableUDP bool
}
func (n *NatServer) Run() error {
if n.running != 0 {
return fmt.Errorf("Server Already Run")
}
n.stopCtx, n.stopFn = context.WithCancel(context.Background())
if n.NetTimeout == 0 {
n.NetTimeout = 10000
}
if n.Passwd != "" {
MSG_CMD_HELLO = sha256.New().Sum(append(MSG_CMD_HELLO, []byte(n.Passwd)...))[:16]
}
var wg sync.WaitGroup
if n.enableUDP {
wg.Add(1)
go func() {
defer wg.Done()
n.runUdpListen()
}()
}
if n.enableTCP {
wg.Add(1)
go func() {
defer wg.Done()
n.runTcpListen()
}()
}
wg.Wait()
return nil
}
func (n *NatServer) runTcpListen() error {
var err error
n.tcpConnPool = make(chan net.Conn, 128)
atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1)
starlog.Infoln("nat server tcp listener start run")
n.listenTcp, err = net.Listen("tcp", n.ListenAddr)
if err != nil {
starlog.Errorln("nat server tcp listener start failed:", err)
return err
}
msgChan := make(chan []byte, 16)
for {
conn, err := n.listenTcp.Accept()
if err != nil {
continue
}
var ok bool
if n.cmdTCPConn == nil {
if conn, ok = n.checkIsTcpControlConn(conn); ok {
n.cmdTCPConn = conn
conn.Write(MSG_CMD_HELLO_REPLY)
go n.handleTcpControlConn(conn, msgChan)
continue
}
}
if conn, ok = n.checkIsTcpNewConn(conn); ok {
starlog.Noticef("new tcp cmd conn is client conn %v\n", conn.RemoteAddr().String())
n.tcpConnPool <- conn
continue
}
starlog.Noticef("new tcp cmd conn is not client conn %v\n", conn.RemoteAddr().String())
go func() {
msgChan <- MSG_NEW_CONN_HELLO
}()
go n.pairNewClientConn(conn)
}
}
func (n *NatServer) runUdpListen() error {
var err error
atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1)
starlog.Infoln("nat server udp listener start run")
if n.UDPTimeout == 0 {
n.UDPTimeout = 120
}
n.udpConnPool = make(chan addionData, 128)
udpListenAddr, err := net.ResolveUDPAddr("udp", n.ListenAddr)
if err != nil {
starlog.Errorln("nat server udp listener start failed:", err)
return err
}
n.listenUDP, err = net.ListenUDP("udp", udpListenAddr)
if err != nil {
starlog.Errorln("nat server tcp listener start failed:", err)
return err
}
go func() {
for {
select {
case <-n.stopCtx.Done():
if n.listenUDP != nil {
n.listenUDP.Close()
}
case <-time.After(time.Second * 30):
if time.Now().Unix()-n.lastUDPHeart > n.UDPTimeout {
if n.udpCmdAddr != nil {
n.udpCmdAddr = nil
}
}
if n.udpCmdAddr != nil {
n.listenUDP.WriteToUDP(MSG_HEARTBEAT, n.udpCmdAddr)
}
n.udpConnMap.Range(func(key, value interface{}) bool {
if time.Now().Unix()-value.(addionData).lastHeartbeat > n.UDPTimeout {
if taregt, ok := n.udpPairMap.Load(key); ok {
n.udpConnMap.Delete(taregt)
n.udpPairMap.Delete(taregt)
}
n.udpConnMap.Delete(key)
n.udpPairMap.Delete(key)
}
return true
})
}
}
}()
for {
data := make([]byte, 8192)
c, udpAddr, err := n.listenUDP.ReadFromUDP(data)
if err != nil {
continue
}
n.handleUdpData(udpAddr, data[:c])
}
}
type addionData struct {
lastHeartbeat int64
Addr *net.UDPAddr
MsgFrom []byte
}
func (n *NatServer) handleUdpData(addr *net.UDPAddr, data []byte) {
starlog.Infoln("handle udp data from:", addr.String())
if addr.String() == n.udpCmdAddr.String() && len(data) >= 16 {
if bytes.Equal(data[:16], MSG_HEARTBEAT) {
starlog.Infoln("recv udp cmd heartbeat")
n.lastUDPHeart = time.Now().Unix()
}
return
}
if n.udpCmdAddr == nil {
if len(data) >= 16 && bytes.Equal(data[:16], MSG_CMD_HELLO) {
starlog.Infof("recv udp cmd hello from %v\n", addr.String())
n.udpCmdAddr = addr
n.lastUDPHeart = time.Now().Unix()
n.listenUDP.WriteToUDP(MSG_CMD_HELLO_REPLY, addr)
return
}
}
if _, ok := n.udpConnMap.Load(addr.IP.String()); ok {
if target, ok := n.udpPairMap.Load(addr.IP.String()); ok {
starlog.Infof("found udp pair data %v <=====> %v\n", addr.String(), target.(*net.UDPAddr).String())
rmt := target.(*net.UDPAddr)
if _, ok := n.udpConnMap.Load(rmt.IP.String()); !ok {
n.udpConnMap.Delete(addr.IP.String())
n.udpPairMap.Delete(addr.IP.String())
n.udpPairMap.Delete(rmt.IP.String())
starlog.Errorf("udp pair data %v <=====> %v fail,remote not found\n", addr.String(), rmt.String())
return
}
tmp, _ := n.udpConnMap.Load(addr.IP.String())
current := tmp.(addionData)
current.lastHeartbeat = time.Now().Unix()
n.udpConnMap.Store(addr.IP.String(), current)
return
}
}
if len(data) >= 16 {
if bytes.Equal(data[:16], MSG_NEW_CONN_HELLO) {
starlog.Infof("recv new udp conn hello from %v\n", addr.String())
if len(data) < 16 {
data = data[16:]
} else {
data = []byte{}
}
n.udpConnMap.Store(addr.IP.String(), addionData{
lastHeartbeat: time.Now().Unix(),
Addr: addr,
})
n.udpConnPool <- addionData{
lastHeartbeat: time.Now().Unix(),
Addr: addr,
MsgFrom: data,
}
return
}
}
starlog.Infof("wait pair udp conn %v\n", addr.String())
if n.udpCmdAddr == nil {
starlog.Infof("wait pair udp conn %v fail,cmd addr is nil\n", addr.String())
return
} else {
n.listenUDP.WriteToUDP(MSG_NEW_CONN_HELLO, n.udpCmdAddr)
}
go func() {
pairAddr := <-n.udpConnPool
n.udpConnMap.Store(addr.String(), addionData{
lastHeartbeat: time.Now().Unix(),
Addr: addr,
})
n.udpPairMap.Store(addr.IP.String(), pairAddr.Addr)
n.udpPairMap.Store(pairAddr.Addr.String(), addr.IP)
starlog.Infof("pair udp conn %v <=====> %v\n", addr.String(), pairAddr.Addr.String())
if len(pairAddr.MsgFrom) > 0 {
n.listenUDP.WriteToUDP(pairAddr.MsgFrom, addr)
}
n.listenUDP.WriteToUDP(data, pairAddr.Addr)
}()
}
func (n *NatServer) pairNewClientConn(conn net.Conn) {
log := starlog.Std.NewFlag()
log.Noticef("start pair tcp cmd conn %v\n", conn.RemoteAddr().String())
select {
case <-time.After(time.Millisecond * time.Duration(n.NetTimeout)):
log.Errorln("pair new conn fail,wait timeout,conn is:", conn)
conn.Close()
return
case nconn := <-n.tcpConnPool:
log.Infof("pair %v <======> %v ok\n", conn.RemoteAddr().String(), nconn.RemoteAddr().String())
go func() {
defer nconn.Close()
defer conn.Close()
io.Copy(nconn, conn)
}()
go func() {
defer nconn.Close()
defer conn.Close()
io.Copy(conn, nconn)
}()
return
}
}
func (n *NatServer) handleTcpControlConn(conn net.Conn, msg chan []byte) {
go func() {
for {
select {
case data := <-msg:
_, err := conn.Write(data)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
case <-time.After(time.Minute):
_, err := conn.Write(MSG_HEARTBEAT)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
}
}
}()
for {
header := make([]byte, 16)
_, err := io.ReadFull(conn, header)
if err != nil {
conn.Close()
n.cmdTCPConn = nil
return
}
if bytes.Equal(header, MSG_HEARTBEAT) {
n.lastTCPHeart = time.Now().Unix()
}
continue
}
}
func (n *NatServer) checkIsTcpControlConn(conn net.Conn) (net.Conn, bool) {
log := starlog.Std.NewFlag()
log.Noticef("start check tcp cmd conn %v\n", conn.RemoteAddr().String())
header := make([]byte, 16)
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
count, err := io.ReadFull(conn, header)
conn.SetReadDeadline(time.Time{})
if err == nil {
if bytes.Equal(header, MSG_CMD_HELLO) {
log.Infof("check tcp cmd conn success:%v\n", conn.RemoteAddr().String())
return conn, true
}
}
log.Infof("check tcp cmd conn fail:%v %v\n", conn.RemoteAddr().String(), err)
return NewCensorConn(header[:count], conn), false
}
func (n *NatServer) checkIsTcpNewConn(conn net.Conn) (net.Conn, bool) {
if n.cmdTCPConn == nil {
return conn, false
}
remoteIp := strings.Split(n.cmdTCPConn.RemoteAddr().String(), ":")[0]
newConnIp := strings.Split(conn.RemoteAddr().String(), ":")[0]
if remoteIp != newConnIp {
return conn, false
}
header := make([]byte, 16)
conn.SetReadDeadline(time.Now().Add(time.Millisecond * 1200))
read, err := io.ReadFull(conn, header)
conn.SetReadDeadline(time.Time{})
if err == nil {
if bytes.Equal(header, MSG_NEW_CONN_HELLO) {
return conn, true
}
}
return NewCensorConn(header[:read], conn), false
}
type censorConn struct {
reader io.Reader
conn net.Conn
}
func NewCensorConn(header []byte, conn net.Conn) censorConn {
return censorConn{
reader: io.MultiReader(bytes.NewReader(header), conn),
conn: conn,
}
}
func (c censorConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c censorConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
func (c censorConn) Close() error { return c.conn.Close() }
func (c censorConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c censorConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c censorConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c censorConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c censorConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }