You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
star/netforward/forward.go

358 lines
8.3 KiB
Go

package netforward
import (
"b612.me/stario"
"b612.me/starlog"
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
type NetForward struct {
LocalAddr string
LocalPort int
RemoteURI string
EnableTCP bool
EnableUDP bool
DelayMilSec int
DelayToward int
StdinMode bool
IgnoreEof bool
DialTimeout time.Duration
UDPTimeout time.Duration
stopCtx context.Context
stopFn context.CancelFunc
running int32
KeepAlivePeriod int
KeepAliveIdel int
KeepAliveCount int
UserTimeout int
UsingKeepAlive bool
}
func (n *NetForward) Close() {
n.stopFn()
}
func (n *NetForward) Status() int32 {
return atomic.LoadInt32(&n.running)
}
func (n *NetForward) Run() error {
if n.running > 0 {
starlog.Errorln("already running")
return errors.New("already running")
}
n.stopCtx, n.stopFn = context.WithCancel(context.Background())
if n.DialTimeout == 0 {
n.DialTimeout = time.Second * 5
}
if n.StdinMode {
go func() {
for {
cmd := strings.TrimSpace(stario.MessageBox("", "").MustString())
for strings.Contains(cmd, " ") {
cmd = strings.Replace(cmd, " ", " ", -1)
}
starlog.Debugf("Recv Command %s\n", cmd)
cmds := strings.Split(cmd, " ")
if len(cmds) < 3 {
starlog.Errorln("Invalid Command", cmd)
continue
}
switch cmds[0] + cmds[1] {
case "setremote":
n.RemoteURI = cmds[2]
starlog.Noticef("Remote URI Set to %s\n", n.RemoteURI)
case "setdelaytoward":
tmp, err := strconv.Atoi(cmds[2])
if err != nil {
starlog.Errorln("Invalid Delay Toward Value", cmds[2])
continue
}
n.DelayToward = tmp
starlog.Noticef("Delay Toward Set to %d\n", n.DelayToward)
case "setdelay":
tmp, err := strconv.Atoi(cmds[2])
if err != nil {
starlog.Errorln("Invalid Delay Value", cmds[2])
continue
}
n.DelayMilSec = tmp
starlog.Noticef("Delay Set to %d\n", n.DelayMilSec)
case "setdialtimeout":
tmp, err := strconv.Atoi(cmds[2])
if err != nil {
starlog.Errorln("Invalid Dial Timeout Value", cmds[2])
continue
}
n.DialTimeout = time.Millisecond * time.Duration(tmp)
starlog.Noticef("Dial Timeout Set to %d\n", n.DialTimeout)
case "setudptimeout":
tmp, err := strconv.Atoi(cmds[2])
if err != nil {
starlog.Errorln("Invalid UDP Timeout Value", cmds[2])
continue
}
n.UDPTimeout = time.Millisecond * time.Duration(tmp)
starlog.Noticef("UDP Timeout Set to %d\n", n.UDPTimeout)
case "setstdin":
if cmds[2] == "off" {
n.StdinMode = false
starlog.Noticef("Stdin Mode Off\n")
return
}
}
}
}()
}
if n.EnableTCP {
go n.runTCP()
}
if n.EnableUDP {
go n.runUDP()
}
return nil
}
func (n *NetForward) runTCP() error {
atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1)
listen, err := net.Listen("tcp", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort))
if err != nil {
starlog.Errorln("Listening On Tcp Failed:", err)
return err
}
go func() {
<-n.stopCtx.Done()
listen.Close()
}()
starlog.Infof("Listening TCP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort))
for {
select {
case <-n.stopCtx.Done():
return nil
default:
}
conn, err := listen.Accept()
if err != nil {
continue
}
log := starlog.Std.NewFlag()
log.Infof("Accept New TCP Conn from %v\n", conn.RemoteAddr().String())
if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) {
log.Infof("Delay %d ms\n", n.DelayMilSec)
time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
}
err = SetTcpInfo(conn.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout)
if err != nil {
log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err)
conn.Close()
continue
}
go func(conn net.Conn) {
rmt, err := net.DialTimeout("tcp", n.RemoteURI, n.DialTimeout)
if err != nil {
log.Errorf("TCP:Dial Remote %s Failed:%v\n", n.RemoteURI, err)
conn.Close()
return
}
err = SetTcpInfo(rmt.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout)
if err != nil {
log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err)
rmt.Close()
return
}
log.Infof("TCP Connect %s <==> %s\n", conn.RemoteAddr().String(), rmt.RemoteAddr().String())
n.copy(rmt, conn)
log.Noticef("TCP Connection Closed %s <==> %s\n", conn.RemoteAddr().String(), n.RemoteURI)
conn.Close()
rmt.Close()
}(conn)
}
}
type UDPConn struct {
net.Conn
listen *net.UDPConn
remoteAddr *net.UDPAddr
lastbeat int64
}
func (u UDPConn) Write(p []byte) (n int, err error) {
u.lastbeat = time.Now().Unix()
return u.Conn.Write(p)
}
func (u UDPConn) Read(p []byte) (n int, err error) {
u.lastbeat = time.Now().Unix()
return u.Conn.Read(p)
}
func (u UDPConn) Work(delay int) {
buf := make([]byte, 8192)
for {
if delay > 0 {
time.Sleep(time.Millisecond * time.Duration(delay))
}
count, err := u.Read(buf)
if err != nil {
u.Close()
u.lastbeat = 0
return
}
_, err = u.listen.Write(buf[0:count])
if err != nil {
u.lastbeat = 0
return
}
}
}
func (n *NetForward) runUDP() error {
var mu sync.RWMutex
atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1)
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%v", n.LocalAddr, n.LocalPort))
if err != nil {
return err
}
listen, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
starlog.Infof("Listening UDP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort))
go func() {
<-n.stopCtx.Done()
listen.Close()
}()
udpMap := make(map[string]UDPConn)
go func() {
for {
select {
case <-n.stopCtx.Done():
return
case <-time.After(time.Second * 60):
mu.Lock()
for k, v := range udpMap {
if time.Now().Unix() > int64(n.UDPTimeout.Seconds())+v.lastbeat {
delete(udpMap, k)
starlog.Noticef("UDP Connection Closed %s <==> %s\n", v.remoteAddr.String(), n.RemoteURI)
}
}
mu.Unlock()
}
}
}()
buf := make([]byte, 8192)
for {
select {
case <-n.stopCtx.Done():
return nil
default:
}
count, rmt, err := listen.ReadFromUDP(buf)
if err != nil || rmt.String() == n.RemoteURI {
continue
}
go func(data []byte, rmt *net.UDPAddr) {
log := starlog.Std.NewFlag()
mu.Lock()
addr, ok := udpMap[rmt.String()]
if !ok {
log.Infof("Accept New UDP Conn from %v\n", rmt.String())
conn, err := net.Dial("udp", n.RemoteURI)
if err != nil {
log.Errorf("UDP:Dial Remote %s Failed:%v\n", n.RemoteURI, err)
mu.Unlock()
return
}
addr = UDPConn{
Conn: conn,
remoteAddr: rmt,
listen: listen,
lastbeat: time.Now().Unix(),
}
udpMap[rmt.String()] = addr
go addr.Work(n.DelayMilSec)
log.Infof("UDP Connect %s <==> %s\n", rmt.String(), n.RemoteURI)
}
mu.Unlock()
if n.DelayMilSec > 0 || (n.DelayToward == 0 || n.DelayToward == 1) {
time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
}
_, err := addr.Write(data)
if err != nil {
mu.Lock()
addr.Close()
delete(udpMap, addr.remoteAddr.String())
mu.Unlock()
log.Noticef("UDP Connection Closed %s <==> %s\n", rmt.String(), n.RemoteURI)
}
}(buf[0:count], rmt)
}
}
func (n *NetForward) copy(dst, src net.Conn) {
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
bufsize := make([]byte, 32*1024)
for {
count, err := src.Read(bufsize)
if err != nil {
if n.IgnoreEof && err == io.EOF {
continue
}
dst.Close()
src.Close()
return
}
_, err = dst.Write(bufsize[:count])
if err != nil {
src.Close()
dst.Close()
return
}
if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) {
time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
}
}
}()
go func() {
defer wg.Done()
bufsize := make([]byte, 32*1024)
for {
count, err := dst.Read(bufsize)
if err != nil {
if n.IgnoreEof && err == io.EOF {
continue
}
src.Close()
dst.Close()
return
}
_, err = src.Write(bufsize[:count])
if err != nil {
src.Close()
dst.Close()
return
}
if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 2) {
time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
}
}
}()
wg.Wait()
}