From d71eacdc91dc92805e01634eb9ad400a4eceb963 Mon Sep 17 00:00:00 2001 From: starainrt Date: Mon, 14 Mar 2022 15:43:56 +0800 Subject: [PATCH] optional function add --- curl.go | 220 ++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 183 insertions(+), 37 deletions(-) diff --git a/curl.go b/curl.go index 315cca6..a8b4d8b 100644 --- a/curl.go +++ b/curl.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/rand" + "crypto/tls" "errors" "fmt" "io" @@ -28,42 +29,184 @@ type RequestFile struct { UploadForm map[string]string UploadName string } + type Request struct { - TimeOut int - DialTimeOut int Url string Method string RecvData []byte RecvContentLength int64 - WriteRecvData bool RecvIo io.Writer - ReqHeader http.Header - ReqCookies []*http.Cookie RespHeader http.Header RespCookies []*http.Cookie + RespHttpCode int + CircleBuffer *stario.StarBuffer + respReader io.ReadCloser + RequestOpts +} + +type RequestOpts struct { RequestFile - RespHttpCode int - PostBuffer *bytes.Buffer - CircleBuffer *stario.StarBuffer - Proxy string - Process func(float64) - respReader io.ReadCloser + PostBuffer io.Reader + Process func(float64) + Proxy string + Timeout time.Duration + DialTimeout time.Duration + ReqHeader http.Header + ReqCookies []*http.Cookie + WriteRecvData bool + SkipTLSVerify bool + CustomTransport *http.Transport + Queries map[string]string +} + +type RequestOpt func(opt *RequestOpts) + +func WithDialTimeout(timeout time.Duration) RequestOpt { + return func(opt *RequestOpts) { + opt.DialTimeout = timeout + } +} + +func WithTimeout(timeout time.Duration) RequestOpt { + return func(opt *RequestOpts) { + opt.Timeout = timeout + } +} + +func WithHeader(key, val string) RequestOpt { + return func(opt *RequestOpts) { + opt.ReqHeader.Set(key, val) + } +} + +func WithHeaderMap(header map[string]string) RequestOpt { + return func(opt *RequestOpts) { + for key, val := range header { + opt.ReqHeader.Set(key, val) + } + } +} + +func WithHeaderAdd(key, val string) RequestOpt { + return func(opt *RequestOpts) { + opt.ReqHeader.Add(key, val) + } +} + +func WithReader(r io.Reader) RequestOpt { + return func(opt *RequestOpts) { + opt.PostBuffer = r + } +} + +func WithFetchRespBody(fetch bool) RequestOpt { + return func(opt *RequestOpts) { + opt.WriteRecvData = fetch + } +} + +func WithCookies(ck []*http.Cookie) RequestOpt { + return func(opt *RequestOpts) { + opt.ReqCookies = ck + } +} + +func WithCookie(key, val, path string) RequestOpt { + return func(opt *RequestOpts) { + opt.ReqCookies = append(opt.ReqCookies, &http.Cookie{Name: key, Value: val, Path: path}) + } +} + +func WithCookieMap(header map[string]string, path string) RequestOpt { + return func(opt *RequestOpts) { + for key, val := range header { + opt.ReqCookies = append(opt.ReqCookies, &http.Cookie{Name: key, Value: val, Path: path}) + } + } +} + +func WithQueries(queries map[string]string) RequestOpt { + return func(opt *RequestOpts) { + opt.Queries = queries + } +} + +func WithProxy(proxy string) RequestOpt { + return func(opt *RequestOpts) { + opt.Proxy = proxy + } +} + +func WithProcess(fn func(float64)) RequestOpt { + return func(opt *RequestOpts) { + opt.Process = fn + } } -func NewRequests(url string, postdata []byte, method string) Request { +func WithContentType(ct string) RequestOpt { + return func(opt *RequestOpts) { + opt.ReqHeader.Set("Content-Type", ct) + } +} + +func WithUserAgent(ua string) RequestOpt { + return func(opt *RequestOpts) { + opt.ReqHeader.Set("User-Agent", ua) + } +} + +func WithCustomTransport(hs *http.Transport) RequestOpt { + return func(opt *RequestOpts) { + opt.CustomTransport = hs + } +} + +func WithSkipTLSVerify(skip bool) RequestOpt { + return func(opt *RequestOpts) { + opt.SkipTLSVerify = skip + } +} + +func NewRequests(url string, rawdata []byte, method string, opts ...RequestOpt) Request { req := Request{ - TimeOut: 30, - DialTimeOut: 15, - Url: url, - PostBuffer: bytes.NewBuffer(postdata), - Method: method, - WriteRecvData: true, + RequestOpts: RequestOpts{ + Timeout: 30 * time.Second, + DialTimeout: 15 * time.Second, + WriteRecvData: true, + }, + Url: url, + Method: method, + } + if rawdata != nil { + req.PostBuffer = bytes.NewBuffer(rawdata) } req.ReqHeader = make(http.Header) if strings.ToUpper(method) == "POST" { req.ReqHeader.Set("Content-Type", HEADER_FORM_URLENCODE) } req.ReqHeader.Set("User-Agent", "B612 / 1.1.0") + for _, v := range opts { + v(&req.RequestOpts) + } + if req.CustomTransport == nil { + req.CustomTransport = &http.Transport{} + } + if req.SkipTLSVerify { + if req.CustomTransport.TLSClientConfig == nil { + req.CustomTransport.TLSClientConfig = &tls.Config{} + } + req.CustomTransport.TLSClientConfig.InsecureSkipVerify = true + } + req.CustomTransport.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { + c, err := net.DialTimeout(netw, addr, req.DialTimeout) + if err != nil { + return nil, err + } + if req.Timeout != 0 { + c.SetDeadline(time.Now().Add(req.Timeout)) + } + return c, nil + } return req } @@ -78,6 +221,9 @@ func (curl *Request) ResetReqCookies() { func (curl *Request) AddSimpleCookie(key, value string) { curl.ReqCookies = append(curl.ReqCookies, &http.Cookie{Name: key, Value: value, Path: "/"}) } +func (curl *Request) AddCookie(key, value, path string) { + curl.ReqCookies = append(curl.ReqCookies, &http.Cookie{Name: key, Value: value, Path: path}) +} func randomBoundary() string { var buf [30]byte @@ -146,9 +292,6 @@ func Curl(curl Request) (resps Request, err error) { if err != nil { return Request{}, err } - - curl.PostBuffer = nil - curl.CircleBuffer = nil curl.RespHttpCode = resp.StatusCode curl.RespHeader = resp.Header curl.RespCookies = resp.Cookies() @@ -211,13 +354,20 @@ func netcurl(curl Request) (*http.Response, error) { if curl.Method == "" { return nil, errors.New("Error Method Not Entered") } - if curl.PostBuffer != nil && curl.PostBuffer.Len() > 0 { + if curl.PostBuffer != nil { req, err = http.NewRequest(curl.Method, curl.Url, curl.PostBuffer) } else if curl.CircleBuffer != nil && curl.CircleBuffer.Len() > 0 { req, err = http.NewRequest(curl.Method, curl.Url, curl.CircleBuffer) } else { req, err = http.NewRequest(curl.Method, curl.Url, nil) } + if curl.Queries != nil { + sid := req.URL.Query() + for k, v := range curl.Queries { + sid.Add(k, v) + } + req.URL.RawQuery = sid.Encode() + } if err != nil { return nil, err } @@ -227,27 +377,15 @@ func netcurl(curl Request) (*http.Response, error) { req.AddCookie(v) } } - transport := &http.Transport{ - DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) { - c, err := net.DialTimeout(netw, addr, time.Second*time.Duration(curl.DialTimeOut)) - if err != nil { - return nil, err - } - if curl.TimeOut != 0 { - c.SetDeadline(time.Now().Add(time.Duration(curl.TimeOut) * time.Second)) - } - return c, nil - }, - } if curl.Proxy != "" { purl, err := url.Parse(curl.Proxy) if err != nil { return nil, err } - transport.Proxy = http.ProxyURL(purl) + curl.CustomTransport.Proxy = http.ProxyURL(purl) } client := &http.Client{ - Transport: transport, + Transport: curl.CustomTransport, } resp, err := client.Do(req) return resp, err @@ -266,10 +404,18 @@ func UrlDecode(str string) (string, error) { return url.QueryUnescape(str) } -func Build_Query(queryData map[string]string) string { +func BuildQuery(queryData map[string]string) string { query := url.Values{} for k, v := range queryData { query.Add(k, v) } return query.Encode() } + +func BuildPostForm(queryMap map[string]string) []byte { + query := url.Values{} + for k, v := range queryMap { + query.Add(k, v) + } + return []byte(query.Encode()) +}