package netforward import ( "b612.me/stario" "b612.me/starlog" "context" "errors" "fmt" "net" "strconv" "strings" "sync" "sync/atomic" "time" ) type NetForward struct { LocalAddr string LocalPort int RemoteURI string EnableTCP bool EnableUDP bool DelayMilSec int DelayToward int StdinMode bool DialTimeout time.Duration UDPTimeout time.Duration stopCtx context.Context stopFn context.CancelFunc running int32 } func (n *NetForward) Close() { n.stopFn() } func (n *NetForward) Status() int32 { return atomic.LoadInt32(&n.running) } func (n *NetForward) Run() error { if n.running > 0 { starlog.Errorln("already running") return errors.New("already running") } n.stopCtx, n.stopFn = context.WithCancel(context.Background()) if n.DialTimeout == 0 { n.DialTimeout = time.Second * 5 } if n.StdinMode { go func() { for { cmd := strings.TrimSpace(stario.MessageBox("", "").MustString()) for strings.Contains(cmd, " ") { cmd = strings.Replace(cmd, " ", " ", -1) } starlog.Debugf("Recv Command %s\n", cmd) cmds := strings.Split(cmd, " ") if len(cmds) < 3 { starlog.Errorln("Invalid Command", cmd) continue } switch cmds[0] + cmds[1] { case "setremote": n.RemoteURI = cmds[2] starlog.Noticef("Remote URI Set to %s\n", n.RemoteURI) case "setdelaytoward": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Delay Toward Value", cmds[2]) continue } n.DelayToward = tmp starlog.Noticef("Delay Toward Set to %d\n", n.DelayToward) case "setdelay": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Delay Value", cmds[2]) continue } n.DelayMilSec = tmp starlog.Noticef("Delay Set to %d\n", n.DelayMilSec) case "setdialtimeout": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Dial Timeout Value", cmds[2]) continue } n.DialTimeout = time.Millisecond * time.Duration(tmp) starlog.Noticef("Dial Timeout Set to %d\n", n.DialTimeout) case "setudptimeout": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid UDP Timeout Value", cmds[2]) continue } n.UDPTimeout = time.Millisecond * time.Duration(tmp) starlog.Noticef("UDP Timeout Set to %d\n", n.UDPTimeout) case "setstdin": if cmds[2] == "off" { n.StdinMode = false starlog.Noticef("Stdin Mode Off\n") return } } } }() } if n.EnableTCP { go n.runTCP() } if n.EnableUDP { go n.runUDP() } return nil } func (n *NetForward) runTCP() error { atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) listen, err := net.Listen("tcp", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) if err != nil { starlog.Errorln("Listening On Tcp Failed:", err) return err } go func() { <-n.stopCtx.Done() listen.Close() }() starlog.Infof("Listening TCP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) for { select { case <-n.stopCtx.Done(): return nil default: } conn, err := listen.Accept() if err != nil { continue } log := starlog.Std.NewFlag() log.Infof("Accept New TCP Conn from %v\n", conn.RemoteAddr().String()) go func(conn net.Conn) { rmt, err := net.DialTimeout("tcp", n.RemoteURI, n.DialTimeout) if err != nil { log.Errorf("TCP:Dial Remote %s Failed:%v\n", n.RemoteURI, err) conn.Close() return } log.Infof("TCP Connect %s <==> %s\n", conn.RemoteAddr().String(), rmt.RemoteAddr().String()) n.copy(rmt, conn) log.Noticef("TCP Connection Closed %s <==> %s\n", conn.RemoteAddr().String(), n.RemoteURI) }(conn) } } type UDPConn struct { net.Conn listen *net.UDPConn remoteAddr *net.UDPAddr lastbeat int64 } func (u UDPConn) Write(p []byte) (n int, err error) { u.lastbeat = time.Now().Unix() return u.Conn.Write(p) } func (u UDPConn) Read(p []byte) (n int, err error) { u.lastbeat = time.Now().Unix() return u.Conn.Read(p) } func (u UDPConn) Work(delay int) { buf := make([]byte, 8192) for { if delay > 0 { time.Sleep(time.Millisecond * time.Duration(delay)) } count, err := u.Read(buf) if err != nil { u.Close() u.lastbeat = 0 return } _, err = u.listen.Write(buf[0:count]) if err != nil { u.lastbeat = 0 return } } } func (n *NetForward) runUDP() error { var mu sync.RWMutex atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%v", n.LocalAddr, n.LocalPort)) if err != nil { return err } listen, err := net.ListenUDP("udp", udpAddr) if err != nil { return err } starlog.Infof("Listening UDP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) go func() { <-n.stopCtx.Done() listen.Close() }() udpMap := make(map[string]UDPConn) go func() { for { select { case <-n.stopCtx.Done(): return case <-time.After(time.Second * 60): mu.Lock() for k, v := range udpMap { if time.Now().Unix() > int64(n.UDPTimeout.Seconds())+v.lastbeat { delete(udpMap, k) starlog.Noticef("UDP Connection Closed %s <==> %s\n", v.remoteAddr.String(), n.RemoteURI) } } mu.Unlock() } } }() buf := make([]byte, 8192) for { select { case <-n.stopCtx.Done(): return nil default: } count, rmt, err := listen.ReadFromUDP(buf) if err != nil || rmt.String() == n.RemoteURI { continue } go func(data []byte, rmt *net.UDPAddr) { log := starlog.Std.NewFlag() mu.Lock() addr, ok := udpMap[rmt.String()] if !ok { log.Infof("Accept New UDP Conn from %v\n", rmt.String()) conn, err := net.Dial("udp", n.RemoteURI) if err != nil { log.Errorf("UDP:Dial Remote %s Failed:%v\n", n.RemoteURI, err) mu.Unlock() return } addr = UDPConn{ Conn: conn, remoteAddr: rmt, listen: listen, lastbeat: time.Now().Unix(), } udpMap[rmt.String()] = addr go addr.Work(n.DelayMilSec) log.Infof("UDP Connect %s <==> %s\n", rmt.String(), n.RemoteURI) } mu.Unlock() if n.DelayMilSec > 0 || (n.DelayToward == 0 || n.DelayToward == 1) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } _, err := addr.Write(data) if err != nil { mu.Lock() addr.Close() delete(udpMap, addr.remoteAddr.String()) mu.Unlock() log.Noticef("UDP Connection Closed %s <==> %s\n", rmt.String(), n.RemoteURI) } }(buf[0:count], rmt) } } func (n *NetForward) copy(dst, src net.Conn) { var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() bufsize := make([]byte, 32*1024) for { if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } n, err := src.Read(bufsize) if err != nil { dst.Close() src.Close() return } _, err = dst.Write(bufsize[:n]) if err != nil { src.Close() dst.Close() return } } }() go func() { defer wg.Done() bufsize := make([]byte, 32*1024) for { if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 2) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } n, err := dst.Read(bufsize) if err != nil { src.Close() dst.Close() return } _, err = src.Write(bufsize[:n]) if err != nil { src.Close() dst.Close() return } } }() wg.Wait() }