star/mget/cmd.go
2024-09-15 15:27:50 +08:00

131 lines
3.3 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mget
import (
"b612.me/stario"
"b612.me/starlog"
"fmt"
"github.com/spf13/cobra"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"time"
)
var mg Mget
var Cmd = &cobra.Command{
Use: "mget",
Short: "多线程下载工具",
Long: `多线程下载工具`,
Run: Run,
}
var headers []string
var ua string
var proxy string
var skipVerify bool
var speedcontrol string
func init() {
Cmd.Flags().StringVarP(&mg.Tareget, "output", "o", "", "输出文件名")
Cmd.Flags().IntVarP(&mg.BufferSize, "buffer", "b", 8192, "缓冲区大小")
Cmd.Flags().IntVarP(&mg.Thread, "thread", "t", 8, "线程数")
Cmd.Flags().IntVarP(&mg.RedoRPO, "safe", "s", 1048576, "安全校验点")
Cmd.Flags().StringSliceVarP(&headers, "header", "H", []string{}, "自定义请求头,格式: key=value")
Cmd.Flags().StringVarP(&proxy, "proxy", "p", "", "代理地址")
Cmd.Flags().StringVarP(&ua, "user-agent", "U", "", "自定义User-Agent")
Cmd.Flags().BoolVarP(&skipVerify, "skip-verify", "k", false, "跳过SSL验证")
Cmd.Flags().StringVarP(&speedcontrol, "speed", "S", "", "限速如1M意思是1MB/s")
}
func parseSpeedString(speedString string) (uint64, error) {
// 定义单位及其对应的字节值
unitMultipliers := map[string]int{
"b": 1, "": 1,
"k": 1024, "kb": 1024, "kib": 1024,
"m": 1024 * 1024, "mb": 1024 * 1024, "mib": 1024 * 1024,
"g": 1024 * 1024 * 1024, "gb": 1024 * 1024 * 1024, "gib": 1024 * 1024 * 1024,
"t": 1024 * 1024 * 1024 * 1024, "tb": 1024 * 1024 * 1024 * 1024, "tib": 1024 * 1024 * 1024 * 1024,
}
// 正则表达式匹配速度的格式
re := regexp.MustCompile(`(?i)^\s*([\d.]+)\s*(b|k|m|g|t|kb|mb|gb|tb|kib|mib|gib|tib)?\s*/?\s*s?\s*$`)
matches := re.FindStringSubmatch(strings.ToLower(speedString))
if matches == nil {
return 0, fmt.Errorf("invalid speed string format")
}
// 解析数值部分
value, err := strconv.ParseFloat(matches[1], 64)
if err != nil {
return 0, fmt.Errorf("invalid numeric value")
}
// 获取单位部分
unit := matches[2]
if unit == "" {
unit = "b"
}
// 根据单位计算最终的字节每秒值
multiplier, ok := unitMultipliers[unit]
if !ok {
return 0, fmt.Errorf("invalid unit in speed string")
}
return uint64(value * float64(multiplier)), nil
}
func Run(cmd *cobra.Command, args []string) {
if args == nil || len(args) == 0 {
starlog.Errorln("缺少URL参数")
os.Exit(1)
}
if speedcontrol != "" {
speed, err := parseSpeedString(speedcontrol)
if err != nil {
starlog.Criticalln("Speed Limit Error:", err)
os.Exit(1)
}
mg.speedlimit = int64(speed)
fmt.Printf("Max Speed Limit:(user in):\t%v\n", speedcontrol)
fmt.Printf("Max Speed Limit (bytes/s):\t%v bytes/sec\n", speed)
}
for _, v := range headers {
kv := strings.SplitN(v, "=", 2)
if len(kv) != 2 {
continue
}
mg.Setting.AddHeader(strings.TrimSpace(kv[0]), strings.TrimSpace(kv[1]))
}
if ua != "" {
mg.Setting.SetUserAgent(ua)
}
if proxy != "" {
mg.Setting.SetProxy(proxy)
}
if skipVerify {
mg.Setting.SetSkipTLSVerify(true)
}
mg.OriginUri = args[0]
sig := make(chan os.Signal)
signal.Notify(sig, os.Interrupt)
select {
case err := <-stario.WaitUntilFinished(mg.Run):
if err != nil {
starlog.Errorln(err)
os.Exit(2)
}
time.Sleep(time.Second)
return
case <-sig:
starlog.Infoln("User Interrupted")
mg.fn()
time.Sleep(time.Second)
mg.Redo.Save()
os.Exit(3)
}
}