From 8845d353392e95aa40e8ae4d699be0e7eb6b1980 Mon Sep 17 00:00:00 2001 From: starainrt Date: Fri, 8 Nov 2024 22:22:59 +0800 Subject: [PATCH] mget bug fix --- main.go | 2 +- mget/cmd.go | 22 +++++++++++++++++++--- mget/process.go | 25 +++++++++++++++---------- mget/redo.go | 33 +++++++++++++++++++++++++++++++++ mget/util.go | 2 +- mget/wget.go | 45 +++++++++++++++++++++++++++++++++++++++++---- 6 files changed, 110 insertions(+), 19 deletions(-) diff --git a/main.go b/main.go index c9fb03b..e4337cf 100644 --- a/main.go +++ b/main.go @@ -42,7 +42,7 @@ import ( var cmdRoot = &cobra.Command{ Use: "b612", - Version: "2.1.0.beta.12", + Version: "2.1.0.beta.13", } func init() { diff --git a/mget/cmd.go b/mget/cmd.go index 5baf592..045ea22 100644 --- a/mget/cmd.go +++ b/mget/cmd.go @@ -3,6 +3,7 @@ package mget import ( "b612.me/stario" "b612.me/starlog" + "b612.me/starnet" "fmt" "github.com/spf13/cobra" "os" @@ -26,7 +27,8 @@ var headers []string var ua string var proxy string var skipVerify bool -var speedcontrol string +var speedcontrol, user, pwd string +var dialTimeout, timeout int func init() { Cmd.Flags().StringVarP(&mg.Tareget, "output", "o", "", "输出文件名") @@ -34,10 +36,14 @@ func init() { Cmd.Flags().IntVarP(&mg.Thread, "thread", "t", 8, "线程数") Cmd.Flags().IntVarP(&mg.RedoRPO, "safe", "s", 1048576, "安全校验点") Cmd.Flags().StringSliceVarP(&headers, "header", "H", []string{}, "自定义请求头,格式: key=value") - Cmd.Flags().StringVarP(&proxy, "proxy", "p", "", "代理地址") + Cmd.Flags().StringVarP(&proxy, "proxy", "P", "", "代理地址") Cmd.Flags().StringVarP(&ua, "user-agent", "U", "", "自定义User-Agent") Cmd.Flags().BoolVarP(&skipVerify, "skip-verify", "k", false, "跳过SSL验证") Cmd.Flags().StringVarP(&speedcontrol, "speed", "S", "", "限速,如1M,意思是1MB/s") + Cmd.Flags().IntVarP(&dialTimeout, "dial-timeout", "d", 5, "连接网络超时时间,单位:秒") + Cmd.Flags().IntVarP(&timeout, "timeout", "T", 0, "下载超时时间,单位:秒") + Cmd.Flags().StringVarP(&user, "user", "u", "", "http basic认证用户") + Cmd.Flags().StringVarP(&pwd, "passwd", "p", "", "http basic认证密码") } func parseSpeedString(speedString string) (uint64, error) { @@ -83,6 +89,17 @@ func Run(cmd *cobra.Command, args []string) { starlog.Errorln("缺少URL参数") os.Exit(1) } + mg.Setting = *starnet.NewSimpleRequest(args[0], "GET") + mg.OriginUri = args[0] + if dialTimeout > 0 { + mg.Setting.SetDialTimeout(time.Duration(dialTimeout) * time.Second) + } + if timeout > 0 { + mg.Setting.SetTimeout(time.Duration(timeout) * time.Second) + } + if user != "" || pwd != "" { + mg.Setting.RequestOpts.SetBasicAuth(user, pwd) + } if speedcontrol != "" { speed, err := parseSpeedString(speedcontrol) if err != nil { @@ -109,7 +126,6 @@ func Run(cmd *cobra.Command, args []string) { if skipVerify { mg.Setting.SetSkipTLSVerify(true) } - mg.OriginUri = args[0] sig := make(chan os.Signal) signal.Notify(sig, os.Interrupt) select { diff --git a/mget/process.go b/mget/process.go index 0e19c4c..8c14ef8 100644 --- a/mget/process.go +++ b/mget/process.go @@ -13,12 +13,17 @@ func (m *Mget) processMiddleware(base mpb.BarFiller) mpb.BarFiller { fn := func(w io.Writer, st decor.Statistics) error { var res string count := 0 - _, err := fmt.Fprintf(w, "\nFinished:%s Total Write:%d Speed:%v\n\n", m.Redo.FormatPercent(), m.Redo.Total(), m.Redo.FormatSpeed("MB")) + fmt.Fprintf(w, "\nSpeed:%v AvgSpeed:%v\n", m.Redo.FormatSpeed("MB"), m.Redo.FormatAvgSpeed("MB")) + _, err := fmt.Fprintf(w, "Finished:%s Total Write:%d\n\n", m.Redo.FormatPercent(), m.Redo.Total()) for k := range m.threads { v := m.threads[len(m.threads)-1-k] if v != nil { count++ - res = fmt.Sprintf("Thread %v: %s %s\t", len(m.threads)-k, v.FormatSpeed("MB"), v.FormatPercent()) + res + percent := v.FormatPercent() + if m.Redo.Total() == m.Redo.ContentLength { + percent = "100.00%" + } + res = fmt.Sprintf("Thread %v: %s %s\t", len(m.threads)-k, v.FormatSpeed("MB"), percent) + res if count%3 == 0 { res = strings.TrimRight(res, "\t") fmt.Fprintf(w, "%s\n", res) @@ -60,16 +65,16 @@ func (w *Mget) Process() { decor.Counters(decor.SizeB1024(0), "% .2f / % .2f"), ), mpb.AppendDecorators( - decor.EwmaETA(decor.ET_STYLE_GO, 30), + decor.AverageETA(decor.ET_STYLE_GO), decor.Name(" ] "), - decor.EwmaSpeed(decor.SizeB1024(0), "% .2f ", 30), + decor.AverageSpeed(decor.SizeB1024(0), "% .2f "), ), ) defer p.Wait() + lastTime := time.Now() + bar.SetRefill(int64(w.Redo.Total())) + bar.DecoratorAverageAdjust(time.Now().Add(time.Millisecond * time.Duration(-w.TimeCost))) for { - last := w.Redo.Total() - lastTime := time.Now() - bar.SetCurrent(int64(w.Redo.Total())) select { case <-w.ctx.Done(): bar.SetCurrent(int64(w.Redo.Total())) @@ -88,9 +93,9 @@ func (w *Mget) Process() { return } now := w.Redo.Total() - bar.EwmaIncrInt64(int64(now-last), time.Since(lastTime)) - lastTime = time.Now() - last = now + date := time.Now() + bar.EwmaSetCurrent(int64(now), date.Sub(lastTime)) + lastTime = date if w.dynLength { bar.SetTotal(int64(w.Redo.ContentLength), false) } diff --git a/mget/redo.go b/mget/redo.go index dc60c92..b5b35dc 100644 --- a/mget/redo.go +++ b/mget/redo.go @@ -16,10 +16,14 @@ type Redo struct { Filename string `json:"filename"` ContentLength uint64 `json:"content_length"` Range []Range `json:"range"` + TimeCost uint64 `json:"time_cost"` rangeUpdated bool + startDate time.Time + startCount uint64 lastUpdate time.Time lastTotal uint64 speed float64 + avgSpeed float64 total uint64 isRedo bool sync.RWMutex @@ -40,6 +44,7 @@ func (r *Redo) Total() uint64 { r.RUnlock() if r.total > r.ContentLength && r.ContentLength > 0 { r.reform() + total = 0 continue } break @@ -54,6 +59,13 @@ func (r *Redo) Update(start, end int) error { r.Lock() defer r.Unlock() r.rangeUpdated = true + if r.lastUpdate.IsZero() { + r.startDate = time.Now() + for _, v := range r.Range { + r.startCount += v.Max - v.Min + 1 + } + time.Sleep(time.Millisecond) + } r.Range = append(r.Range, Range{uint64(start), uint64(end)}) now := time.Now() if now.Sub(r.lastUpdate) >= time.Millisecond*500 { @@ -63,6 +75,10 @@ func (r *Redo) Update(start, end int) error { } r.total = total r.speed = float64(total-r.lastTotal) / (float64(now.Sub(r.lastUpdate).Milliseconds()) / 1000.00) + if !r.lastUpdate.IsZero() { + r.TimeCost += uint64(now.Sub(r.lastUpdate).Milliseconds()) + } + r.avgSpeed = float64(total-r.startCount) / (float64(now.Sub(r.startDate).Milliseconds()) / 1000.00) r.lastTotal = total r.lastUpdate = now } @@ -90,10 +106,27 @@ func (r *Redo) FormatSpeed(unit string) string { } } +func (r *Redo) FormatAvgSpeed(unit string) string { + switch strings.ToLower(unit) { + case "kb": + return fmt.Sprintf("%.2f KB/s", r.avgSpeed/1024) + case "mb": + return fmt.Sprintf("%.2f MB/s", r.avgSpeed/1024/1024) + case "gb": + return fmt.Sprintf("%.2f GB/s", r.avgSpeed/1024/1024/1024) + default: + return fmt.Sprintf("%.2f B/s", r.avgSpeed) + } +} + func (r *Redo) Speed() float64 { return r.speed } +func (r *Redo) AverageSpeed() float64 { + return r.avgSpeed +} + func (r *Redo) Save() error { var err error err = r.reform() diff --git a/mget/util.go b/mget/util.go index dd378bf..7507968 100644 --- a/mget/util.go +++ b/mget/util.go @@ -75,7 +75,7 @@ func IOWriter(stopCtx context.Context, ch chan Buffer, state *uint32, di *downlo *start += int64(n) di.AddCurrent(int64(n)) } - if *start >= *end { + if *end != 0 && *start >= *end { return nil } if err != nil { diff --git a/mget/wget.go b/mget/wget.go index 3f12c12..e0f876c 100644 --- a/mget/wget.go +++ b/mget/wget.go @@ -53,6 +53,15 @@ func (w *Mget) Clone() *starnet.Request { req.SetCookies(CloneCookies(w.Setting.Cookies())) req.SetSkipTLSVerify(w.Setting.SkipTLSVerify()) req.SetProxy(w.Setting.Proxy()) + if w.Setting.DialTimeout() > 0 { + req.SetDialTimeout(w.Setting.DialTimeout()) + } + if w.Setting.Timeout() > 0 { + req.SetTimeout(w.Setting.Timeout()) + } + if u, p := w.Setting.BasicAuth(); u != "" || p != "" { + req.SetBasicAuth(u, p) + } return req } @@ -88,6 +97,10 @@ func (w *Mget) prepareRun(res *starnet.Response, is206 bool) error { fmt.Println("Will write to:", w.Tareget) fmt.Println("Size:", w.TargetSize) fmt.Println("Is206:", is206) + fmt.Println("IsDynLen:", w.dynLength) + if !is206 { + w.Thread = 1 + } w.Redo = Redo{ Filename: w.Tareget, ContentLength: uint64(w.TargetSize), @@ -136,7 +149,8 @@ func (w *Mget) Run() error { defer w.fn() w.threads = make([]*downloader, w.Thread) if w.Setting.Uri() == "" { - w.Setting = *starnet.NewSimpleRequest(w.OriginUri, "GET") + w.Setting.SetUri(w.OriginUri) + w.Setting.SetMethod("GET") } for { res, is206, err = w.IsUrl206() @@ -151,6 +165,11 @@ func (w *Mget) Run() error { return fmt.Errorf("Server return %d", res.StatusCode) } if !is206 { + go func() { + w.writeEnable = true + w.writeError = w.WriteServer() + w.writeEnable = false + }() var di = &downloader{ alive: true, downloadinfo: &downloadinfo{ @@ -159,14 +178,31 @@ func (w *Mget) Run() error { Size: w.TargetSize, }, } + if w.dynLength { + di.End = 0 + } + w.writeEnable = true w.threads[0] = di + w.Thread = 1 + go w.Process() state := uint32(0) err = IOWriter(w.ctx, w.ch, &state, di.downloadinfo, res.Body().Reader(), w.BufferSize, &di.Start, &di.End) di.alive = false if err == nil { + w.writeEnable = false + stario.WaitUntilTimeout(time.Second*2, + func(c chan struct{}) error { + for { + if w.processEnable { + time.Sleep(time.Millisecond * 50) + continue + } + return nil + } + }) return nil } - continue + return err } else { res.Body().Close() } @@ -187,9 +223,10 @@ func (w *Mget) Run() error { go w.Process() w.wg.Wait() time.Sleep(2 * time.Microsecond) + exitFn := sync.OnceFunc(w.fn) for { if w.writeEnable { - w.fn() + exitFn() time.Sleep(time.Millisecond * 50) continue } @@ -199,7 +236,7 @@ func (w *Mget) Run() error { } break } - w.fn() + exitFn() stario.WaitUntilTimeout(time.Second*2, func(c chan struct{}) error { for {