package net import ( "b612.me/stario" "b612.me/starlog" "context" "encoding/hex" "fmt" "net" "os" "path/filepath" "runtime" "strings" "time" ) type TcpClient struct { DialTimeout int LocalAddr string RemoteAddr string UsingKeepAlive bool KeepAlivePeriod int KeepAliveIdel int KeepAliveCount int UseLinger int Interactive bool UserTimeout int ShowRecv bool ShowAsHex bool SaveToFolder string Rmt *TcpConn LogPath string stopCtx context.Context stopFn context.CancelFunc } func (s *TcpClient) Close() error { return s.Rmt.Close() } func (s *TcpClient) handleInteractive() { var currentCmd string notifyMap := make(map[string]chan struct{}) if !s.Interactive { return } starlog.Infoln("Interactive mode enabled") for { select { case <-s.stopCtx.Done(): starlog.Infoln("Interactive mode stopped due to context done") return default: } cmd := stario.MessageBox("", "").MustString() if cmd == "" { continue } cmdf := strings.Fields(cmd) switch cmdf[0] { case "hex": currentCmd = "hex" starlog.Infoln("Switch to hex mode,send hex to remote client") case "text": currentCmd = "text" starlog.Infoln("Switch to text mode,send text to remote client") case "close": if s.Rmt.TCPConn == nil { starlog.Errorln("No client selected") continue } s.Rmt.TCPConn.Close() starlog.Infof("Client %s closed\n", s.Rmt.RemoteAddr().String()) s.Rmt = nil currentCmd = "" case "startauto": if s.Rmt == nil { starlog.Errorln("No client selected") continue } notifyMap[s.Rmt.RemoteAddr().String()] = make(chan struct{}) go func(conn *TcpConn) { for { select { case <-notifyMap[conn.RemoteAddr().String()]: starlog.Infoln("Auto send stopped") return default: } _, err := conn.Write([]byte(strings.Repeat("B612", 256))) if err != nil { starlog.Errorln("Write error:", err) return } } }(s.Rmt) starlog.Infoln("Auto send started") case "closeauto": if s.Rmt == nil { starlog.Errorln("No client selected") continue } close(notifyMap[s.Rmt.RemoteAddr().String()]) case "send": if s.Rmt == nil { starlog.Errorln("No client selected") continue } if currentCmd == "hex" { data, err := hex.DecodeString(strings.TrimSpace(strings.TrimPrefix(cmd, "send"))) if err != nil { starlog.Errorln("Hex decode error:", err) continue } _, err = s.Rmt.Write(data) if err != nil { starlog.Errorln("Write error:", err) } else { if s.Rmt.f != nil { s.Rmt.f.Write([]byte(time.Now().String() + " send\n")) s.Rmt.f.Write(data) s.Rmt.f.Write([]byte("\n")) } } } else { _, err := s.Rmt.Write([]byte(strings.TrimSpace(strings.TrimPrefix(cmd, "send")))) if err != nil { starlog.Errorln("Write error:", err) } else { if s.Rmt.f != nil { s.Rmt.f.Write([]byte(time.Now().String() + " send\n")) s.Rmt.f.Write([]byte(cmdf[1])) s.Rmt.f.Write([]byte("\n")) } } } starlog.Infof("Send to %s success\n", s.Rmt.RemoteAddr().String()) } } } func (s *TcpClient) Run() error { var err error s.stopCtx, s.stopFn = context.WithCancel(context.Background()) if s.LogPath != "" { err := starlog.SetLogFile(s.LogPath, starlog.Std, true) if err != nil { starlog.Errorln("SetLogFile error:", err) return fmt.Errorf("SetLogFile error: %w", err) } } var localAddr *net.TCPAddr if s.LocalAddr != "" { localAddr, err = net.ResolveTCPAddr("tcp", s.LocalAddr) if err != nil { starlog.Errorln("ResolveTCPAddr error:", err) return fmt.Errorf("ResolveTCPAddr error: %w", err) } } dialer := net.Dialer{ LocalAddr: localAddr, Timeout: time.Duration(s.DialTimeout) * time.Second, Control: ControlSetReUseAddr, } tcpConn, err := dialer.Dial("tcp", s.RemoteAddr) if err != nil { starlog.Errorln("Dial TCP error:", err) return fmt.Errorf("Dial TCP error: %w", err) } conn := tcpConn.(*net.TCPConn) starlog.Infof("Connected to %s LocalAddr: %s\n", conn.RemoteAddr().String(), conn.LocalAddr().String()) if s.Interactive { go s.handleInteractive() } s.Rmt = s.getTcpConn(conn) s.handleConn(s.Rmt) return nil } func (s *TcpClient) getTcpConn(conn *net.TCPConn) *TcpConn { var err error var f *os.File if s.SaveToFolder != "" { f, err = os.Create(filepath.Join(s.SaveToFolder, strings.ReplaceAll(conn.RemoteAddr().String(), ":", "_"))) if err != nil { starlog.Errorf("Create file error for %s: %v\n", conn.RemoteAddr().String(), err) } } return &TcpConn{ TCPConn: conn, f: f, } } func (s *TcpClient) handleConn(conn *TcpConn) { var err error log := starlog.Std.NewFlag() err = SetTcpInfo(conn.TCPConn, s.UsingKeepAlive, s.KeepAliveIdel, s.KeepAlivePeriod, s.KeepAliveCount, s.UserTimeout) if err != nil { log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err) conn.Close() return } if s.UseLinger >= 0 { conn.SetLinger(s.UseLinger) } log.Infof("SetKeepAlive success for %s\n", conn.RemoteAddr().String()) log.Infof("KeepAlivePeriod: %d, KeepAliveIdel: %d, KeepAliveCount: %d, UserTimeout: %d\n", s.KeepAlivePeriod, s.KeepAliveIdel, s.KeepAliveCount, s.UserTimeout) if runtime.GOOS != "linux" { log.Warningln("keepAliveCount and userTimeout only work on linux") } for { select { case <-s.stopCtx.Done(): log.Infof("Connection from %s closed due to context done\n", conn.RemoteAddr().String()) conn.Close() return default: } buf := make([]byte, 8192) n, err := conn.Read(buf) if err != nil { log.Errorf("Read error for %s: %v\n", conn.RemoteAddr().String(), err) conn.Close() return } if n > 0 { if s.ShowRecv { if s.ShowAsHex { log.Printf("Recv from %s: %x\n", conn.RemoteAddr().String(), buf[:n]) } else { log.Printf("Recv from %s: %s\n", conn.RemoteAddr().String(), string(buf[:n])) } } if conn.f != nil { conn.f.Write([]byte(time.Now().String() + " recv\n")) conn.f.Write(buf[:n]) conn.f.Write([]byte("\n")) } } } } func (s *TcpClient) Stop() { s.stopFn() if s.Rmt != nil { s.Rmt.Close() } }