package netforward import ( "b612.me/stario" "b612.me/starlog" "context" "errors" "fmt" "io" "net" "strconv" "strings" "sync" "sync/atomic" "time" ) type NetForward struct { LocalAddr string LocalPort int RemoteURI string EnableTCP bool EnableUDP bool DelayMilSec int DelayToward int StdinMode bool IgnoreEof bool DialTimeout time.Duration UDPTimeout time.Duration stopCtx context.Context stopFn context.CancelFunc running int32 KeepAlivePeriod int KeepAliveIdel int KeepAliveCount int UserTimeout int UsingKeepAlive bool } func (n *NetForward) Close() { n.stopFn() } func (n *NetForward) Status() int32 { return atomic.LoadInt32(&n.running) } func (n *NetForward) Run() error { if n.running > 0 { starlog.Errorln("already running") return errors.New("already running") } n.stopCtx, n.stopFn = context.WithCancel(context.Background()) if n.DialTimeout == 0 { n.DialTimeout = time.Second * 5 } if n.StdinMode { go func() { for { cmd := strings.TrimSpace(stario.MessageBox("", "").MustString()) for strings.Contains(cmd, " ") { cmd = strings.Replace(cmd, " ", " ", -1) } starlog.Debugf("Recv Command %s\n", cmd) cmds := strings.Split(cmd, " ") if len(cmds) < 3 { starlog.Errorln("Invalid Command", cmd) continue } switch cmds[0] + cmds[1] { case "setremote": n.RemoteURI = cmds[2] starlog.Noticef("Remote URI Set to %s\n", n.RemoteURI) case "setdelaytoward": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Delay Toward Value", cmds[2]) continue } n.DelayToward = tmp starlog.Noticef("Delay Toward Set to %d\n", n.DelayToward) case "setdelay": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Delay Value", cmds[2]) continue } n.DelayMilSec = tmp starlog.Noticef("Delay Set to %d\n", n.DelayMilSec) case "setdialtimeout": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Dial Timeout Value", cmds[2]) continue } n.DialTimeout = time.Millisecond * time.Duration(tmp) starlog.Noticef("Dial Timeout Set to %d\n", n.DialTimeout) case "setudptimeout": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid UDP Timeout Value", cmds[2]) continue } n.UDPTimeout = time.Millisecond * time.Duration(tmp) starlog.Noticef("UDP Timeout Set to %d\n", n.UDPTimeout) case "setstdin": if cmds[2] == "off" { n.StdinMode = false starlog.Noticef("Stdin Mode Off\n") return } } } }() } if n.EnableTCP { go n.runTCP() } if n.EnableUDP { go n.runUDP() } return nil } func (n *NetForward) runTCP() error { atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) listen, err := net.Listen("tcp", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) if err != nil { starlog.Errorln("Listening On Tcp Failed:", err) return err } go func() { <-n.stopCtx.Done() listen.Close() }() starlog.Infof("Listening TCP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) for { select { case <-n.stopCtx.Done(): return nil default: } conn, err := listen.Accept() if err != nil { continue } log := starlog.Std.NewFlag() log.Infof("Accept New TCP Conn from %v\n", conn.RemoteAddr().String()) if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) { log.Infof("Delay %d ms\n", n.DelayMilSec) time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } err = SetTcpInfo(conn.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout) if err != nil { log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err) conn.Close() continue } go func(conn net.Conn) { rmt, err := net.DialTimeout("tcp", n.RemoteURI, n.DialTimeout) if err != nil { log.Errorf("TCP:Dial Remote %s Failed:%v\n", n.RemoteURI, err) conn.Close() return } err = SetTcpInfo(rmt.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout) if err != nil { log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err) rmt.Close() return } log.Infof("TCP Connect %s <==> %s\n", conn.RemoteAddr().String(), rmt.RemoteAddr().String()) n.copy(rmt, conn) log.Noticef("TCP Connection Closed %s <==> %s\n", conn.RemoteAddr().String(), n.RemoteURI) conn.Close() rmt.Close() }(conn) } } type UDPConn struct { net.Conn listen *net.UDPConn remoteAddr *net.UDPAddr lastbeat int64 } func (u UDPConn) Write(p []byte) (n int, err error) { u.lastbeat = time.Now().Unix() return u.Conn.Write(p) } func (u UDPConn) Read(p []byte) (n int, err error) { u.lastbeat = time.Now().Unix() return u.Conn.Read(p) } func (u UDPConn) Work(delay int) { buf := make([]byte, 8192) for { if delay > 0 { time.Sleep(time.Millisecond * time.Duration(delay)) } count, err := u.Read(buf) if err != nil { u.Close() u.lastbeat = 0 return } _, err = u.listen.Write(buf[0:count]) if err != nil { u.lastbeat = 0 return } } } func (n *NetForward) runUDP() error { var mu sync.RWMutex atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%v", n.LocalAddr, n.LocalPort)) if err != nil { return err } listen, err := net.ListenUDP("udp", udpAddr) if err != nil { return err } starlog.Infof("Listening UDP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) go func() { <-n.stopCtx.Done() listen.Close() }() udpMap := make(map[string]UDPConn) go func() { for { select { case <-n.stopCtx.Done(): return case <-time.After(time.Second * 60): mu.Lock() for k, v := range udpMap { if time.Now().Unix() > int64(n.UDPTimeout.Seconds())+v.lastbeat { delete(udpMap, k) starlog.Noticef("UDP Connection Closed %s <==> %s\n", v.remoteAddr.String(), n.RemoteURI) } } mu.Unlock() } } }() buf := make([]byte, 8192) for { select { case <-n.stopCtx.Done(): return nil default: } count, rmt, err := listen.ReadFromUDP(buf) if err != nil || rmt.String() == n.RemoteURI { continue } go func(data []byte, rmt *net.UDPAddr) { log := starlog.Std.NewFlag() mu.Lock() addr, ok := udpMap[rmt.String()] if !ok { log.Infof("Accept New UDP Conn from %v\n", rmt.String()) conn, err := net.Dial("udp", n.RemoteURI) if err != nil { log.Errorf("UDP:Dial Remote %s Failed:%v\n", n.RemoteURI, err) mu.Unlock() return } addr = UDPConn{ Conn: conn, remoteAddr: rmt, listen: listen, lastbeat: time.Now().Unix(), } udpMap[rmt.String()] = addr go addr.Work(n.DelayMilSec) log.Infof("UDP Connect %s <==> %s\n", rmt.String(), n.RemoteURI) } mu.Unlock() if n.DelayMilSec > 0 || (n.DelayToward == 0 || n.DelayToward == 1) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } _, err := addr.Write(data) if err != nil { mu.Lock() addr.Close() delete(udpMap, addr.remoteAddr.String()) mu.Unlock() log.Noticef("UDP Connection Closed %s <==> %s\n", rmt.String(), n.RemoteURI) } }(buf[0:count], rmt) } } func (n *NetForward) copy(dst, src net.Conn) { var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() bufsize := make([]byte, 32*1024) for { count, err := src.Read(bufsize) if err != nil { if n.IgnoreEof && err == io.EOF { continue } dst.Close() src.Close() return } _, err = dst.Write(bufsize[:count]) if err != nil { src.Close() dst.Close() return } if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } } }() go func() { defer wg.Done() bufsize := make([]byte, 32*1024) for { count, err := dst.Read(bufsize) if err != nil { if n.IgnoreEof && err == io.EOF { continue } src.Close() dst.Close() return } _, err = src.Write(bufsize[:count]) if err != nil { src.Close() dst.Close() return } if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 2) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } } }() wg.Wait() }