You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
564 lines
12 KiB
Go
564 lines
12 KiB
Go
package mget
|
|
|
|
import (
|
|
"b612.me/stario"
|
|
"b612.me/starnet"
|
|
"b612.me/staros"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
type Mget struct {
|
|
Setting starnet.Request
|
|
Redo
|
|
//本地文件地址
|
|
Tareget string
|
|
//本地文件大小
|
|
TargetSize int64
|
|
//redo文件最大丢数据量
|
|
RedoRPO int
|
|
//单个buffer大小
|
|
BufferSize int
|
|
//并发下载线程数
|
|
dynLength bool
|
|
Thread int `json:"thread"`
|
|
tf *os.File
|
|
ch chan Buffer
|
|
ctx context.Context
|
|
fn context.CancelFunc
|
|
wg sync.WaitGroup
|
|
threads []*downloader
|
|
lastUndoInfo []Range
|
|
writeError error
|
|
writeEnable bool
|
|
processEnable bool
|
|
speedlimit int64
|
|
}
|
|
|
|
type Buffer struct {
|
|
Data []byte
|
|
Start uint64
|
|
}
|
|
|
|
func (w *Mget) Clone() *starnet.Request {
|
|
req := starnet.NewSimpleRequest(w.Setting.Uri(), w.Setting.Method())
|
|
req.SetHeaders(CloneHeader(w.Setting.Headers()))
|
|
req.SetCookies(CloneCookies(w.Setting.Cookies()))
|
|
req.SetSkipTLSVerify(w.Setting.SkipTLSVerify())
|
|
req.SetProxy(w.Setting.Proxy())
|
|
if w.Setting.DialTimeout() > 0 {
|
|
req.SetDialTimeout(w.Setting.DialTimeout())
|
|
}
|
|
if w.Setting.Timeout() > 0 {
|
|
req.SetTimeout(w.Setting.Timeout())
|
|
}
|
|
if u, p := w.Setting.BasicAuth(); u != "" || p != "" {
|
|
req.SetBasicAuth(u, p)
|
|
}
|
|
return req
|
|
}
|
|
|
|
func (w *Mget) IsUrl206() (*starnet.Response, bool, error) {
|
|
req := w.Clone()
|
|
req.SetHeader("Range", "bytes=0-")
|
|
res, err := req.Do()
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
if res.StatusCode == 206 {
|
|
return res, true, nil
|
|
}
|
|
return res, false, nil
|
|
}
|
|
|
|
func (w *Mget) prepareRun(res *starnet.Response, is206 bool) error {
|
|
var err error
|
|
|
|
length := res.Header.Get("Content-Length")
|
|
if length == "" {
|
|
length = "0"
|
|
w.dynLength = true
|
|
is206 = false
|
|
}
|
|
w.TargetSize, err = strconv.ParseInt(length, 10, 64)
|
|
if err != nil {
|
|
return fmt.Errorf("parse content length error: %w", err)
|
|
}
|
|
if w.Tareget == "" {
|
|
w.Tareget = GetFileName(res.Response)
|
|
}
|
|
fmt.Println("Will write to:", w.Tareget)
|
|
fmt.Println("Size:", w.TargetSize)
|
|
fmt.Println("Is206:", is206)
|
|
fmt.Println("IsDynLen:", w.dynLength)
|
|
if !is206 {
|
|
w.Thread = 1
|
|
}
|
|
w.Redo = Redo{
|
|
Filename: w.Tareget,
|
|
ContentLength: uint64(w.TargetSize),
|
|
OriginUri: w.Setting.Uri(),
|
|
Date: time.Now(),
|
|
Is206: is206,
|
|
}
|
|
fmt.Println("Threads:", w.Thread)
|
|
if staros.Exists(w.Tareget + ".bgrd") {
|
|
fmt.Println("Found redo file, try to recover...")
|
|
var redo Redo
|
|
data, err := os.ReadFile(w.Tareget + ".bgrd")
|
|
if err != nil {
|
|
return fmt.Errorf("read redo file error: %w", err)
|
|
}
|
|
err = json.Unmarshal(data, &redo)
|
|
if err != nil {
|
|
return fmt.Errorf("unmarshal redo file error: %w", err)
|
|
}
|
|
redo.reform()
|
|
if redo.ContentLength != w.Redo.ContentLength {
|
|
fmt.Println("Content length not match, redo file may be invalid, ignore it")
|
|
return nil
|
|
}
|
|
if redo.OriginUri != w.Redo.OriginUri {
|
|
fmt.Println("Origin uri not match, redo file may be invalid, ignore it")
|
|
return nil
|
|
}
|
|
w.Redo = redo
|
|
w.Redo.isRedo = true
|
|
w.lastUndoInfo, err = w.Redo.ReverseRange()
|
|
if err != nil {
|
|
return fmt.Errorf("reverse redo range error: %w", err)
|
|
}
|
|
fmt.Println("Recover redo file success,process:", w.Redo.FormatPercent())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (w *Mget) Run() error {
|
|
var err error
|
|
var res *starnet.Response
|
|
var is206 bool
|
|
w.ctx, w.fn = context.WithCancel(context.Background())
|
|
w.ch = make(chan Buffer)
|
|
defer w.fn()
|
|
w.threads = make([]*downloader, w.Thread)
|
|
if w.Setting.Uri() == "" {
|
|
w.Setting.SetUri(w.OriginUri)
|
|
w.Setting.SetMethod("GET")
|
|
}
|
|
for {
|
|
res, is206, err = w.IsUrl206()
|
|
if err != nil {
|
|
return fmt.Errorf("check 206 error: %w", err)
|
|
}
|
|
err = w.prepareRun(res, is206)
|
|
if err != nil {
|
|
return fmt.Errorf("prepare run error: %w", err)
|
|
}
|
|
if res.StatusCode != 206 && res.StatusCode != 200 {
|
|
return fmt.Errorf("Server return %d", res.StatusCode)
|
|
}
|
|
if !is206 {
|
|
go func() {
|
|
w.writeEnable = true
|
|
w.writeError = w.WriteServer()
|
|
w.writeEnable = false
|
|
}()
|
|
var di = &downloader{
|
|
alive: true,
|
|
downloadinfo: &downloadinfo{
|
|
Start: 0,
|
|
End: w.TargetSize - 1,
|
|
Size: w.TargetSize,
|
|
},
|
|
}
|
|
if w.dynLength {
|
|
di.End = 0
|
|
}
|
|
w.writeEnable = true
|
|
w.threads[0] = di
|
|
w.Thread = 1
|
|
go w.Process()
|
|
state := uint32(0)
|
|
err = IOWriter(w.ctx, w.ch, &state, di.downloadinfo, res.Body().Reader(), w.BufferSize, &di.Start, &di.End)
|
|
di.alive = false
|
|
if err == nil {
|
|
w.writeEnable = false
|
|
stario.WaitUntilTimeout(time.Second*2,
|
|
func(c chan struct{}) error {
|
|
for {
|
|
if w.processEnable {
|
|
time.Sleep(time.Millisecond * 50)
|
|
continue
|
|
}
|
|
return nil
|
|
}
|
|
})
|
|
return nil
|
|
}
|
|
return err
|
|
} else {
|
|
res.Body().Close()
|
|
}
|
|
break
|
|
}
|
|
go func() {
|
|
w.writeEnable = true
|
|
w.writeError = w.WriteServer()
|
|
w.writeEnable = false
|
|
}()
|
|
if w.TargetSize == 0 {
|
|
return nil
|
|
}
|
|
for i := 0; i < w.Thread; i++ {
|
|
w.wg.Add(1)
|
|
go w.dispatch(i)
|
|
}
|
|
go w.Process()
|
|
w.wg.Wait()
|
|
time.Sleep(2 * time.Microsecond)
|
|
var once sync.Once
|
|
for {
|
|
if w.writeEnable {
|
|
once.Do(w.fn)
|
|
time.Sleep(time.Millisecond * 50)
|
|
continue
|
|
}
|
|
if w.writeError != nil {
|
|
err = w.Redo.Save()
|
|
return fmt.Errorf("write error: %w %v", w.writeError, err)
|
|
}
|
|
break
|
|
}
|
|
once.Do(w.fn)
|
|
stario.WaitUntilTimeout(time.Second*2,
|
|
func(c chan struct{}) error {
|
|
for {
|
|
if w.processEnable {
|
|
time.Sleep(time.Millisecond * 50)
|
|
continue
|
|
}
|
|
return nil
|
|
}
|
|
})
|
|
|
|
r, err := w.ReverseRange()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(r) == 0 {
|
|
return os.Remove(w.Tareget + ".bgrd")
|
|
}
|
|
return w.Redo.Save()
|
|
}
|
|
|
|
func (w *Mget) dispatch(idx int) error {
|
|
defer w.wg.Done()
|
|
var start, end int64
|
|
if len(w.lastUndoInfo) == 0 {
|
|
count := w.TargetSize / int64(w.Thread)
|
|
start = count * int64(idx)
|
|
end = count*int64(idx+1) - 1
|
|
if idx == w.Thread-1 {
|
|
end = w.TargetSize - 1
|
|
}
|
|
} else {
|
|
w.Lock()
|
|
if len(w.lastUndoInfo) == 0 {
|
|
d := &downloader{}
|
|
w.threads[idx] = d
|
|
w.Unlock()
|
|
goto morejob
|
|
}
|
|
start = int64(w.lastUndoInfo[0].Min)
|
|
end = int64(w.lastUndoInfo[0].Max)
|
|
w.lastUndoInfo = w.lastUndoInfo[1:]
|
|
w.Unlock()
|
|
}
|
|
for {
|
|
req := w.Clone()
|
|
req.SetCookies(CloneCookies(w.Setting.Cookies()))
|
|
d := &downloader{
|
|
Request: req,
|
|
ch: w.ch,
|
|
ctx: w.ctx,
|
|
bufferSize: w.BufferSize,
|
|
downloadinfo: &downloadinfo{
|
|
Start: start,
|
|
End: end,
|
|
},
|
|
}
|
|
w.threads[idx] = d
|
|
if err := d.Run(); err != nil {
|
|
fmt.Printf("thread %d error: %v\n", idx, err)
|
|
if d.Start >= d.End {
|
|
break
|
|
}
|
|
start = d.Start
|
|
end = d.End
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
morejob:
|
|
for {
|
|
w.Lock()
|
|
if len(w.lastUndoInfo) > 0 {
|
|
w.threads[idx].Start = int64(w.lastUndoInfo[idx].Min)
|
|
w.threads[idx].End = int64(w.lastUndoInfo[idx].Max)
|
|
w.lastUndoInfo = w.lastUndoInfo[1:]
|
|
w.Unlock()
|
|
} else {
|
|
w.Unlock()
|
|
if !w.RequestNewTask(w.threads[idx]) {
|
|
break
|
|
}
|
|
}
|
|
for {
|
|
req := w.Clone()
|
|
req.SetCookies(CloneCookies(w.Setting.Cookies()))
|
|
d := &downloader{
|
|
Request: req,
|
|
ch: w.ch,
|
|
ctx: w.ctx,
|
|
bufferSize: w.BufferSize,
|
|
downloadinfo: &downloadinfo{
|
|
Start: w.threads[idx].Start,
|
|
End: w.threads[idx].End,
|
|
},
|
|
}
|
|
w.threads[idx] = d
|
|
if err := d.Run(); err != nil {
|
|
fmt.Printf("thread %d error: %v\n", idx, err)
|
|
if d.Start >= d.End {
|
|
break
|
|
}
|
|
start = d.Start
|
|
end = d.End
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (w *Mget) getSleepTime() time.Duration {
|
|
if w.speedlimit == 0 {
|
|
return 0
|
|
}
|
|
return time.Nanosecond * time.Duration(16384*1000*1000*1000/w.speedlimit) / 2
|
|
|
|
}
|
|
func (w *Mget) WriteServer() error {
|
|
var err error
|
|
defer w.fn()
|
|
if !w.isRedo {
|
|
w.tf, err = createFileWithSize(w.Tareget, w.TargetSize)
|
|
} else {
|
|
w.tf, err = os.OpenFile(w.Tareget, os.O_RDWR, 0666)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
lastUpdateRange := 0
|
|
currentRange := 0
|
|
|
|
currentCount := int64(0)
|
|
lastDate := time.Now()
|
|
lastCount := int64(0)
|
|
speedControl := func(count int) {
|
|
if w.speedlimit == 0 {
|
|
return
|
|
}
|
|
currentCount += int64(count)
|
|
for {
|
|
if time.Since(lastDate) < time.Second {
|
|
if currentCount-lastCount > w.speedlimit {
|
|
time.Sleep(w.getSleepTime())
|
|
} else {
|
|
break
|
|
}
|
|
} else {
|
|
lastDate = time.Now()
|
|
lastCount = currentCount
|
|
break
|
|
}
|
|
}
|
|
}
|
|
for {
|
|
select {
|
|
case <-w.ctx.Done():
|
|
return nil
|
|
case b := <-w.ch:
|
|
n, err := w.tf.WriteAt(b.Data, int64(b.Start))
|
|
if err != nil {
|
|
fmt.Println("write error:", err)
|
|
return err
|
|
}
|
|
speedControl(n)
|
|
if w.dynLength {
|
|
w.ContentLength += uint64(n)
|
|
}
|
|
currentRange += n
|
|
end := b.Start + uint64(n) - 1
|
|
err = w.Update(int(b.Start), int(end))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if currentRange-lastUpdateRange >= w.RedoRPO {
|
|
w.tf.Sync()
|
|
go w.Redo.Save()
|
|
lastUpdateRange = currentRange
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type downloader struct {
|
|
*starnet.Request
|
|
alive bool
|
|
ch chan Buffer
|
|
ctx context.Context
|
|
state uint32
|
|
bufferSize int
|
|
*downloadinfo
|
|
}
|
|
|
|
func (d *downloader) Run() error {
|
|
d.alive = true
|
|
defer func() {
|
|
d.alive = false
|
|
}()
|
|
d.SetHeader("Range", fmt.Sprintf("bytes=%d-%d", d.Start, d.End))
|
|
res, err := d.Do()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if res.Header.Get("Content-Range") == "" {
|
|
return fmt.Errorf("server not support range")
|
|
}
|
|
start, end, _, err := parseContentRange(res.Header.Get("Content-Range"))
|
|
if d.Start != start {
|
|
return fmt.Errorf("server not support range")
|
|
}
|
|
d.End = end
|
|
d.downloadinfo = &downloadinfo{
|
|
Start: d.Start,
|
|
End: d.End,
|
|
Size: d.End - d.Start + 1,
|
|
}
|
|
reader := res.Body().Reader()
|
|
return IOWriter(d.ctx, d.ch, &d.state, d.downloadinfo, reader, d.bufferSize, &d.Start, &d.End)
|
|
}
|
|
|
|
func (w *Mget) RequestNewTask(task *downloader) bool {
|
|
//stop thhe world first
|
|
w.Lock()
|
|
defer w.Unlock()
|
|
defer func() {
|
|
for _, v := range w.threads {
|
|
if v != nil {
|
|
atomic.StoreUint32(&v.state, 0)
|
|
}
|
|
}
|
|
}()
|
|
var maxThread *downloader
|
|
for _, v := range w.threads {
|
|
if v != nil {
|
|
atomic.StoreUint32(&v.state, 1)
|
|
}
|
|
}
|
|
time.Sleep(time.Microsecond * 2)
|
|
|
|
for _, v := range w.threads {
|
|
if v == nil {
|
|
continue
|
|
}
|
|
if maxThread == nil {
|
|
maxThread = v
|
|
continue
|
|
}
|
|
if v.End-v.Start > maxThread.End-maxThread.Start {
|
|
maxThread = v
|
|
}
|
|
}
|
|
if maxThread == nil || maxThread.End <= maxThread.Start {
|
|
return false
|
|
}
|
|
if (maxThread.End-maxThread.Start)/2 < int64(w.BufferSize*2) || (maxThread.End-maxThread.Start)/2 < 100*1024 {
|
|
return false
|
|
}
|
|
task.End = maxThread.End
|
|
maxThread.End = maxThread.Start + (maxThread.End-maxThread.Start)/2
|
|
task.Start = maxThread.End + 1
|
|
//fmt.Printf("thread got new task %d-%d\n", task.Start, task.End)
|
|
return true
|
|
}
|
|
|
|
type downloadinfo struct {
|
|
Start int64
|
|
End int64
|
|
Size int64
|
|
current int64
|
|
lastCurrent int64
|
|
lastTime time.Time
|
|
speed float64
|
|
}
|
|
|
|
func (d *downloadinfo) Current() int64 {
|
|
return d.current
|
|
}
|
|
|
|
func (d *downloadinfo) Percent() float64 {
|
|
return float64(d.current) / float64(d.Size)
|
|
}
|
|
|
|
func (d *downloadinfo) FormatPercent() string {
|
|
return fmt.Sprintf("%.2f%%", d.Percent()*100)
|
|
}
|
|
|
|
func (d *downloadinfo) SetCurrent(info int64) {
|
|
d.current = info
|
|
now := time.Now()
|
|
if now.Sub(d.lastTime) >= time.Millisecond*500 {
|
|
d.speed = float64(d.current-d.lastCurrent) / (float64(now.Sub(d.lastTime).Milliseconds()) / 1000.00)
|
|
d.lastCurrent = d.current
|
|
d.lastTime = time.Now()
|
|
}
|
|
}
|
|
|
|
func (d *downloadinfo) AddCurrent(info int64) {
|
|
d.current += info
|
|
now := time.Now()
|
|
if now.Sub(d.lastTime) >= time.Millisecond*500 {
|
|
d.speed = float64(d.current-d.lastCurrent) / (float64(now.Sub(d.lastTime).Milliseconds()) / 1000.00)
|
|
d.lastCurrent = d.current
|
|
d.lastTime = time.Now()
|
|
}
|
|
}
|
|
|
|
func (d *downloadinfo) FormatSpeed(unit string) string {
|
|
switch strings.ToLower(unit) {
|
|
case "kb":
|
|
return fmt.Sprintf("%.2f KB/s", d.speed/1024)
|
|
case "mb":
|
|
return fmt.Sprintf("%.2f MB/s", d.speed/1024/1024)
|
|
case "gb":
|
|
return fmt.Sprintf("%.2f GB/s", d.speed/1024/1024/1024)
|
|
default:
|
|
return fmt.Sprintf("%.2f B/s", d.speed)
|
|
}
|
|
}
|
|
|
|
func (d *downloadinfo) Speed() float64 {
|
|
return d.speed
|
|
}
|