Compare commits
9 Commits
v2.1.0.bet
...
master
Author | SHA1 | Date |
---|---|---|
兔子 | 4fd9c37944 | 2 weeks ago |
兔子 | 8b0a3483a0 | 3 months ago |
兔子 | 7d85b14d60 | 3 months ago |
兔子 | 60a73eb0d8 | 3 months ago |
兔子 | 23e4443cf4 | 3 months ago |
兔子 | c7d94d8f95 | 3 months ago |
兔子 | 1a04780474 | 3 months ago |
兔子 | a227ee9e6c | 5 months ago |
兔子 | 9fc353211f | 5 months ago |
@ -0,0 +1,160 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"b612.me/starcrypto"
|
||||
"b612.me/starlog"
|
||||
"fmt"
|
||||
"github.com/go-acme/lego/v4/certcrypto"
|
||||
"github.com/go-acme/lego/v4/challenge/http01"
|
||||
"github.com/go-acme/lego/v4/challenge/tlsalpn01"
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
"github.com/go-acme/lego/v4/providers/dns/acmedns"
|
||||
"github.com/go-acme/lego/v4/providers/dns/alidns"
|
||||
"github.com/go-acme/lego/v4/providers/dns/azuredns"
|
||||
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
|
||||
"github.com/go-acme/lego/v4/providers/dns/tencentcloud"
|
||||
"os"
|
||||
)
|
||||
|
||||
func run(a Acme) error {
|
||||
|
||||
// Create a user. New accounts need an email and private key to start.
|
||||
if a.KeyPath != "" {
|
||||
data, err := os.ReadFile(a.KeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read key file error:%w", err)
|
||||
}
|
||||
a.key, err = starcrypto.DecodePrivateKey(data, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode key error:%w", err)
|
||||
}
|
||||
}
|
||||
for _, req := range a.CertReqs {
|
||||
starlog.Info("request cert for %v", req.Domains)
|
||||
config := lego.NewConfig(&a)
|
||||
// This CA URL is configured for a local dev instance of Boulder running in Docker in a VM.
|
||||
config.CADirURL = "https://acme-v02.api.letsencrypt.org/directory"
|
||||
switch req.KeyType {
|
||||
case "rsa2048":
|
||||
config.Certificate.KeyType = certcrypto.RSA2048
|
||||
case "rsa4096":
|
||||
config.Certificate.KeyType = certcrypto.RSA4096
|
||||
case "rsa8192":
|
||||
config.Certificate.KeyType = certcrypto.RSA8192
|
||||
case "ec256":
|
||||
config.Certificate.KeyType = certcrypto.EC256
|
||||
case "ec384":
|
||||
config.Certificate.KeyType = certcrypto.EC384
|
||||
default:
|
||||
config.Certificate.KeyType = certcrypto.EC384
|
||||
}
|
||||
|
||||
// A client facilitates communication with the CA server.
|
||||
client, err := lego.NewClient(config)
|
||||
if err != nil {
|
||||
starlog.Errorf("new client error:%v", err)
|
||||
return fmt.Errorf("new client error:%w", err)
|
||||
}
|
||||
p := a.DnsPrivders[req.PrivderName]
|
||||
switch p.Type {
|
||||
case "http":
|
||||
err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", p.KeyID))
|
||||
if err != nil {
|
||||
starlog.Errorf("set http provider error:%v", err)
|
||||
return fmt.Errorf("set http provider error:%w", err)
|
||||
}
|
||||
err = client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", p.KeySecret))
|
||||
if err != nil {
|
||||
starlog.Errorf("set tlsalpn provider error:%v", err)
|
||||
return fmt.Errorf("set tlsalpn provider error:%w", err)
|
||||
}
|
||||
case "tencent":
|
||||
cfg := tencentcloud.NewDefaultConfig()
|
||||
cfg.SecretID = p.KeyID
|
||||
cfg.SecretKey = p.KeySecret
|
||||
dnsSet, err := tencentcloud.NewDNSProviderConfig(cfg)
|
||||
if err != nil {
|
||||
starlog.Errorf("new dns provider error:%v", err)
|
||||
return fmt.Errorf("new dns provider error:%w", err)
|
||||
}
|
||||
err = client.Challenge.SetDNS01Provider(dnsSet)
|
||||
if err != nil {
|
||||
starlog.Errorf("set dns provider error:%v", err)
|
||||
return fmt.Errorf("set dns provider error:%w", err)
|
||||
}
|
||||
case "cloudflare":
|
||||
cfg := cloudflare.NewDefaultConfig()
|
||||
cfg.AuthKey = p.KeySecret
|
||||
cfg.AuthEmail = p.KeyID
|
||||
dnsSet, err := cloudflare.NewDNSProviderConfig(cfg)
|
||||
if err != nil {
|
||||
starlog.Errorf("new dns provider error:%v", err)
|
||||
return fmt.Errorf("new dns provider error:%w", err)
|
||||
}
|
||||
err = client.Challenge.SetDNS01Provider(dnsSet)
|
||||
if err != nil {
|
||||
starlog.Errorf("set dns provider error:%v", err)
|
||||
return fmt.Errorf("set dns provider error:%w", err)
|
||||
}
|
||||
case "alidns":
|
||||
cfg := alidns.NewDefaultConfig()
|
||||
cfg.APIKey = p.KeyID
|
||||
cfg.SecretKey = p.KeySecret
|
||||
dnsSet, err := alidns.NewDNSProviderConfig(cfg)
|
||||
if err != nil {
|
||||
starlog.Errorf("new dns provider error:%v", err)
|
||||
return fmt.Errorf("new dns provider error:%w", err)
|
||||
}
|
||||
err = client.Challenge.SetDNS01Provider(dnsSet)
|
||||
if err != nil {
|
||||
starlog.Errorf("set dns provider error:%v", err)
|
||||
return fmt.Errorf("set dns provider error:%w", err)
|
||||
}
|
||||
case "azure":
|
||||
cfg := azuredns.NewDefaultConfig()
|
||||
cfg.ClientID = p.KeyID
|
||||
cfg.ClientSecret = p.KeySecret
|
||||
dnsSet, err := azuredns.NewDNSProviderConfig(cfg)
|
||||
if err != nil {
|
||||
starlog.Errorf("new dns provider error:%v", err)
|
||||
return fmt.Errorf("new dns provider error:%w", err)
|
||||
}
|
||||
err = client.Challenge.SetDNS01Provider(dnsSet)
|
||||
if err != nil {
|
||||
starlog.Errorf("set dns provider error:%v", err)
|
||||
return fmt.Errorf("set dns provider error:%w", err)
|
||||
}
|
||||
default:
|
||||
cfg, _ := acmedns.NewDNSProvider()
|
||||
err = client.Challenge.SetDNS01Provider(cfg)
|
||||
if err != nil {
|
||||
starlog.Errorf("set dns provider error:%v", err)
|
||||
return fmt.Errorf("set dns provider error:%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
// New users will need to register
|
||||
reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
a.Registration = reg
|
||||
|
||||
request := certificate.ObtainRequest{
|
||||
Domains: []string{"mydomain.com"},
|
||||
Bundle: true,
|
||||
}
|
||||
certificates, err := client.Certificate.Obtain(request)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Each certificate comes back with the cert bytes, the bytes of the client's
|
||||
// private key, and a certificate URL. SAVE THESE TO DISK.
|
||||
fmt.Printf("%#v\n", certificates)
|
||||
|
||||
*/
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"github.com/go-acme/lego/v4/registration"
|
||||
)
|
||||
|
||||
type DnsProvider struct {
|
||||
Name string
|
||||
Type string
|
||||
KeyID string
|
||||
KeySecret string
|
||||
}
|
||||
|
||||
type CertReq struct {
|
||||
Domains []string
|
||||
Type string
|
||||
PrivderName string
|
||||
KeyType string
|
||||
SaveFolder string
|
||||
SaveName string
|
||||
}
|
||||
|
||||
// You'll need a user or account type that implements acme.User
|
||||
type Acme struct {
|
||||
Email string
|
||||
KeyPath string
|
||||
SaveFolder string
|
||||
CertReqs map[string]CertReq
|
||||
DnsPrivders map[string]DnsProvider
|
||||
Registration *registration.Resource
|
||||
key crypto.PrivateKey
|
||||
}
|
||||
|
||||
func (u *Acme) GetEmail() string {
|
||||
return u.Email
|
||||
}
|
||||
func (u Acme) GetRegistration() *registration.Resource {
|
||||
return u.Registration
|
||||
}
|
||||
func (u *Acme) GetPrivateKey() crypto.PrivateKey {
|
||||
return u.key
|
||||
}
|
@ -1,47 +1,83 @@
|
||||
module b612.me/apps/b612
|
||||
|
||||
go 1.19
|
||||
go 1.21.2
|
||||
|
||||
toolchain go1.22.4
|
||||
|
||||
require (
|
||||
b612.me/notify v1.2.5
|
||||
b612.me/notify v1.2.6
|
||||
b612.me/sdk/whois v0.0.0-20240816133027-129514a15991
|
||||
b612.me/starcrypto v0.0.5
|
||||
b612.me/stario v0.0.9
|
||||
b612.me/starlog v1.3.3
|
||||
b612.me/starnet v0.1.8
|
||||
b612.me/staros v1.1.7
|
||||
b612.me/stario v0.0.10
|
||||
b612.me/starlog v1.3.4
|
||||
b612.me/starmap v1.2.4
|
||||
b612.me/starnet v0.2.1
|
||||
b612.me/staros v1.1.8
|
||||
b612.me/starssh v0.0.2
|
||||
b612.me/startext v0.0.0-20220314043758-22c6d5e5b1cd
|
||||
b612.me/wincmd v0.0.3
|
||||
b612.me/wincmd v0.0.4
|
||||
github.com/elazarl/goproxy v0.0.0-20231117061959-7cc037d33fb5
|
||||
github.com/elazarl/goproxy/ext v0.0.0-20190711103511-473e67f1d7d2
|
||||
github.com/emersion/go-smtp v0.20.2
|
||||
github.com/go-acme/lego/v4 v4.16.1
|
||||
github.com/goftp/file-driver v0.0.0-20180502053751-5d604a0fc0c9
|
||||
github.com/goftp/server v0.0.0-20200708154336-f64f7c2d8a42
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
|
||||
github.com/huin/goupnp v1.3.0
|
||||
github.com/inconshreveable/mousetrap v1.1.0
|
||||
github.com/likexian/whois v1.15.1
|
||||
github.com/miekg/dns v1.1.58
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
github.com/spf13/cobra v1.8.0
|
||||
github.com/things-go/go-socks5 v0.0.5
|
||||
golang.org/x/crypto v0.21.0
|
||||
github.com/vbauerster/mpb/v8 v8.8.3
|
||||
golang.org/x/crypto v0.26.0
|
||||
golang.org/x/net v0.28.0
|
||||
golang.org/x/sys v0.24.0
|
||||
software.sslmate.com/src/go-pkcs12 v0.4.0
|
||||
|
||||
)
|
||||
|
||||
require (
|
||||
b612.me/starmap v1.2.4 // indirect
|
||||
b612.me/win32api v0.0.2 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns v1.1.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.1.0 // indirect
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 // indirect
|
||||
github.com/VividCortex/ewma v1.2.0 // indirect
|
||||
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect
|
||||
github.com/aliyun/alibaba-cloud-sdk-go v1.61.1755 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
|
||||
github.com/cloudflare/cloudflare-go v0.86.0 // indirect
|
||||
github.com/cpu/goacmedns v0.1.1 // indirect
|
||||
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.0.1 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/uuid v1.3.1 // indirect
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
|
||||
github.com/hashicorp/go-retryablehttp v0.7.5 // indirect
|
||||
github.com/jlaffaye/ftp v0.1.0 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/kr/fs v0.1.0 // indirect
|
||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
|
||||
github.com/pkg/sftp v1.13.4 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.490 // indirect
|
||||
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.490 // indirect
|
||||
golang.org/x/image v0.6.0 // indirect
|
||||
golang.org/x/mod v0.14.0 // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/term v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/tools v0.17.0 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/sync v0.8.0 // indirect
|
||||
golang.org/x/term v0.23.0 // indirect
|
||||
golang.org/x/text v0.17.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
@ -0,0 +1,852 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// HTTP reverse proxy handler
|
||||
|
||||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
func lower(b byte) byte {
|
||||
if 'A' <= b && b <= 'Z' {
|
||||
return b + ('a' - 'A')
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func EqualFold(s, t string) bool {
|
||||
if len(s) != len(t) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(s); i++ {
|
||||
if lower(s[i]) != lower(t[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func IsPrint(s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] < ' ' || s[i] > '~' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// A ProxyRequest contains a request to be rewritten by a ReverseProxy.
|
||||
type ProxyRequest struct {
|
||||
// In is the request received by the proxy.
|
||||
// The Rewrite function must not modify In.
|
||||
In *http.Request
|
||||
|
||||
// Out is the request which will be sent by the proxy.
|
||||
// The Rewrite function may modify or replace this request.
|
||||
// Hop-by-hop headers are removed from this request
|
||||
// before Rewrite is called.
|
||||
Out *http.Request
|
||||
}
|
||||
|
||||
// SetURL routes the outbound request to the scheme, host, and base path
|
||||
// provided in target. If the target's path is "/base" and the incoming
|
||||
// request was for "/dir", the target request will be for "/base/dir".
|
||||
//
|
||||
// SetURL rewrites the outbound Host header to match the target's host.
|
||||
// To preserve the inbound request's Host header (the default behavior
|
||||
// of NewSingleHostReverseProxy):
|
||||
//
|
||||
// rewriteFunc := func(r *httputil.ProxyRequest) {
|
||||
// r.SetURL(url)
|
||||
// r.Out.Host = r.In.Host
|
||||
// }
|
||||
func (r *ProxyRequest) SetURL(target *url.URL) {
|
||||
rewriteRequestURL(r.Out, target)
|
||||
r.Out.Host = ""
|
||||
}
|
||||
|
||||
// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and
|
||||
// X-Forwarded-Proto headers of the outbound request.
|
||||
//
|
||||
// - The X-Forwarded-For header is set to the client IP address.
|
||||
// - The X-Forwarded-Host header is set to the host name requested
|
||||
// by the client.
|
||||
// - The X-Forwarded-Proto header is set to "http" or "https", depending
|
||||
// on whether the inbound request was made on a TLS-enabled connection.
|
||||
//
|
||||
// If the outbound request contains an existing X-Forwarded-For header,
|
||||
// SetXForwarded appends the client IP address to it. To append to the
|
||||
// inbound request's X-Forwarded-For header (the default behavior of
|
||||
// ReverseProxy when using a Director function), copy the header
|
||||
// from the inbound request before calling SetXForwarded:
|
||||
//
|
||||
// rewriteFunc := func(r *httputil.ProxyRequest) {
|
||||
// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
||||
// r.SetXForwarded()
|
||||
// }
|
||||
func (r *ProxyRequest) SetXForwarded() {
|
||||
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
|
||||
if err == nil {
|
||||
prior := r.Out.Header["X-Forwarded-For"]
|
||||
if len(prior) > 0 {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||
} else {
|
||||
r.Out.Header.Del("X-Forwarded-For")
|
||||
}
|
||||
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
|
||||
if r.In.TLS == nil {
|
||||
r.Out.Header.Set("X-Forwarded-Proto", "http")
|
||||
} else {
|
||||
r.Out.Header.Set("X-Forwarded-Proto", "https")
|
||||
}
|
||||
}
|
||||
|
||||
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||
// sends it to another server, proxying the response back to the
|
||||
// client.
|
||||
//
|
||||
// 1xx responses are forwarded to the client if the underlying
|
||||
// transport supports ClientTrace.Got1xxResponse.
|
||||
type ReverseProxy struct {
|
||||
// Rewrite must be a function which modifies
|
||||
// the request into a new request to be sent
|
||||
// using Transport. Its response is then copied
|
||||
// back to the original client unmodified.
|
||||
// Rewrite must not access the provided ProxyRequest
|
||||
// or its contents after returning.
|
||||
//
|
||||
// The Forwarded, X-Forwarded, X-Forwarded-Host,
|
||||
// and X-Forwarded-Proto headers are removed from the
|
||||
// outbound request before Rewrite is called. See also
|
||||
// the ProxyRequest.SetXForwarded method.
|
||||
//
|
||||
// Unparsable query parameters are removed from the
|
||||
// outbound request before Rewrite is called.
|
||||
// The Rewrite function may copy the inbound URL's
|
||||
// RawQuery to the outbound URL to preserve the original
|
||||
// parameter string. Note that this can lead to security
|
||||
// issues if the proxy's interpretation of query parameters
|
||||
// does not match that of the downstream server.
|
||||
//
|
||||
// At most one of Rewrite or Director may be set.
|
||||
Rewrite func(*ProxyRequest)
|
||||
|
||||
// Director is a function which modifies
|
||||
// the request into a new request to be sent
|
||||
// using Transport. Its response is then copied
|
||||
// back to the original client unmodified.
|
||||
// Director must not access the provided Request
|
||||
// after returning.
|
||||
//
|
||||
// By default, the X-Forwarded-For header is set to the
|
||||
// value of the client IP address. If an X-Forwarded-For
|
||||
// header already exists, the client IP is appended to the
|
||||
// existing values. As a special case, if the header
|
||||
// exists in the Request.Header map but has a nil value
|
||||
// (such as when set by the Director func), the X-Forwarded-For
|
||||
// header is not modified.
|
||||
//
|
||||
// To prevent IP spoofing, be sure to delete any pre-existing
|
||||
// X-Forwarded-For header coming from the client or
|
||||
// an untrusted proxy.
|
||||
//
|
||||
// Hop-by-hop headers are removed from the request after
|
||||
// Director returns, which can remove headers added by
|
||||
// Director. Use a Rewrite function instead to ensure
|
||||
// modifications to the request are preserved.
|
||||
//
|
||||
// Unparsable query parameters are removed from the outbound
|
||||
// request if Request.Form is set after Director returns.
|
||||
//
|
||||
// At most one of Rewrite or Director may be set.
|
||||
Director func(*http.Request)
|
||||
|
||||
// The transport used to perform proxy requests.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// FlushInterval specifies the flush interval
|
||||
// to flush to the client while copying the
|
||||
// response body.
|
||||
// If zero, no periodic flushing is done.
|
||||
// A negative value means to flush immediately
|
||||
// after each write to the client.
|
||||
// The FlushInterval is ignored when ReverseProxy
|
||||
// recognizes a response as a streaming response, or
|
||||
// if its ContentLength is -1; for such responses, writes
|
||||
// are flushed to the client immediately.
|
||||
FlushInterval time.Duration
|
||||
|
||||
// ErrorLog specifies an optional logger for errors
|
||||
// that occur when attempting to proxy the request.
|
||||
// If nil, logging is done via the log package's standard logger.
|
||||
ErrorLog *log.Logger
|
||||
|
||||
// BufferPool optionally specifies a buffer pool to
|
||||
// get byte slices for use by io.CopyBuffer when
|
||||
// copying HTTP response bodies.
|
||||
BufferPool BufferPool
|
||||
|
||||
// ModifyResponse is an optional function that modifies the
|
||||
// Response from the backend. It is called if the backend
|
||||
// returns a response at all, with any HTTP status code.
|
||||
// If the backend is unreachable, the optional ErrorHandler is
|
||||
// called without any call to ModifyResponse.
|
||||
//
|
||||
// If ModifyResponse returns an error, ErrorHandler is called
|
||||
// with its error value. If ErrorHandler is nil, its default
|
||||
// implementation is used.
|
||||
ModifyResponse func(*http.Response) error
|
||||
|
||||
// ErrorHandler is an optional function that handles errors
|
||||
// reaching the backend or errors from ModifyResponse.
|
||||
//
|
||||
// If nil, the default is to log the provided error and return
|
||||
// a 502 Status Bad Gateway response.
|
||||
ErrorHandler func(http.ResponseWriter, *http.Request, error)
|
||||
}
|
||||
|
||||
// A BufferPool is an interface for getting and returning temporary
|
||||
// byte slices for use by io.CopyBuffer.
|
||||
type BufferPool interface {
|
||||
Get() []byte
|
||||
Put([]byte)
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
||||
func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
||||
if a.RawPath == "" && b.RawPath == "" {
|
||||
return singleJoiningSlash(a.Path, b.Path), ""
|
||||
}
|
||||
// Same as singleJoiningSlash, but uses EscapedPath to determine
|
||||
// whether a slash should be added
|
||||
apath := a.EscapedPath()
|
||||
bpath := b.EscapedPath()
|
||||
|
||||
aslash := strings.HasSuffix(apath, "/")
|
||||
bslash := strings.HasPrefix(bpath, "/")
|
||||
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a.Path + b.Path[1:], apath + bpath[1:]
|
||||
case !aslash && !bslash:
|
||||
return a.Path + "/" + b.Path, apath + "/" + bpath
|
||||
}
|
||||
return a.Path + b.Path, apath + bpath
|
||||
}
|
||||
|
||||
// NewSingleHostReverseProxy returns a new ReverseProxy that routes
|
||||
// URLs to the scheme, host, and base path provided in target. If the
|
||||
// target's path is "/base" and the incoming request was for "/dir",
|
||||
// the target request will be for /base/dir.
|
||||
//
|
||||
// NewSingleHostReverseProxy does not rewrite the Host header.
|
||||
//
|
||||
// To customize the ReverseProxy behavior beyond what
|
||||
// NewSingleHostReverseProxy provides, use ReverseProxy directly
|
||||
// with a Rewrite function. The ProxyRequest SetURL method
|
||||
// may be used to route the outbound request. (Note that SetURL,
|
||||
// unlike NewSingleHostReverseProxy, rewrites the Host header
|
||||
// of the outbound request by default.)
|
||||
//
|
||||
// proxy := &ReverseProxy{
|
||||
// Rewrite: func(r *ProxyRequest) {
|
||||
// r.SetURL(target)
|
||||
// r.Out.Host = r.In.Host // if desired
|
||||
// }
|
||||
// }
|
||||
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
|
||||
director := func(req *http.Request) {
|
||||
rewriteRequestURL(req, target)
|
||||
}
|
||||
return &ReverseProxy{Director: director}
|
||||
}
|
||||
|
||||
func rewriteRequestURL(req *http.Request, target *url.URL) {
|
||||
targetQuery := target.RawQuery
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
||||
}
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||
// Connection header field. These are the headers defined by the
|
||||
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||
// compatibility.
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
|
||||
if p.ErrorHandler != nil {
|
||||
return p.ErrorHandler
|
||||
}
|
||||
return p.defaultErrorHandler
|
||||
}
|
||||
|
||||
// modifyResponse conditionally runs the optional ModifyResponse hook
|
||||
// and reports whether the request should proceed.
|
||||
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
|
||||
if p.ModifyResponse == nil {
|
||||
return true
|
||||
}
|
||||
if err := p.ModifyResponse(res); err != nil {
|
||||
res.Body.Close()
|
||||
p.getErrorHandler()(rw, req, err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
transport := p.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
ctx := req.Context()
|
||||
if ctx.Done() != nil {
|
||||
// CloseNotifier predates context.Context, and has been
|
||||
// entirely superseded by it. If the request contains
|
||||
// a Context that carries a cancellation signal, don't
|
||||
// bother spinning up a goroutine to watch the CloseNotify
|
||||
// channel (if any).
|
||||
//
|
||||
// If the request Context has a nil Done channel (which
|
||||
// means it is either context.Background, or a custom
|
||||
// Context implementation with no cancellation signal),
|
||||
// then consult the CloseNotifier if available.
|
||||
} else if cn, ok := rw.(http.CloseNotifier); ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
notifyChan := cn.CloseNotify()
|
||||
go func() {
|
||||
select {
|
||||
case <-notifyChan:
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
outreq := req.Clone(ctx)
|
||||
if req.ContentLength == 0 {
|
||||
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
|
||||
}
|
||||
if outreq.Body != nil {
|
||||
// Reading from the request body after returning from a handler is not
|
||||
// allowed, and the RoundTrip goroutine that reads the Body can outlive
|
||||
// this handler. This can lead to a crash if the handler panics (see
|
||||
// Issue 46866). Although calling Close doesn't guarantee there isn't
|
||||
// any Read in flight after the handle returns, in practice it's safe to
|
||||
// read after closing it.
|
||||
defer outreq.Body.Close()
|
||||
}
|
||||
if outreq.Header == nil {
|
||||
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
||||
}
|
||||
|
||||
if (p.Director != nil) == (p.Rewrite != nil) {
|
||||
p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set"))
|
||||
return
|
||||
}
|
||||
|
||||
if p.Director != nil {
|
||||
p.Director(outreq)
|
||||
if outreq.Form != nil {
|
||||
outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
outreq.Close = false
|
||||
|
||||
reqUpType := upgradeType(outreq.Header)
|
||||
if !IsPrint(reqUpType) {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
|
||||
return
|
||||
}
|
||||
removeHopByHopHeaders(outreq.Header)
|
||||
|
||||
// Issue 21096: tell backend applications that care about trailer support
|
||||
// that we support trailers. (We do, but we don't go out of our way to
|
||||
// advertise that unless the incoming client request thought it was worth
|
||||
// mentioning.) Note that we look at req.Header, not outreq.Header, since
|
||||
// the latter has passed through removeHopByHopHeaders.
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
|
||||
outreq.Header.Set("Te", "trailers")
|
||||
}
|
||||
|
||||
// After stripping all the hop-by-hop connection headers above, add back any
|
||||
// necessary for protocol upgrades, such as for websockets.
|
||||
if reqUpType != "" {
|
||||
outreq.Header.Set("Connection", "Upgrade")
|
||||
outreq.Header.Set("Upgrade", reqUpType)
|
||||
}
|
||||
|
||||
if p.Rewrite != nil {
|
||||
// Strip client-provided forwarding headers.
|
||||
// The Rewrite func may use SetXForwarded to set new values
|
||||
// for these or copy the previous values from the inbound request.
|
||||
outreq.Header.Del("Forwarded")
|
||||
outreq.Header.Del("X-Forwarded-For")
|
||||
outreq.Header.Del("X-Forwarded-Host")
|
||||
outreq.Header.Del("X-Forwarded-Proto")
|
||||
|
||||
// Remove unparsable query parameters from the outbound request.
|
||||
outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
|
||||
|
||||
pr := &ProxyRequest{
|
||||
In: req,
|
||||
Out: outreq,
|
||||
}
|
||||
p.Rewrite(pr)
|
||||
outreq = pr.Out
|
||||
}
|
||||
|
||||
if _, ok := outreq.Header["User-Agent"]; !ok {
|
||||
// If the outbound request doesn't have a User-Agent header set,
|
||||
// don't send the default Go HTTP client User-Agent.
|
||||
outreq.Header.Set("User-Agent", "")
|
||||
}
|
||||
|
||||
trace := &httptrace.ClientTrace{
|
||||
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||||
h := rw.Header()
|
||||
copyHeader(h, http.Header(header))
|
||||
rw.WriteHeader(code)
|
||||
|
||||
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
|
||||
for k := range h {
|
||||
delete(h, k)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
if err != nil {
|
||||
p.getErrorHandler()(rw, outreq, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
if !p.modifyResponse(rw, res, outreq) {
|
||||
return
|
||||
}
|
||||
p.handleUpgradeResponse(rw, outreq, res)
|
||||
return
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(res.Header)
|
||||
|
||||
if !p.modifyResponse(rw, res, outreq) {
|
||||
return
|
||||
}
|
||||
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
// The "Trailer" header isn't included in the Transport's response,
|
||||
// at least for *http.Transport. Build it up from Trailer.
|
||||
announcedTrailers := len(res.Trailer)
|
||||
if announcedTrailers > 0 {
|
||||
trailerKeys := make([]string, 0, len(res.Trailer))
|
||||
for k := range res.Trailer {
|
||||
trailerKeys = append(trailerKeys, k)
|
||||
}
|
||||
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
|
||||
}
|
||||
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
|
||||
err = p.copyResponse(rw, res.Body, p.flushInterval(res))
|
||||
if err != nil {
|
||||
defer res.Body.Close()
|
||||
// Since we're streaming the response, if we run into an error all we can do
|
||||
// is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler
|
||||
// on read error while copying body.
|
||||
if !shouldPanicOnCopyError(req) {
|
||||
p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
|
||||
return
|
||||
}
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
res.Body.Close() // close now, instead of defer, to populate res.Trailer
|
||||
|
||||
if len(res.Trailer) > 0 {
|
||||
// Force chunking if we saw a response trailer.
|
||||
// This prevents net/http from calculating the length for short
|
||||
// bodies and adding a Content-Length.
|
||||
if fl, ok := rw.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if len(res.Trailer) == announcedTrailers {
|
||||
copyHeader(rw.Header(), res.Trailer)
|
||||
return
|
||||
}
|
||||
|
||||
for k, vv := range res.Trailer {
|
||||
k = http.TrailerPrefix + k
|
||||
for _, v := range vv {
|
||||
rw.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var inOurTests bool // whether we're in our own tests
|
||||
|
||||
// shouldPanicOnCopyError reports whether the reverse proxy should
|
||||
// panic with http.ErrAbortHandler. This is the right thing to do by
|
||||
// default, but Go 1.10 and earlier did not, so existing unit tests
|
||||
// weren't expecting panics. Only panic in our own tests, or when
|
||||
// running under the HTTP server.
|
||||
func shouldPanicOnCopyError(req *http.Request) bool {
|
||||
if inOurTests {
|
||||
// Our tests know to handle this panic.
|
||||
return true
|
||||
}
|
||||
if req.Context().Value(http.ServerContextKey) != nil {
|
||||
// We seem to be running under an HTTP server, so
|
||||
// it'll recover the panic.
|
||||
return true
|
||||
}
|
||||
// Otherwise act like Go 1.10 and earlier to not break
|
||||
// existing tests.
|
||||
return false
|
||||
}
|
||||
|
||||
// removeHopByHopHeaders removes hop-by-hop headers.
|
||||
func removeHopByHopHeaders(h http.Header) {
|
||||
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
|
||||
for _, f := range h["Connection"] {
|
||||
for _, sf := range strings.Split(f, ",") {
|
||||
if sf = textproto.TrimString(sf); sf != "" {
|
||||
h.Del(sf)
|
||||
}
|
||||
}
|
||||
}
|
||||
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
|
||||
// This behavior is superseded by the RFC 7230 Connection header, but
|
||||
// preserve it for backwards compatibility.
|
||||
for _, f := range hopHeaders {
|
||||
h.Del(f)
|
||||
}
|
||||
}
|
||||
|
||||
// flushInterval returns the p.FlushInterval value, conditionally
|
||||
// overriding its value for a specific request/response.
|
||||
func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
|
||||
resCT := res.Header.Get("Content-Type")
|
||||
|
||||
// For Server-Sent Events responses, flush immediately.
|
||||
// The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
|
||||
if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
|
||||
return -1 // negative means immediately
|
||||
}
|
||||
|
||||
// We might have the case of streaming for which Content-Length might be unset.
|
||||
if res.ContentLength == -1 {
|
||||
return -1
|
||||
}
|
||||
|
||||
return p.FlushInterval
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
|
||||
if flushInterval != 0 {
|
||||
if wf, ok := dst.(writeFlusher); ok {
|
||||
mlw := &maxLatencyWriter{
|
||||
dst: wf,
|
||||
latency: flushInterval,
|
||||
}
|
||||
defer mlw.stop()
|
||||
|
||||
// set up initial timer so headers get flushed even if body writes are delayed
|
||||
mlw.flushPending = true
|
||||
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
|
||||
|
||||
dst = mlw
|
||||
}
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
if p.BufferPool != nil {
|
||||
buf = p.BufferPool.Get()
|
||||
defer p.BufferPool.Put(buf)
|
||||
}
|
||||
_, err := p.copyBuffer(dst, src, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// copyBuffer returns any write errors or non-EOF read errors, and the amount
|
||||
// of bytes written.
|
||||
func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
|
||||
if len(buf) == 0 {
|
||||
buf = make([]byte, 32*1024)
|
||||
}
|
||||
var written int64
|
||||
for {
|
||||
nr, rerr := src.Read(buf)
|
||||
if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
|
||||
p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
|
||||
}
|
||||
if nr > 0 {
|
||||
nw, werr := dst.Write(buf[:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
}
|
||||
if werr != nil {
|
||||
return written, werr
|
||||
}
|
||||
if nr != nw {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
if rerr != nil {
|
||||
if rerr == io.EOF {
|
||||
rerr = nil
|
||||
}
|
||||
return written, rerr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) logf(format string, args ...any) {
|
||||
if p.ErrorLog != nil {
|
||||
p.ErrorLog.Printf(format, args...)
|
||||
} else {
|
||||
log.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
type writeFlusher interface {
|
||||
io.Writer
|
||||
http.Flusher
|
||||
}
|
||||
|
||||
type maxLatencyWriter struct {
|
||||
dst writeFlusher
|
||||
latency time.Duration // non-zero; negative means to flush immediately
|
||||
|
||||
mu sync.Mutex // protects t, flushPending, and dst.Flush
|
||||
t *time.Timer
|
||||
flushPending bool
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
n, err = m.dst.Write(p)
|
||||
if m.latency < 0 {
|
||||
m.dst.Flush()
|
||||
return
|
||||
}
|
||||
if m.flushPending {
|
||||
return
|
||||
}
|
||||
if m.t == nil {
|
||||
m.t = time.AfterFunc(m.latency, m.delayedFlush)
|
||||
} else {
|
||||
m.t.Reset(m.latency)
|
||||
}
|
||||
m.flushPending = true
|
||||
return
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) delayedFlush() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
|
||||
return
|
||||
}
|
||||
m.dst.Flush()
|
||||
m.flushPending = false
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.flushPending = false
|
||||
if m.t != nil {
|
||||
m.t.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func upgradeType(h http.Header) string {
|
||||
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||
return ""
|
||||
}
|
||||
return h.Get("Upgrade")
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||
reqUpType := upgradeType(req.Header)
|
||||
resUpType := upgradeType(res.Header)
|
||||
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
|
||||
}
|
||||
if !EqualFold(reqUpType, resUpType) {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
|
||||
return
|
||||
}
|
||||
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
||||
return
|
||||
}
|
||||
backConn, ok := res.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
|
||||
return
|
||||
}
|
||||
|
||||
backConnCloseCh := make(chan bool)
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
// See issue https://golang.org/issue/35559.
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnCloseCh:
|
||||
}
|
||||
backConn.Close()
|
||||
}()
|
||||
|
||||
defer close(backConnCloseCh)
|
||||
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
res.Header = rw.Header()
|
||||
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
|
||||
if err := res.Write(brw); err != nil {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
|
||||
return
|
||||
}
|
||||
if err := brw.Flush(); err != nil {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
|
||||
return
|
||||
}
|
||||
errc := make(chan error, 1)
|
||||
spc := switchProtocolCopier{user: conn, backend: backConn}
|
||||
go spc.copyToBackend(errc)
|
||||
go spc.copyFromBackend(errc)
|
||||
<-errc
|
||||
}
|
||||
|
||||
// switchProtocolCopier exists so goroutines proxying data back and
|
||||
// forth have nice names in stacks.
|
||||
type switchProtocolCopier struct {
|
||||
user, backend io.ReadWriter
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
||||
_, err := io.Copy(c.user, c.backend)
|
||||
errc <- err
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
|
||||
_, err := io.Copy(c.backend, c.user)
|
||||
errc <- err
|
||||
}
|
||||
|
||||
func cleanQueryParams(s string) string {
|
||||
reencode := func(s string) string {
|
||||
v, _ := url.ParseQuery(s)
|
||||
return v.Encode()
|
||||
}
|
||||
for i := 0; i < len(s); {
|
||||
switch s[i] {
|
||||
case ';':
|
||||
return reencode(s)
|
||||
case '%':
|
||||
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
|
||||
return reencode(s)
|
||||
}
|
||||
i += 3
|
||||
default:
|
||||
i++
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func ishex(c byte) bool {
|
||||
switch {
|
||||
case '0' <= c && c <= '9':
|
||||
return true
|
||||
case 'a' <= c && c <= 'f':
|
||||
return true
|
||||
case 'A' <= c && c <= 'F':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHttpServer(t *testing.T) {
|
||||
http.ListenAndServe(":89", http.FileServer(http.Dir(`./`)))
|
||||
}
|
@ -0,0 +1,99 @@
|
||||
package mget
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/vbauerster/mpb/v8"
|
||||
"github.com/vbauerster/mpb/v8/decor"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (m *Mget) processMiddleware(base mpb.BarFiller) mpb.BarFiller {
|
||||
fn := func(w io.Writer, st decor.Statistics) error {
|
||||
var res string
|
||||
count := 0
|
||||
_, err := fmt.Fprintf(w, "\nFinished:%s Total Write:%d Speed:%v\n\n", m.Redo.FormatPercent(), m.Redo.Total(), m.Redo.FormatSpeed("MB"))
|
||||
for k := range m.threads {
|
||||
v := m.threads[len(m.threads)-1-k]
|
||||
if v != nil {
|
||||
count++
|
||||
res = fmt.Sprintf("Thread %v: %s %s\t", len(m.threads)-k, v.FormatSpeed("MB"), v.FormatPercent()) + res
|
||||
if count%3 == 0 {
|
||||
res = strings.TrimRight(res, "\t")
|
||||
fmt.Fprintf(w, "%s\n", res)
|
||||
res = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
if res != "" {
|
||||
res = strings.TrimRight(res, "\t")
|
||||
fmt.Fprintf(w, "%s\n", res)
|
||||
}
|
||||
return err
|
||||
}
|
||||
if base == nil {
|
||||
return mpb.BarFillerFunc(fn)
|
||||
}
|
||||
return mpb.BarFillerFunc(func(w io.Writer, st decor.Statistics) error {
|
||||
err := fn(w, st)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return base.Fill(w, st)
|
||||
})
|
||||
}
|
||||
|
||||
func (w *Mget) Process() {
|
||||
w.processEnable = true
|
||||
defer func() {
|
||||
w.processEnable = false
|
||||
}()
|
||||
fmt.Println()
|
||||
p := mpb.New()
|
||||
var filler mpb.BarFiller
|
||||
filler = w.processMiddleware(filler)
|
||||
bar := p.New(int64(w.ContentLength),
|
||||
mpb.BarStyle().Rbound("|"),
|
||||
mpb.BarExtender(filler, true), // all bars share same extender filler
|
||||
mpb.PrependDecorators(
|
||||
decor.Counters(decor.SizeB1024(0), "% .2f / % .2f"),
|
||||
),
|
||||
mpb.AppendDecorators(
|
||||
decor.EwmaETA(decor.ET_STYLE_GO, 30),
|
||||
decor.Name(" ] "),
|
||||
decor.EwmaSpeed(decor.SizeB1024(0), "% .2f ", 30),
|
||||
),
|
||||
)
|
||||
defer p.Wait()
|
||||
for {
|
||||
last := w.Redo.Total()
|
||||
lastTime := time.Now()
|
||||
bar.SetCurrent(int64(w.Redo.Total()))
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
bar.SetCurrent(int64(w.Redo.Total()))
|
||||
if w.dynLength {
|
||||
bar.SetTotal(int64(w.Redo.ContentLength), true)
|
||||
}
|
||||
bar.Abort(false)
|
||||
return
|
||||
case <-time.After(time.Second):
|
||||
if !w.writeEnable {
|
||||
bar.SetCurrent(int64(w.Redo.Total()))
|
||||
if w.dynLength {
|
||||
bar.SetTotal(int64(w.Redo.ContentLength), true)
|
||||
}
|
||||
bar.Abort(true)
|
||||
return
|
||||
}
|
||||
now := w.Redo.Total()
|
||||
bar.EwmaIncrInt64(int64(now-last), time.Since(lastTime))
|
||||
lastTime = time.Now()
|
||||
last = now
|
||||
if w.dynLength {
|
||||
bar.SetTotal(int64(w.Redo.ContentLength), false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,84 @@
|
||||
package mget
|
||||
|
||||
import "sort"
|
||||
|
||||
type Range struct {
|
||||
Min uint64 `json:"min"`
|
||||
Max uint64 `json:"max"`
|
||||
}
|
||||
type SortRange []Range
|
||||
|
||||
func (s SortRange) Len() int { return len(s) }
|
||||
func (s SortRange) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s SortRange) Less(i, j int) bool { return s[i].Min < s[j].Min }
|
||||
|
||||
func uniformRange(rg []Range) ([]Range, error) {
|
||||
newRg := make([]Range, 0, len(rg))
|
||||
sort.Sort(SortRange(rg))
|
||||
var last *Range = nil
|
||||
for _, v := range rg {
|
||||
if last != nil && v.Min <= last.Max+1 {
|
||||
if last.Max <= v.Max {
|
||||
last.Max = v.Max
|
||||
}
|
||||
continue
|
||||
}
|
||||
newRg = append(newRg, v)
|
||||
last = &newRg[len(newRg)-1]
|
||||
}
|
||||
return newRg, nil
|
||||
}
|
||||
|
||||
func singleSubRange(origin []Range, v Range) []Range {
|
||||
newRg := make([]Range, 0)
|
||||
sort.Sort(SortRange(origin))
|
||||
for i := 0; i < len(origin); i++ {
|
||||
ori := origin[i]
|
||||
res := make([]Range, 0)
|
||||
shouldAdd := true
|
||||
for j := 0; j < 1; j++ {
|
||||
if v.Min <= ori.Min && v.Max >= ori.Max {
|
||||
shouldAdd = false
|
||||
break
|
||||
}
|
||||
if v.Max < ori.Min {
|
||||
continue
|
||||
}
|
||||
if v.Min > ori.Max {
|
||||
break
|
||||
}
|
||||
ur1 := Range{
|
||||
Min: ori.Min,
|
||||
Max: v.Min - 1,
|
||||
}
|
||||
if v.Min == 0 {
|
||||
ur1.Min = 1
|
||||
ur1.Max = 0
|
||||
}
|
||||
ur2 := Range{
|
||||
Min: v.Max + 1,
|
||||
Max: ori.Max,
|
||||
}
|
||||
if ur1.Max >= ur1.Min {
|
||||
res = append(res, ur1)
|
||||
}
|
||||
if ur2.Max >= ur2.Min {
|
||||
res = append(res, ur2)
|
||||
}
|
||||
}
|
||||
if len(res) == 0 && shouldAdd {
|
||||
res = append(res, ori)
|
||||
}
|
||||
newRg = append(newRg, res...)
|
||||
}
|
||||
return newRg
|
||||
}
|
||||
|
||||
func subRange(origin, rg []Range) []Range {
|
||||
sort.Sort(SortRange(rg))
|
||||
sort.Sort(SortRange(origin))
|
||||
for _, v := range rg {
|
||||
origin = singleSubRange(origin, v)
|
||||
}
|
||||
return origin
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
package mget
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRangePlus(t *testing.T) {
|
||||
var r = Redo{
|
||||
ContentLength: 100,
|
||||
rangeUpdated: true,
|
||||
Range: []Range{
|
||||
{10, 12},
|
||||
{13, 20},
|
||||
{17, 19},
|
||||
{30, 80},
|
||||
{90, 97},
|
||||
},
|
||||
}
|
||||
err := r.reform()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(r.Range, []Range{{10, 20}, {30, 80}, {90, 97}}) {
|
||||
t.Error("reform error")
|
||||
}
|
||||
fmt.Println(r.Range)
|
||||
fmt.Println(r.ReverseRange())
|
||||
}
|
@ -0,0 +1,138 @@
|
||||
package mget
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Redo struct {
|
||||
Is206 bool `json:"is_206"`
|
||||
OriginUri string `json:"origin_uri"`
|
||||
Date time.Time `json:"date"`
|
||||
Filename string `json:"filename"`
|
||||
ContentLength uint64 `json:"content_length"`
|
||||
Range []Range `json:"range"`
|
||||
rangeUpdated bool
|
||||
lastUpdate time.Time
|
||||
lastTotal uint64
|
||||
speed float64
|
||||
total uint64
|
||||
isRedo bool
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func (r *Redo) CacheTotal() uint64 {
|
||||
return r.total
|
||||
}
|
||||
|
||||
func (r *Redo) Total() uint64 {
|
||||
var total uint64
|
||||
for {
|
||||
r.RLock()
|
||||
for _, v := range r.Range {
|
||||
total += v.Max - v.Min + 1
|
||||
}
|
||||
r.total = total
|
||||
r.RUnlock()
|
||||
if r.total > r.ContentLength && r.ContentLength > 0 {
|
||||
r.reform()
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (r *Redo) Update(start, end int) error {
|
||||
if start < 0 || end < 0 || start > end {
|
||||
return fmt.Errorf("invalid range: %d-%d", start, end)
|
||||
}
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.rangeUpdated = true
|
||||
r.Range = append(r.Range, Range{uint64(start), uint64(end)})
|
||||
now := time.Now()
|
||||
if now.Sub(r.lastUpdate) >= time.Millisecond*500 {
|
||||
var total uint64
|
||||
for _, v := range r.Range {
|
||||
total += v.Max - v.Min + 1
|
||||
}
|
||||
r.total = total
|
||||
r.speed = float64(total-r.lastTotal) / (float64(now.Sub(r.lastUpdate).Milliseconds()) / 1000.00)
|
||||
r.lastTotal = total
|
||||
r.lastUpdate = now
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Redo) Percent() float64 {
|
||||
return float64(r.Total()) / float64(r.ContentLength)
|
||||
}
|
||||
|
||||
func (r *Redo) FormatPercent() string {
|
||||
return fmt.Sprintf("%.2f%%", r.Percent()*100)
|
||||
}
|
||||
|
||||
func (r *Redo) FormatSpeed(unit string) string {
|
||||
switch strings.ToLower(unit) {
|
||||
case "kb":
|
||||
return fmt.Sprintf("%.2f KB/s", r.speed/1024)
|
||||
case "mb":
|
||||
return fmt.Sprintf("%.2f MB/s", r.speed/1024/1024)
|
||||
case "gb":
|
||||
return fmt.Sprintf("%.2f GB/s", r.speed/1024/1024/1024)
|
||||
default:
|
||||
return fmt.Sprintf("%.2f B/s", r.speed)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Redo) Speed() float64 {
|
||||
return r.speed
|
||||
}
|
||||
|
||||
func (r *Redo) Save() error {
|
||||
var err error
|
||||
err = r.reform()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.Filename != "" {
|
||||
data, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
return os.WriteFile(r.Filename+".bgrd", data, 0644)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Redo) reform() error {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
if !r.rangeUpdated {
|
||||
return nil
|
||||
}
|
||||
tmp, err := r.uniformRange(r.Range)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Range = tmp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Redo) uniformRange(rg []Range) ([]Range, error) {
|
||||
return uniformRange(rg)
|
||||
}
|
||||
|
||||
func (r *Redo) ReverseRange() ([]Range, error) {
|
||||
r.reform()
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
return r.uniformRange(subRange([]Range{{0, r.ContentLength - 1}}, r.Range))
|
||||
}
|
@ -0,0 +1,141 @@
|
||||
package mget
|
||||
|
||||
import (
|
||||
"b612.me/staros"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
func parseContentRange(contentRange string) (start, end, total int64, err error) {
|
||||
_, err = fmt.Sscanf(contentRange, "bytes %d-%d/%d", &start, &end, &total)
|
||||
return
|
||||
}
|
||||
|
||||
func GetFileName(resp *http.Response) string {
|
||||
fname := getFileName(resp)
|
||||
var idx = 0
|
||||
for {
|
||||
idx++
|
||||
if staros.Exists(fname) {
|
||||
if staros.Exists(fname + ".bgrd") {
|
||||
return fname
|
||||
}
|
||||
fname = fmt.Sprintf("%s.%d", fname, idx)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return fname
|
||||
}
|
||||
|
||||
func getFileName(resp *http.Response) string {
|
||||
// 尝试从Content-Disposition头中提取文件名
|
||||
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||
if contentDisposition != "" {
|
||||
// 使用正则表达式提取文件名
|
||||
re := regexp.MustCompile(`(?i)^attachment; filename="?(?P<filename>[^;"]+)`)
|
||||
matches := re.FindStringSubmatch(contentDisposition)
|
||||
if len(matches) > 1 {
|
||||
// 提取命名的捕获组
|
||||
for i, name := range re.SubexpNames() {
|
||||
if name == "filename" {
|
||||
return matches[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 提取路径中的最后一个元素作为文件名
|
||||
return path.Base(resp.Request.URL.Path)
|
||||
}
|
||||
|
||||
func IOWriter(stopCtx context.Context, ch chan Buffer, state *uint32, di *downloadinfo, reader io.ReadCloser, bufSize int, start *int64, end *int64) error {
|
||||
defer reader.Close()
|
||||
for {
|
||||
buf := make([]byte, bufSize)
|
||||
select {
|
||||
case <-stopCtx.Done():
|
||||
return nil
|
||||
default:
|
||||
if atomic.LoadUint32(state) == 1 {
|
||||
runtime.Gosched()
|
||||
time.Sleep(time.Millisecond)
|
||||
continue
|
||||
}
|
||||
n, err := reader.Read(buf)
|
||||
if n > 0 {
|
||||
ch <- Buffer{Data: buf[:n], Start: uint64(*start)}
|
||||
*start += int64(n)
|
||||
di.AddCurrent(int64(n))
|
||||
}
|
||||
if *start >= *end {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createFileWithSize(filename string, size int64) (*os.File, error) {
|
||||
file, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if size == 0 {
|
||||
return file, nil
|
||||
}
|
||||
// 调整文件指针到指定大小位置
|
||||
if _, err = file.Seek(size-1, 0); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 写入一个空字节,以确保文件达到所需大小
|
||||
if _, err = file.Write([]byte{0}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func CloneHeader(original http.Header) http.Header {
|
||||
newHeader := make(http.Header)
|
||||
for key, values := range original {
|
||||
copiedValues := make([]string, len(values))
|
||||
copy(copiedValues, values)
|
||||
newHeader[key] = copiedValues
|
||||
}
|
||||
return newHeader
|
||||
}
|
||||
|
||||
func CloneCookies(original []*http.Cookie) []*http.Cookie {
|
||||
cloned := make([]*http.Cookie, len(original))
|
||||
for i, cookie := range original {
|
||||
cloned[i] = &http.Cookie{
|
||||
Name: cookie.Name,
|
||||
Value: cookie.Value,
|
||||
Path: cookie.Path,
|
||||
Domain: cookie.Domain,
|
||||
Expires: cookie.Expires,
|
||||
RawExpires: cookie.RawExpires,
|
||||
MaxAge: cookie.MaxAge,
|
||||
Secure: cookie.Secure,
|
||||
HttpOnly: cookie.HttpOnly,
|
||||
SameSite: cookie.SameSite,
|
||||
Raw: cookie.Raw,
|
||||
Unparsed: append([]string(nil), cookie.Unparsed...),
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
@ -0,0 +1,526 @@
|
||||
package mget
|
||||
|
||||
import (
|
||||
"b612.me/stario"
|
||||
"b612.me/starnet"
|
||||
"b612.me/staros"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Mget struct {
|
||||
Setting starnet.Request
|
||||
Redo
|
||||
//本地文件地址
|
||||
Tareget string
|
||||
//本地文件大小
|
||||
TargetSize int64
|
||||
//redo文件最大丢数据量
|
||||
RedoRPO int
|
||||
//单个buffer大小
|
||||
BufferSize int
|
||||
//并发下载线程数
|
||||
dynLength bool
|
||||
Thread int `json:"thread"`
|
||||
tf *os.File
|
||||
ch chan Buffer
|
||||
ctx context.Context
|
||||
fn context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
threads []*downloader
|
||||
lastUndoInfo []Range
|
||||
writeError error
|
||||
writeEnable bool
|
||||
processEnable bool
|
||||
speedlimit int64
|
||||
}
|
||||
|
||||
type Buffer struct {
|
||||
Data []byte
|
||||
Start uint64
|
||||
}
|
||||
|
||||
func (w *Mget) Clone() *starnet.Request {
|
||||
req := starnet.NewSimpleRequest(w.Setting.Uri(), w.Setting.Method())
|
||||
req.SetHeaders(CloneHeader(w.Setting.Headers()))
|
||||
req.SetCookies(CloneCookies(w.Setting.Cookies()))
|
||||
req.SetSkipTLSVerify(w.Setting.SkipTLSVerify())
|
||||
req.SetProxy(w.Setting.Proxy())
|
||||
return req
|
||||
}
|
||||
|
||||
func (w *Mget) IsUrl206() (*starnet.Response, bool, error) {
|
||||
req := w.Clone()
|
||||
req.SetHeader("Range", "bytes=0-")
|
||||
res, err := req.Do()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if res.StatusCode == 206 {
|
||||
return res, true, nil
|
||||
}
|
||||
return res, false, nil
|
||||
}
|
||||
|
||||
func (w *Mget) prepareRun(res *starnet.Response, is206 bool) error {
|
||||
var err error
|
||||
|
||||
length := res.Header.Get("Content-Length")
|
||||
if length == "" {
|
||||
length = "0"
|
||||
w.dynLength = true
|
||||
is206 = false
|
||||
}
|
||||
w.TargetSize, err = strconv.ParseInt(length, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse content length error: %w", err)
|
||||
}
|
||||
if w.Tareget == "" {
|
||||
w.Tareget = GetFileName(res.Response)
|
||||
}
|
||||
fmt.Println("Will write to:", w.Tareget)
|
||||
fmt.Println("Size:", w.TargetSize)
|
||||
fmt.Println("Is206:", is206)
|
||||
w.Redo = Redo{
|
||||
Filename: w.Tareget,
|
||||
ContentLength: uint64(w.TargetSize),
|
||||
OriginUri: w.Setting.Uri(),
|
||||
Date: time.Now(),
|
||||
Is206: is206,
|
||||
}
|
||||
fmt.Println("Threads:", w.Thread)
|
||||
if staros.Exists(w.Tareget + ".bgrd") {
|
||||
fmt.Println("Found redo file, try to recover...")
|
||||
var redo Redo
|
||||
data, err := os.ReadFile(w.Tareget + ".bgrd")
|
||||
if err != nil {
|
||||
return fmt.Errorf("read redo file error: %w", err)
|
||||
}
|
||||
err = json.Unmarshal(data, &redo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal redo file error: %w", err)
|
||||
}
|
||||
redo.reform()
|
||||
if redo.ContentLength != w.Redo.ContentLength {
|
||||
fmt.Println("Content length not match, redo file may be invalid, ignore it")
|
||||
return nil
|
||||
}
|
||||
if redo.OriginUri != w.Redo.OriginUri {
|
||||
fmt.Println("Origin uri not match, redo file may be invalid, ignore it")
|
||||
return nil
|
||||
}
|
||||
w.Redo = redo
|
||||
w.Redo.isRedo = true
|
||||
w.lastUndoInfo, err = w.Redo.ReverseRange()
|
||||
if err != nil {
|
||||
return fmt.Errorf("reverse redo range error: %w", err)
|
||||
}
|
||||
fmt.Println("Recover redo file success,process:", w.Redo.FormatPercent())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Mget) Run() error {
|
||||
var err error
|
||||
var res *starnet.Response
|
||||
var is206 bool
|
||||
w.ctx, w.fn = context.WithCancel(context.Background())
|
||||
w.ch = make(chan Buffer)
|
||||
defer w.fn()
|
||||
w.threads = make([]*downloader, w.Thread)
|
||||
if w.Setting.Uri() == "" {
|
||||
w.Setting = *starnet.NewSimpleRequest(w.OriginUri, "GET")
|
||||
}
|
||||
for {
|
||||
res, is206, err = w.IsUrl206()
|
||||
if err != nil {
|
||||
return fmt.Errorf("check 206 error: %w", err)
|
||||
}
|
||||
err = w.prepareRun(res, is206)
|
||||
if err != nil {
|
||||
return fmt.Errorf("prepare run error: %w", err)
|
||||
}
|
||||
if res.StatusCode != 206 && res.StatusCode != 200 {
|
||||
return fmt.Errorf("Server return %d", res.StatusCode)
|
||||
}
|
||||
if !is206 {
|
||||
var di = &downloader{
|
||||
alive: true,
|
||||
downloadinfo: &downloadinfo{
|
||||
Start: 0,
|
||||
End: w.TargetSize - 1,
|
||||
Size: w.TargetSize,
|
||||
},
|
||||
}
|
||||
w.threads[0] = di
|
||||
state := uint32(0)
|
||||
err = IOWriter(w.ctx, w.ch, &state, di.downloadinfo, res.Body().Reader(), w.BufferSize, &di.Start, &di.End)
|
||||
di.alive = false
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
} else {
|
||||
res.Body().Close()
|
||||
}
|
||||
break
|
||||
}
|
||||
go func() {
|
||||
w.writeEnable = true
|
||||
w.writeError = w.WriteServer()
|
||||
w.writeEnable = false
|
||||
}()
|
||||
if w.TargetSize == 0 {
|
||||
return nil
|
||||
}
|
||||
for i := 0; i < w.Thread; i++ {
|
||||
w.wg.Add(1)
|
||||
go w.dispatch(i)
|
||||
}
|
||||
go w.Process()
|
||||
w.wg.Wait()
|
||||
time.Sleep(2 * time.Microsecond)
|
||||
for {
|
||||
if w.writeEnable {
|
||||
w.fn()
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
continue
|
||||
}
|
||||
if w.writeError != nil {
|
||||
err = w.Redo.Save()
|
||||
return fmt.Errorf("write error: %w %v", w.writeError, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
w.fn()
|
||||
stario.WaitUntilTimeout(time.Second*2,
|
||||
func(c chan struct{}) error {
|
||||
for {
|
||||
if w.processEnable {
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
})
|
||||
|
||||
r, err := w.ReverseRange()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(r) == 0 {
|
||||
return os.Remove(w.Tareget + ".bgrd")
|
||||
}
|
||||
return w.Redo.Save()
|
||||
}
|
||||
|
||||
func (w *Mget) dispatch(idx int) error {
|
||||
defer w.wg.Done()
|
||||
var start, end int64
|
||||
if len(w.lastUndoInfo) == 0 {
|
||||
count := w.TargetSize / int64(w.Thread)
|
||||
start = count * int64(idx)
|
||||
end = count*int64(idx+1) - 1
|
||||
if idx == w.Thread-1 {
|
||||
end = w.TargetSize - 1
|
||||
}
|
||||
} else {
|
||||
w.Lock()
|
||||
if len(w.lastUndoInfo) == 0 {
|
||||
d := &downloader{}
|
||||
w.threads[idx] = d
|
||||
w.Unlock()
|
||||
goto morejob
|
||||
}
|
||||
start = int64(w.lastUndoInfo[0].Min)
|
||||
end = int64(w.lastUndoInfo[0].Max)
|
||||
w.lastUndoInfo = w.lastUndoInfo[1:]
|
||||
w.Unlock()
|
||||
}
|
||||
for {
|
||||
req := w.Clone()
|
||||
req.SetCookies(CloneCookies(w.Setting.Cookies()))
|
||||
d := &downloader{
|
||||
Request: req,
|
||||
ch: w.ch,
|
||||
ctx: w.ctx,
|
||||
bufferSize: w.BufferSize,
|
||||
downloadinfo: &downloadinfo{
|
||||
Start: start,
|
||||
End: end,
|
||||
},
|
||||
}
|
||||
w.threads[idx] = d
|
||||
if err := d.Run(); err != nil {
|
||||
fmt.Printf("thread %d error: %v\n", idx, err)
|
||||
if d.Start >= d.End {
|
||||
break
|
||||
}
|
||||
start = d.Start
|
||||
end = d.End
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
morejob:
|
||||
for {
|
||||
w.Lock()
|
||||
if len(w.lastUndoInfo) > 0 {
|
||||
w.threads[idx].Start = int64(w.lastUndoInfo[idx].Min)
|
||||
w.threads[idx].End = int64(w.lastUndoInfo[idx].Max)
|
||||
w.lastUndoInfo = w.lastUndoInfo[1:]
|
||||
w.Unlock()
|
||||
} else {
|
||||
w.Unlock()
|
||||
if !w.RequestNewTask(w.threads[idx]) {
|
||||
break
|
||||
}
|
||||
}
|
||||
for {
|
||||
req := w.Clone()
|
||||
req.SetCookies(CloneCookies(w.Setting.Cookies()))
|
||||
d := &downloader{
|
||||
Request: req,
|
||||
ch: w.ch,
|
||||
ctx: w.ctx,
|
||||
bufferSize: w.BufferSize,
|
||||
downloadinfo: &downloadinfo{
|
||||
Start: w.threads[idx].Start,
|
||||
End: w.threads[idx].End,
|
||||
},
|
||||
}
|
||||
w.threads[idx] = d
|
||||
if err := d.Run(); err != nil {
|
||||
fmt.Printf("thread %d error: %v\n", idx, err)
|
||||
if d.Start >= d.End {
|
||||
break
|
||||
}
|
||||
start = d.Start
|
||||
end = d.End
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Mget) getSleepTime() time.Duration {
|
||||
if w.speedlimit == 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Nanosecond * time.Duration(16384*1000*1000*1000/w.speedlimit) / 2
|
||||
|
||||
}
|
||||
func (w *Mget) WriteServer() error {
|
||||
var err error
|
||||
defer w.fn()
|
||||
if !w.isRedo {
|
||||
w.tf, err = createFileWithSize(w.Tareget, w.TargetSize)
|
||||
} else {
|
||||
w.tf, err = os.OpenFile(w.Tareget, os.O_RDWR, 0666)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lastUpdateRange := 0
|
||||
currentRange := 0
|
||||
|
||||
currentCount := int64(0)
|
||||
lastDate := time.Now()
|
||||
lastCount := int64(0)
|
||||
speedControl := func(count int) {
|
||||
if w.speedlimit == 0 {
|
||||
return
|
||||
}
|
||||
currentCount += int64(count)
|
||||
for {
|
||||
if time.Since(lastDate) < time.Second {
|
||||
if currentCount-lastCount > w.speedlimit {
|
||||
time.Sleep(w.getSleepTime())
|
||||
} else {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
lastDate = time.Now()
|
||||
lastCount = currentCount
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return nil
|
||||
case b := <-w.ch:
|
||||
n, err := w.tf.WriteAt(b.Data, int64(b.Start))
|
||||
if err != nil {
|
||||
fmt.Println("write error:", err)
|
||||
return err
|
||||
}
|
||||
speedControl(n)
|
||||
if w.dynLength {
|
||||
w.ContentLength += uint64(n)
|
||||
}
|
||||
currentRange += n
|
||||
end := b.Start + uint64(n) - 1
|
||||
err = w.Update(int(b.Start), int(end))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if currentRange-lastUpdateRange >= w.RedoRPO {
|
||||
w.tf.Sync()
|
||||
go w.Redo.Save()
|
||||
lastUpdateRange = currentRange
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type downloader struct {
|
||||
*starnet.Request
|
||||
alive bool
|
||||
ch chan Buffer
|
||||
ctx context.Context
|
||||
state uint32
|
||||
bufferSize int
|
||||
*downloadinfo
|
||||
}
|
||||
|
||||
func (d *downloader) Run() error {
|
||||
d.alive = true
|
||||
defer func() {
|
||||
d.alive = false
|
||||
}()
|
||||
d.SetHeader("Range", fmt.Sprintf("bytes=%d-%d", d.Start, d.End))
|
||||
res, err := d.Do()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res.Header.Get("Content-Range") == "" {
|
||||
return fmt.Errorf("server not support range")
|
||||
}
|
||||
start, end, _, err := parseContentRange(res.Header.Get("Content-Range"))
|
||||
if d.Start != start {
|
||||
return fmt.Errorf("server not support range")
|
||||
}
|
||||
d.End = end
|
||||
d.downloadinfo = &downloadinfo{
|
||||
Start: d.Start,
|
||||
End: d.End,
|
||||
Size: d.End - d.Start + 1,
|
||||
}
|
||||
reader := res.Body().Reader()
|
||||
return IOWriter(d.ctx, d.ch, &d.state, d.downloadinfo, reader, d.bufferSize, &d.Start, &d.End)
|
||||
}
|
||||
|
||||
func (w *Mget) RequestNewTask(task *downloader) bool {
|
||||
//stop thhe world first
|
||||
w.Lock()
|
||||
defer w.Unlock()
|
||||
defer func() {
|
||||
for _, v := range w.threads {
|
||||
if v != nil {
|
||||
atomic.StoreUint32(&v.state, 0)
|
||||
}
|
||||
}
|
||||
}()
|
||||
var maxThread *downloader
|
||||
for _, v := range w.threads {
|
||||
if v != nil {
|
||||
atomic.StoreUint32(&v.state, 1)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Microsecond * 2)
|
||||
|
||||
for _, v := range w.threads {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
if maxThread == nil {
|
||||
maxThread = v
|
||||
continue
|
||||
}
|
||||
if v.End-v.Start > maxThread.End-maxThread.Start {
|
||||
maxThread = v
|
||||
}
|
||||
}
|
||||
if maxThread == nil || maxThread.End <= maxThread.Start {
|
||||
return false
|
||||
}
|
||||
if (maxThread.End-maxThread.Start)/2 < int64(w.BufferSize*2) || (maxThread.End-maxThread.Start)/2 < 100*1024 {
|
||||
return false
|
||||
}
|
||||
task.End = maxThread.End
|
||||
maxThread.End = maxThread.Start + (maxThread.End-maxThread.Start)/2
|
||||
task.Start = maxThread.End + 1
|
||||
//fmt.Printf("thread got new task %d-%d\n", task.Start, task.End)
|
||||
return true
|
||||
}
|
||||
|
||||
type downloadinfo struct {
|
||||
Start int64
|
||||
End int64
|
||||
Size int64
|
||||
current int64
|
||||
lastCurrent int64
|
||||
lastTime time.Time
|
||||
speed float64
|
||||
}
|
||||
|
||||
func (d *downloadinfo) Current() int64 {
|
||||
return d.current
|
||||
}
|
||||
|
||||
func (d *downloadinfo) Percent() float64 {
|
||||
return float64(d.current) / float64(d.Size)
|
||||
}
|
||||
|
||||
func (d *downloadinfo) FormatPercent() string {
|
||||
return fmt.Sprintf("%.2f%%", d.Percent()*100)
|
||||
}
|
||||
|
||||
func (d *downloadinfo) SetCurrent(info int64) {
|
||||
d.current = info
|
||||
now := time.Now()
|
||||
if now.Sub(d.lastTime) >= time.Millisecond*500 {
|
||||
d.speed = float64(d.current-d.lastCurrent) / (float64(now.Sub(d.lastTime).Milliseconds()) / 1000.00)
|
||||
d.lastCurrent = d.current
|
||||
d.lastTime = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *downloadinfo) AddCurrent(info int64) {
|
||||
d.current += info
|
||||
now := time.Now()
|
||||
if now.Sub(d.lastTime) >= time.Millisecond*500 {
|
||||
d.speed = float64(d.current-d.lastCurrent) / (float64(now.Sub(d.lastTime).Milliseconds()) / 1000.00)
|
||||
d.lastCurrent = d.current
|
||||
d.lastTime = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *downloadinfo) FormatSpeed(unit string) string {
|
||||
switch strings.ToLower(unit) {
|
||||
case "kb":
|
||||
return fmt.Sprintf("%.2f KB/s", d.speed/1024)
|
||||
case "mb":
|
||||
return fmt.Sprintf("%.2f MB/s", d.speed/1024/1024)
|
||||
case "gb":
|
||||
return fmt.Sprintf("%.2f GB/s", d.speed/1024/1024/1024)
|
||||
default:
|
||||
return fmt.Sprintf("%.2f B/s", d.speed)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *downloadinfo) Speed() float64 {
|
||||
return d.speed
|
||||
}
|
@ -0,0 +1,35 @@
|
||||
package mget
|
||||
|
||||
import (
|
||||
"b612.me/starnet"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWget(t *testing.T) {
|
||||
r := starnet.NewSimpleRequest("http://192.168.2.33:88/DJI_0746.MP4", "GET")
|
||||
w := Mget{
|
||||
Setting: *r,
|
||||
RedoRPO: 1048576,
|
||||
BufferSize: 8192,
|
||||
Thread: 8,
|
||||
}
|
||||
if err := w.Run(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM(t *testing.T) {
|
||||
a := map[string]string{
|
||||
"1": "1",
|
||||
"2": "2",
|
||||
}
|
||||
modify(a)
|
||||
fmt.Println(a)
|
||||
}
|
||||
|
||||
func modify(a map[string]string) {
|
||||
b := make(map[string]string)
|
||||
b = a
|
||||
b["1"] = "3"
|
||||
}
|
@ -0,0 +1,124 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"b612.me/starlog"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type NatTesterServer struct {
|
||||
MainPort string
|
||||
AltPort string
|
||||
MainIP string
|
||||
AltIP string
|
||||
LogPath string
|
||||
stopCtx context.Context
|
||||
stopFn context.CancelFunc
|
||||
maina *net.UDPConn
|
||||
mainb *net.UDPConn
|
||||
alt *net.UDPConn
|
||||
running int32
|
||||
}
|
||||
|
||||
func (n *NatTesterServer) Run() error {
|
||||
if atomic.LoadInt32(&n.running) > 0 {
|
||||
starlog.Errorln("already running")
|
||||
return fmt.Errorf("already running")
|
||||
}
|
||||
atomic.StoreInt32(&n.running, 1)
|
||||
defer atomic.StoreInt32(&n.running, 0)
|
||||
if n.LogPath != "" {
|
||||
starlog.SetLogFile(n.LogPath, starlog.Std, true)
|
||||
starlog.Infof("Log file set to %s\n", n.LogPath)
|
||||
}
|
||||
starlog.Infof("MainPort: %s\n", n.MainPort)
|
||||
starlog.Infof("AltPort: %s\n", n.AltPort)
|
||||
tmp, err := net.Dial("udp", "8.8.8.8:53")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
starlog.Infof("Current Output IP: %s\n", tmp.LocalAddr().(*net.UDPAddr).IP.String())
|
||||
tmp.Close()
|
||||
n.stopCtx, n.stopFn = context.WithCancel(context.Background())
|
||||
mainaaddr, err := net.ResolveUDPAddr("udp", n.MainIP+":"+n.MainPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mainbaddr, err := net.ResolveUDPAddr("udp", n.MainIP+":"+n.AltPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.maina, err = net.ListenUDP("udp", mainaaddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
starlog.Infof("UDP MainIP:MainPort Listening on %s\n", n.maina.LocalAddr().String())
|
||||
n.mainb, err = net.ListenUDP("udp", mainbaddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
starlog.Infof("UDP MainIP:AltPort Listening on %s\n", n.mainb.LocalAddr().String())
|
||||
altaddr, err := net.ResolveUDPAddr("udp", n.AltIP+":"+n.AltPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.alt, err = net.ListenUDP("udp", altaddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
starlog.Infof("UDP AltIP:AltPort Listening on %s\n", n.alt.LocalAddr().String())
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-n.stopCtx.Done():
|
||||
starlog.Infoln("Stopping,Reason: Context Done")
|
||||
return
|
||||
default:
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
num, r, e := n.alt.ReadFromUDP(buf)
|
||||
if e != nil {
|
||||
continue
|
||||
}
|
||||
go n.Analyse(n.alt, r, strings.Split(string(buf[:num]), "::"))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-n.stopCtx.Done():
|
||||
starlog.Infoln("Stopping,Reason: Context Done")
|
||||
n.maina.Close()
|
||||
n.mainb.Close()
|
||||
n.alt.Close()
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
num, r, e := n.maina.ReadFromUDP(buf)
|
||||
if e != nil {
|
||||
continue
|
||||
}
|
||||
go n.Analyse(n.maina, r, strings.Split(string(buf[:num]), "::"))
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NatTesterServer) Analyse(c *net.UDPConn, r *net.UDPAddr, cmds []string) error {
|
||||
switch cmds[0] {
|
||||
case "ip":
|
||||
c.WriteToUDP([]byte("ip::"+r.String()), r)
|
||||
starlog.Infof("Recv IP Request from %s,Local: %s\n", r.String(), c.LocalAddr().String())
|
||||
case "startnat1":
|
||||
n.alt.WriteToUDP([]byte("stage1"), r)
|
||||
starlog.Infof("Start NAT1 Test from %s,Recv Local:%s Send Local:%s\n", r.String(), c.LocalAddr().String(), n.alt.LocalAddr().String())
|
||||
case "startnat2":
|
||||
n.mainb.WriteToUDP([]byte("stage2"), r)
|
||||
starlog.Infof("Start NAT2 Test from %s,Recv Local:%s Send Local:%s\n", r.String(), c.LocalAddr().String(), n.mainb.LocalAddr().String())
|
||||
case "startnat3":
|
||||
n.maina.WriteToUDP([]byte("stage3"), r)
|
||||
starlog.Infof("Start NAT3 Test from %s,Recv Local:%s Send Local:%s\n", r.String(), c.LocalAddr().String(), n.maina.LocalAddr().String())
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,765 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"b612.me/apps/b612/netforward"
|
||||
"b612.me/starlog"
|
||||
"b612.me/starmap"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/huin/goupnp/dcps/internetgateway2"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type NatThroughs struct {
|
||||
Lists []*NatThrough
|
||||
WebPort int
|
||||
AutoUPnP bool
|
||||
KeepAlivePeriod int
|
||||
KeepAliveIdel int
|
||||
KeepAliveCount int
|
||||
STUN string
|
||||
Remote string
|
||||
HealthCheckInterval int
|
||||
Type string
|
||||
}
|
||||
|
||||
func (n *NatThroughs) Close() {
|
||||
for _, v := range n.Lists {
|
||||
v.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NatThroughs) Parse(reqs []string) error {
|
||||
if n.KeepAlivePeriod == 0 {
|
||||
n.KeepAlivePeriod = 10
|
||||
}
|
||||
if n.KeepAliveIdel == 0 {
|
||||
n.KeepAliveIdel = 30
|
||||
}
|
||||
if n.KeepAliveCount == 0 {
|
||||
n.KeepAliveCount = 5
|
||||
}
|
||||
if n.STUN == "" {
|
||||
n.STUN = "turn.b612.me:3478"
|
||||
}
|
||||
if n.Type == "" {
|
||||
n.Type = "tcp"
|
||||
}
|
||||
for _, v := range reqs {
|
||||
var req = NatThrough{
|
||||
Forward: netforward.NetForward{
|
||||
LocalAddr: "0.0.0.0",
|
||||
DialTimeout: time.Second * 5,
|
||||
UDPTimeout: time.Second * 30,
|
||||
KeepAlivePeriod: n.KeepAlivePeriod,
|
||||
KeepAliveIdel: n.KeepAliveIdel,
|
||||
KeepAliveCount: n.KeepAliveCount,
|
||||
UsingKeepAlive: true,
|
||||
EnableTCP: true,
|
||||
EnableUDP: true,
|
||||
},
|
||||
Type: n.Type,
|
||||
STUN: n.STUN,
|
||||
Remote: n.Remote,
|
||||
KeepAlivePeriod: n.KeepAlivePeriod,
|
||||
KeepAliveIdel: n.KeepAliveIdel,
|
||||
KeepAliveCount: n.KeepAliveCount,
|
||||
AutoUPnP: n.AutoUPnP,
|
||||
HealthCheckInterval: n.HealthCheckInterval,
|
||||
}
|
||||
strs := strings.Split(v, ",")
|
||||
switch len(strs) {
|
||||
case 1:
|
||||
req.Forward.RemoteURI = strs[0]
|
||||
case 2:
|
||||
ipport := strings.Split(strs[0], ":")
|
||||
if len(ipport) == 1 {
|
||||
port, err := strconv.Atoi(ipport[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Forward.LocalPort = port
|
||||
} else {
|
||||
req.Forward.LocalAddr = ipport[0]
|
||||
port, err := strconv.Atoi(ipport[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Forward.LocalPort = port
|
||||
}
|
||||
req.Forward.RemoteURI = strs[1]
|
||||
case 3:
|
||||
ipport := strings.Split(strs[1], ":")
|
||||
if len(ipport) == 1 {
|
||||
port, err := strconv.Atoi(ipport[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Forward.LocalPort = port
|
||||
} else {
|
||||
req.Forward.LocalAddr = ipport[0]
|
||||
port, err := strconv.Atoi(ipport[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Forward.LocalPort = port
|
||||
}
|
||||
req.Forward.RemoteURI = strs[2]
|
||||
req.Name = strs[0]
|
||||
case 4:
|
||||
ipport := strings.Split(strs[2], ":")
|
||||
if len(ipport) == 1 {
|
||||
port, err := strconv.Atoi(ipport[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Forward.LocalPort = port
|
||||
} else {
|
||||
req.Forward.LocalAddr = ipport[0]
|
||||
port, err := strconv.Atoi(ipport[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Forward.LocalPort = port
|
||||
}
|
||||
req.Type = strings.ToLower(strs[0])
|
||||
req.Forward.RemoteURI = strs[3]
|
||||
req.Name = strs[1]
|
||||
}
|
||||
n.Lists = append(n.Lists, &req)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NatThroughs) Run() error {
|
||||
go n.WebService()
|
||||
wg := sync.WaitGroup{}
|
||||
for _, v := range n.Lists {
|
||||
wg.Add(1)
|
||||
go func(v *NatThrough) {
|
||||
defer wg.Done()
|
||||
if err := v.Run(); err != nil {
|
||||
starlog.Errorf("Failed to run natThrough: %v\n", err)
|
||||
}
|
||||
v.HealthCheck()
|
||||
}(v)
|
||||
}
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
type nattinfo struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Ext string `json:"ext"`
|
||||
Local string `json:"local"`
|
||||
Forward string `json:"forward"`
|
||||
}
|
||||
|
||||
func (n *NatThroughs) WebService() error {
|
||||
if n.WebPort == 0 {
|
||||
return nil
|
||||
}
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", n.WebPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
starlog.Infof("Web service listen on %d\n", n.WebPort)
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
var str string
|
||||
for k, v := range n.Lists {
|
||||
str += fmt.Sprintf("id:%d name:%s : %s <----> %s <-----> %s\n", k, v.Name, v.ExtUrl, v.localipport, v.Forward.RemoteURI)
|
||||
}
|
||||
w.Write([]byte(str))
|
||||
})
|
||||
http.HandleFunc("/json", func(w http.ResponseWriter, r *http.Request) {
|
||||
var res []nattinfo
|
||||
for k, v := range n.Lists {
|
||||
res = append(res, nattinfo{
|
||||
Id: k,
|
||||
Name: v.Name,
|
||||
Ext: v.ExtUrl,
|
||||
Local: v.localipport,
|
||||
Forward: v.Forward.RemoteURI,
|
||||
})
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
data, _ := json.Marshal(res)
|
||||
w.Write(data)
|
||||
return
|
||||
})
|
||||
http.HandleFunc("/jump", func(w http.ResponseWriter, r *http.Request) {
|
||||
types := "https://"
|
||||
name := r.URL.Query().Get("name")
|
||||
if name == "" {
|
||||
w.Write([]byte("id is empty"))
|
||||
return
|
||||
}
|
||||
if r.URL.Query().Get("type") == "http" {
|
||||
types = "http://"
|
||||
}
|
||||
for _, v := range n.Lists {
|
||||
if v.Name == name {
|
||||
http.Redirect(w, r, types+v.ExtUrl, http.StatusFound)
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
return http.Serve(listener, nil)
|
||||
}
|
||||
|
||||
// NatThrough 类似于natter.py 是一个用于Full Cone NAT直接穿透的工具
|
||||
type NatThrough struct {
|
||||
Name string
|
||||
OriginLocalPort int
|
||||
Forward netforward.NetForward
|
||||
Type string
|
||||
STUN string
|
||||
Remote string
|
||||
KeepAlivePeriod int
|
||||
KeepAliveIdel int
|
||||
KeepAliveCount int
|
||||
AutoUPnP bool
|
||||
isOk bool
|
||||
ExtUrl string
|
||||
localipport string
|
||||
keepaliveConn net.Conn
|
||||
HealthCheckInterval int
|
||||
stopFn context.CancelFunc
|
||||
stopCtx context.Context
|
||||
}
|
||||
|
||||
func (n *NatThrough) Close() {
|
||||
n.stopFn()
|
||||
n.Forward.Close()
|
||||
}
|
||||
|
||||
func (c *NatThrough) Run() error {
|
||||
c.isOk = false
|
||||
c.stopCtx, c.stopFn = context.WithCancel(context.Background())
|
||||
c.OriginLocalPort = c.Forward.LocalPort
|
||||
if c.Forward.LocalPort == 0 {
|
||||
listener, err := net.Listen(c.Type, c.Forward.LocalAddr+":0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to listen on %s: %v", c.Forward.LocalAddr, err)
|
||||
}
|
||||
if c.Type == "tcp" {
|
||||
c.Forward.LocalPort = listener.Addr().(*net.TCPAddr).Port
|
||||
} else {
|
||||
c.Forward.LocalPort = listener.Addr().(*net.UDPAddr).Port
|
||||
}
|
||||
listener.Close()
|
||||
}
|
||||
if c.Type == "tcp" {
|
||||
c.Forward.EnableTCP = true
|
||||
c.Forward.EnableUDP = false
|
||||
} else {
|
||||
c.Forward.EnableTCP = false
|
||||
c.Forward.EnableUDP = true
|
||||
}
|
||||
starlog.Infof("NatThrough Type: %s\n", c.Type)
|
||||
starlog.Infof("Local Port: %d\n", c.Forward.LocalPort)
|
||||
starlog.Infof("Keepalive To: %s\n", c.Remote)
|
||||
starlog.Infof("Forward To: %s\n", c.Forward.RemoteURI)
|
||||
|
||||
innerIp, extIp, err := c.GetIPPortFromSTUN(c.Type, c.Forward.LocalAddr, c.Forward.LocalPort, c.STUN)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to get external IP and port: %v", err)
|
||||
}
|
||||
starlog.Infof("Internal Addr: %s \n", innerIp.String())
|
||||
starlog.Infof("External Addr: %s \n", extIp.String())
|
||||
getIP := func(ip net.Addr) string {
|
||||
switch ip.(type) {
|
||||
case *net.TCPAddr:
|
||||
return ip.(*net.TCPAddr).IP.String()
|
||||
case *net.UDPAddr:
|
||||
return ip.(*net.UDPAddr).IP.String()
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
getPort := func(ip net.Addr) int {
|
||||
switch ip.(type) {
|
||||
case *net.TCPAddr:
|
||||
return ip.(*net.TCPAddr).Port
|
||||
case *net.UDPAddr:
|
||||
return ip.(*net.UDPAddr).Port
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
if err := c.KeepAlive(c.Forward.LocalAddr, c.Forward.LocalPort); err != nil {
|
||||
starlog.Errorf("Failed to run keepalive: %v\n", err)
|
||||
c.Forward.Close()
|
||||
c.stopFn()
|
||||
}
|
||||
}()
|
||||
innerIp, extIp, err = c.GetIPPortFromSTUN(c.Type, c.Forward.LocalAddr, c.Forward.LocalPort, c.STUN)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to get external IP and port: %v", err)
|
||||
}
|
||||
starlog.Infof("Retest Internal Addr: %s \n", innerIp.String())
|
||||
starlog.Infof("Retest External Addr: %s \n", extIp.String())
|
||||
if c.AutoUPnP {
|
||||
go c.HandleUPnP(getIP(innerIp), uint16(getPort(extIp)))
|
||||
}
|
||||
err = c.Forward.Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to run forward: %v", err)
|
||||
}
|
||||
c.isOk = true
|
||||
c.localipport = fmt.Sprintf("%s:%d", getIP(innerIp), getPort(innerIp))
|
||||
c.ExtUrl = fmt.Sprintf("%s:%d", getIP(extIp), getPort(extIp))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NatThrough) HealthCheck() {
|
||||
getIP := func(ip net.Addr) string {
|
||||
switch ip.(type) {
|
||||
case *net.TCPAddr:
|
||||
return ip.(*net.TCPAddr).IP.String()
|
||||
case *net.UDPAddr:
|
||||
return ip.(*net.UDPAddr).IP.String()
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
getPort := func(ip net.Addr) int {
|
||||
switch ip.(type) {
|
||||
case *net.TCPAddr:
|
||||
return ip.(*net.TCPAddr).Port
|
||||
case *net.UDPAddr:
|
||||
return ip.(*net.UDPAddr).Port
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
count := 0
|
||||
if c.HealthCheckInterval == 0 {
|
||||
c.HealthCheckInterval = 30
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCtx.Done():
|
||||
return
|
||||
case <-time.After(time.Second * time.Duration(c.HealthCheckInterval)):
|
||||
}
|
||||
if c.Type == "udp" {
|
||||
_, extIp, err := c.UdpKeppaliveSTUN(c.Forward.UdpListener(), c.STUN)
|
||||
if err != nil {
|
||||
count++
|
||||
starlog.Errorf("Health Check Error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
extUrl := fmt.Sprintf("%s:%d", getIP(extIp), getPort(extIp))
|
||||
if c.ExtUrl != extUrl {
|
||||
count++
|
||||
} else {
|
||||
count = 0
|
||||
}
|
||||
starlog.Noticef("Health Check:Origin %s,Current %s\n", c.ExtUrl, extUrl)
|
||||
} else {
|
||||
conn, err := net.DialTimeout("tcp", c.ExtUrl, time.Second*2)
|
||||
if err != nil {
|
||||
starlog.Warningf("Health Check Fail: %v\n", err)
|
||||
count++
|
||||
} else {
|
||||
count = 0
|
||||
starlog.Infof("Health Check Ok\n")
|
||||
conn.(*net.TCPConn).SetLinger(0)
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
if count >= 3 {
|
||||
count = 0
|
||||
starlog.Errorf("Failed to connect to remote, close connection retrying\n")
|
||||
c.stopFn()
|
||||
c.keepaliveConn.Close()
|
||||
c.Forward.Close()
|
||||
forward := netforward.NetForward{
|
||||
LocalAddr: c.Forward.LocalAddr,
|
||||
LocalPort: c.OriginLocalPort,
|
||||
RemoteURI: c.Forward.RemoteURI,
|
||||
KeepAlivePeriod: c.KeepAlivePeriod,
|
||||
KeepAliveIdel: c.KeepAliveIdel,
|
||||
KeepAliveCount: c.KeepAliveCount,
|
||||
UsingKeepAlive: true,
|
||||
}
|
||||
time.Sleep(time.Second * 22)
|
||||
c.Forward = forward
|
||||
c.Run()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NatThrough) KeepAlive(localAddr string, localPort int) error {
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCtx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
if c.Type == "tcp" {
|
||||
dialer := net.Dialer{
|
||||
Control: netforward.ControlSetReUseAddr,
|
||||
LocalAddr: &net.TCPAddr{IP: net.ParseIP(localAddr), Port: localPort},
|
||||
}
|
||||
conn, err := dialer.Dial("tcp", c.Remote)
|
||||
if err != nil {
|
||||
starlog.Errorf("Failed to dial remote: %v\n", err)
|
||||
time.Sleep(time.Second * 5)
|
||||
continue
|
||||
}
|
||||
c.keepaliveConn = conn
|
||||
conn.(*net.TCPConn).SetLinger(0)
|
||||
netforward.SetTcpInfo(conn.(*net.TCPConn), true, c.KeepAliveIdel, c.KeepAlivePeriod, c.KeepAliveCount, 0)
|
||||
starlog.Infof("Keepalive local:%s remote: %s\n", conn.LocalAddr().String(), conn.RemoteAddr().String())
|
||||
go func() {
|
||||
for {
|
||||
str := fmt.Sprintf("HEAD /keep-alive HTTP/1.1\r\n"+
|
||||
"Host: %s\r\n"+
|
||||
"User-Agent: curl/8.0.0 (B612)\r\n"+
|
||||
"Accept: */*\r\n"+
|
||||
"Connection: keep-alive\r\n\r\n", strings.Split(c.Remote, ":")[0])
|
||||
//fmt.Println(str)
|
||||
if _, err = conn.Write([]byte(str)); err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
time.Sleep(time.Second * 20)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
_, err := conn.Read(make([]byte, 4096))
|
||||
if err != nil {
|
||||
starlog.Warningf("Failed to keepalive remote: %v\n", err)
|
||||
conn.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if c.Type == "udp" {
|
||||
rmtUdpAddr, err := net.ResolveUDPAddr("udp", c.Remote)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.Forward.UdpListener() == nil {
|
||||
time.Sleep(time.Second * 5)
|
||||
continue
|
||||
}
|
||||
c.keepaliveConn = c.Forward.UdpListener()
|
||||
for {
|
||||
_, err = c.Forward.UdpListener().WriteTo([]byte("b612 udp nat through"), rmtUdpAddr)
|
||||
if err != nil {
|
||||
c.keepaliveConn.Close()
|
||||
starlog.Warningf("Failed to keepalive remote: %v\n", err)
|
||||
time.Sleep(time.Second * 30)
|
||||
break
|
||||
}
|
||||
starlog.Infof("UDP Keepalive Ok! %v\n", rmtUdpAddr.String())
|
||||
time.Sleep(time.Second * 30)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NatThrough) HandleUPnP(localaddr string, extPort uint16) {
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
client, err := c.FoundUsableUPnP()
|
||||
if err != nil {
|
||||
starlog.Errorf("Failed to find UPnP device: %v\n", err)
|
||||
time.Sleep(time.Second * 20)
|
||||
continue
|
||||
}
|
||||
starlog.Infof("Found UPnP device!\n")
|
||||
_, _, _, _, _, err = client.GetSpecificPortMappingEntry("", uint16(c.Forward.LocalPort), "TCP")
|
||||
if err == nil {
|
||||
starlog.Infof("Port mapping Ok\n")
|
||||
time.Sleep(time.Second * 20)
|
||||
continue
|
||||
}
|
||||
err = client.AddPortMapping("", uint16(c.Forward.LocalPort), strings.ToUpper(c.Type), uint16(c.Forward.LocalPort), localaddr, true, "B612 TCP Nat PassThrough", 75)
|
||||
if err != nil {
|
||||
starlog.Errorf("Failed to add port mapping: %v\n", err)
|
||||
time.Sleep(time.Second * 20)
|
||||
continue
|
||||
}
|
||||
starlog.Infof("Port mapping added:externalPort %d,localAddr %s,localPort %d\n", extPort, localaddr, c.Forward.LocalPort)
|
||||
time.Sleep(time.Second * 20)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NatThrough) GetIPPortFromSTUN(netType string, localip string, localPort int, stunServer string) (net.Addr, net.Addr, error) {
|
||||
// 替换为你的 TURN 服务器地址
|
||||
stunAddr, err := net.ResolveUDPAddr("udp", stunServer)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to resolve STUN server address: %v", err)
|
||||
}
|
||||
var conn net.Conn
|
||||
if netType == "tcp" {
|
||||
dialer := net.Dialer{
|
||||
Control: netforward.ControlSetReUseAddr,
|
||||
LocalAddr: &net.TCPAddr{IP: net.ParseIP(localip), Port: localPort},
|
||||
}
|
||||
conn, err = dialer.Dial("tcp", stunAddr.String())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
conn.(*net.TCPConn).SetLinger(0)
|
||||
}
|
||||
if netType == "udp" {
|
||||
|
||||
conn, err = net.DialUDP(netType, &net.UDPAddr{IP: net.ParseIP(localip), Port: localPort}, stunAddr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to connect to STUN server: %v", err)
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
innerAddr := conn.LocalAddr()
|
||||
|
||||
// Create STUN request
|
||||
transactionID := make([]byte, 12)
|
||||
rand.Read(transactionID)
|
||||
stunRequest := make([]byte, 20)
|
||||
binary.BigEndian.PutUint16(stunRequest[0:], 0x0001) // Message Type: Binding Request
|
||||
binary.BigEndian.PutUint16(stunRequest[2:], 0x0000) // Message Length
|
||||
copy(stunRequest[4:], []byte{0x21, 0x12, 0xa4, 0x42}) // Magic Cookie
|
||||
copy(stunRequest[8:], transactionID) // Transaction ID
|
||||
|
||||
_, err = conn.Write(stunRequest)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to send STUN request: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to receive STUN response: %v", err)
|
||||
}
|
||||
|
||||
// Parse STUN response
|
||||
if n < 20 {
|
||||
return nil, nil, fmt.Errorf("invalid STUN response")
|
||||
}
|
||||
|
||||
payload := buf[20:n]
|
||||
var ip uint32
|
||||
var port uint16
|
||||
for len(payload) > 0 {
|
||||
attrType := binary.BigEndian.Uint16(payload[0:])
|
||||
attrLen := binary.BigEndian.Uint16(payload[2:])
|
||||
if len(payload) < int(4+attrLen) {
|
||||
return nil, nil, fmt.Errorf("invalid STUN attribute length")
|
||||
}
|
||||
|
||||
if attrType == 0x0001 || attrType == 0x0020 {
|
||||
port = binary.BigEndian.Uint16(payload[6:])
|
||||
ip = binary.BigEndian.Uint32(payload[8:])
|
||||
if attrType == 0x0020 {
|
||||
port ^= 0x2112
|
||||
ip ^= 0x2112a442
|
||||
}
|
||||
break
|
||||
}
|
||||
payload = payload[4+attrLen:]
|
||||
}
|
||||
|
||||
if ip == 0 || port == 0 {
|
||||
return nil, nil, fmt.Errorf("invalid STUN response")
|
||||
}
|
||||
|
||||
outerAddr := &net.UDPAddr{
|
||||
IP: net.IPv4(byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)),
|
||||
Port: int(port),
|
||||
}
|
||||
|
||||
return innerAddr, outerAddr, nil
|
||||
}
|
||||
|
||||
func (c *NatThrough) UdpKeppaliveSTUN(conn *net.UDPConn, stunServer string) (net.Addr, net.Addr, error) {
|
||||
// 替换为你的 TURN 服务器地址
|
||||
var target *starmap.StarStack
|
||||
{
|
||||
tmpConn, err := net.Dial("udp", stunServer)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to connect to STUN server: %v", err)
|
||||
}
|
||||
if c.Forward.UdpHooks == nil {
|
||||
c.Forward.UdpHooks = make(map[string]*starmap.StarStack)
|
||||
}
|
||||
if c.Forward.UdpHooks[tmpConn.RemoteAddr().String()] == nil {
|
||||
c.Forward.UdpHooks[tmpConn.RemoteAddr().String()] = starmap.NewStarStack(16)
|
||||
}
|
||||
target = c.Forward.UdpHooks[tmpConn.RemoteAddr().String()]
|
||||
tmpConn.Close()
|
||||
}
|
||||
stunAddr, err := net.ResolveUDPAddr("udp", stunServer)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to resolve STUN server address: %v", err)
|
||||
}
|
||||
innerAddr := conn.LocalAddr()
|
||||
|
||||
// Create STUN request
|
||||
transactionID := make([]byte, 12)
|
||||
rand.Read(transactionID)
|
||||
stunRequest := make([]byte, 20)
|
||||
binary.BigEndian.PutUint16(stunRequest[0:], 0x0001) // Message Type: Binding Request
|
||||
binary.BigEndian.PutUint16(stunRequest[2:], 0x0000) // Message Length
|
||||
copy(stunRequest[4:], []byte{0x21, 0x12, 0xa4, 0x42}) // Magic Cookie
|
||||
copy(stunRequest[8:], transactionID) // Transaction ID
|
||||
|
||||
_, err = conn.WriteToUDP(stunRequest, stunAddr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to send STUN request: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 2500)
|
||||
tmp, err := target.Pop()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to receive STUN response: %v", err)
|
||||
}
|
||||
buf := tmp.([]byte)
|
||||
n := len(buf)
|
||||
|
||||
// Parse STUN response
|
||||
if n < 20 {
|
||||
return nil, nil, fmt.Errorf("invalid STUN response")
|
||||
}
|
||||
|
||||
payload := buf[20:n]
|
||||
var ip uint32
|
||||
var port uint16
|
||||
for len(payload) > 0 {
|
||||
attrType := binary.BigEndian.Uint16(payload[0:])
|
||||
attrLen := binary.BigEndian.Uint16(payload[2:])
|
||||
if len(payload) < int(4+attrLen) {
|
||||
return nil, nil, fmt.Errorf("invalid STUN attribute length")
|
||||
}
|
||||
|
||||
if attrType == 0x0001 || attrType == 0x0020 {
|
||||
port = binary.BigEndian.Uint16(payload[6:])
|
||||
ip = binary.BigEndian.Uint32(payload[8:])
|
||||
if attrType == 0x0020 {
|
||||
port ^= 0x2112
|
||||
ip ^= 0x2112a442
|
||||
}
|
||||
break
|
||||
}
|
||||
payload = payload[4+attrLen:]
|
||||
}
|
||||
|
||||
if ip == 0 || port == 0 {
|
||||
return nil, nil, fmt.Errorf("invalid STUN response")
|
||||
}
|
||||
|
||||
outerAddr := &net.UDPAddr{
|
||||
IP: net.IPv4(byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)),
|
||||
Port: int(port),
|
||||
}
|
||||
|
||||
return innerAddr, outerAddr, nil
|
||||
}
|
||||
|
||||
func (c *NatThrough) GetMyOutIP() string {
|
||||
tmp, err := net.Dial("udp", "8.8.8.8:53")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return tmp.LocalAddr().(*net.UDPAddr).IP.String()
|
||||
}
|
||||
func (c *NatThrough) FoundUsableUPnP() (RouterClient, error) {
|
||||
wg := sync.WaitGroup{}
|
||||
found := false
|
||||
result := make(chan RouterClient, 3)
|
||||
defer close(result)
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
clients, errors, err := internetgateway2.NewWANIPConnection2Clients()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
return
|
||||
}
|
||||
if len(clients) == 0 {
|
||||
return
|
||||
}
|
||||
starlog.Infof("Found WANIPConnection2 clients:%s\n", clients[0].Location.String())
|
||||
found = true
|
||||
|
||||
result <- clients[0]
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
clients, errors, err := internetgateway2.NewWANIPConnection1Clients()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
return
|
||||
}
|
||||
if len(clients) == 0 {
|
||||
return
|
||||
}
|
||||
starlog.Infof("Found WANIPConnection1 clients:%s\n", clients[0].Location.String())
|
||||
found = true
|
||||
result <- clients[0]
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
clients, errors, err := internetgateway2.NewWANPPPConnection1Clients()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
return
|
||||
}
|
||||
if len(clients) == 0 {
|
||||
return
|
||||
}
|
||||
starlog.Infof("Found WANPPPConnection1 clients:%s\n", clients[0].Location.String())
|
||||
found = true
|
||||
result <- clients[0]
|
||||
}()
|
||||
wg.Wait()
|
||||
if found {
|
||||
return <-result, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no UPnP devices discovered")
|
||||
}
|
||||
|
||||
type RouterClient interface {
|
||||
AddPortMapping(
|
||||
NewRemoteHost string,
|
||||
NewExternalPort uint16,
|
||||
NewProtocol string,
|
||||
NewInternalPort uint16,
|
||||
NewInternalClient string,
|
||||
NewEnabled bool,
|
||||
NewPortMappingDescription string,
|
||||
NewLeaseDuration uint32,
|
||||
) (err error)
|
||||
|
||||
GetExternalIPAddress() (
|
||||
NewExternalIPAddress string,
|
||||
err error,
|
||||
)
|
||||
|
||||
DeletePortMapping(NewRemoteHost string, NewExternalPort uint16, NewProtocol string) (err error)
|
||||
GetSpecificPortMappingEntry(NewRemoteHost string, NewExternalPort uint16, NewProtocol string) (NewInternalPort uint16, NewInternalClient string, NewEnabled bool, NewPortMappingDescription string, NewLeaseDuration uint32, err error)
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"b612.me/apps/b612/netforward"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNathrough(t *testing.T) {
|
||||
var n = NatThrough{
|
||||
Forward: netforward.NetForward{
|
||||
LocalAddr: "0.0.0.0",
|
||||
LocalPort: 0,
|
||||
RemoteURI: "127.0.0.1:88",
|
||||
EnableTCP: true,
|
||||
EnableUDP: false,
|
||||
DelayMilSec: 0,
|
||||
DelayToward: 0,
|
||||
StdinMode: false,
|
||||
IgnoreEof: false,
|
||||
DialTimeout: 3000,
|
||||
UDPTimeout: 3000,
|
||||
KeepAlivePeriod: 30,
|
||||
KeepAliveIdel: 30,
|
||||
KeepAliveCount: 5,
|
||||
UserTimeout: 0,
|
||||
UsingKeepAlive: true,
|
||||
},
|
||||
Type: "tcp",
|
||||
STUN: "turn.b612.me:3478",
|
||||
Remote: "baidu.com:80",
|
||||
KeepAlivePeriod: 3000,
|
||||
KeepAliveIdel: 3000,
|
||||
KeepAliveCount: 5,
|
||||
AutoUPnP: true,
|
||||
stopFn: nil,
|
||||
stopCtx: nil,
|
||||
}
|
||||
go func() {
|
||||
time.Sleep(time.Second * 10)
|
||||
fmt.Println(n.ExtUrl)
|
||||
}()
|
||||
if err := n.Run(); err != nil {
|
||||
fmt.Println(err)
|
||||
t.Error(err)
|
||||
}
|
||||
n.HealthCheck()
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestScan(t *testing.T) {
|
||||
s := ScanPort{
|
||||
Host: "192.168.2.109",
|
||||
Timeout: 2000,
|
||||
Threads: 5000,
|
||||
}
|
||||
if err := s.Parse("1-65535"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := s.Run(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanIP(t *testing.T) {
|
||||
s := ScanIP{
|
||||
Host: "192.168.2.1",
|
||||
CIDR: 23,
|
||||
Timeout: 2000,
|
||||
Threads: 5000,
|
||||
ScanType: "icmp",
|
||||
WithHostname: true,
|
||||
}
|
||||
if err := s.ICMP(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
@ -0,0 +1,279 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"b612.me/apps/b612/netforward"
|
||||
"b612.me/stario"
|
||||
"b612.me/starlog"
|
||||
"b612.me/starnet"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ScanIP struct {
|
||||
Host string
|
||||
CIDR int
|
||||
Port int
|
||||
Mask string
|
||||
Threads int
|
||||
Timeout int
|
||||
ScanType string
|
||||
ipNet *net.IPNet
|
||||
Log string
|
||||
Retry int
|
||||
WithHostname bool
|
||||
}
|
||||
|
||||
func (s *ScanIP) Parse() error {
|
||||
if s.CIDR == 0 && s.Mask == "" {
|
||||
return fmt.Errorf("CIDR or Mask must be set")
|
||||
|
||||
}
|
||||
if s.CIDR != 0 {
|
||||
return nil
|
||||
}
|
||||
//mask to cidr
|
||||
ipMask := net.IPMask(net.ParseIP(s.Mask).To4())
|
||||
if ipMask == nil {
|
||||
return fmt.Errorf("invalid mask")
|
||||
}
|
||||
s.CIDR, _ = ipMask.Size()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ScanIP) nextIP(ipStr string) (net.IP, error) {
|
||||
var err error
|
||||
if s.ipNet == nil {
|
||||
_, s.ipNet, err = net.ParseCIDR(s.Host + "/" + fmt.Sprintf("%d", s.CIDR))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid CIDR: %v", err)
|
||||
}
|
||||
}
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid IP: %v", ipStr)
|
||||
}
|
||||
|
||||
// Convert IP to 4-byte representation
|
||||
ip = ip.To4()
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("non-IPv4 address: %v", ipStr)
|
||||
}
|
||||
|
||||
// Increment IP
|
||||
for i := len(ip) - 1; i >= 0; i-- {
|
||||
ip[i]++
|
||||
if ip[i] > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check if incremented IP is still in range
|
||||
if !s.ipNet.Contains(ip) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
func (s *ScanIP) NetSize() (int, error) {
|
||||
var err error
|
||||
if s.ipNet == nil {
|
||||
_, s.ipNet, err = net.ParseCIDR(s.Host + "/" + fmt.Sprintf("%d", s.CIDR))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid CIDR: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
maskSize, _ := s.ipNet.Mask.Size()
|
||||
return int(math.Pow(2, float64(32-maskSize))) - 2, nil
|
||||
}
|
||||
|
||||
func (s *ScanIP) FirstLastIP() (net.IP, net.IP, error) {
|
||||
var err error
|
||||
if s.ipNet == nil {
|
||||
_, s.ipNet, err = net.ParseCIDR(s.Host + "/" + fmt.Sprintf("%d", s.CIDR))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid CIDR: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
firstIP := s.ipNet.IP.Mask(s.ipNet.Mask)
|
||||
lastIP := make(net.IP, len(firstIP))
|
||||
copy(lastIP, firstIP)
|
||||
for i := range firstIP {
|
||||
lastIP[i] = firstIP[i] | ^s.ipNet.Mask[i]
|
||||
}
|
||||
|
||||
return firstIP, lastIP, nil
|
||||
}
|
||||
|
||||
func (s *ScanIP) ICMP() error {
|
||||
if s.ScanType != "icmp" {
|
||||
return fmt.Errorf("scan type must be icmp")
|
||||
}
|
||||
if err := s.Parse(); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.Log != "" {
|
||||
starlog.SetLogFile(s.Log, starlog.Std, true)
|
||||
}
|
||||
firstIP, lastIP, err := s.FirstLastIP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns, _ := s.NetSize()
|
||||
starlog.Infof("Scan %s/%d\n", s.Host, s.CIDR)
|
||||
starlog.Infof("Scan %s-%s\n", firstIP.String(), lastIP.String())
|
||||
starlog.Infof("There are %d hosts\n", ns)
|
||||
starlog.Infof("Threads: %d\n", s.Threads)
|
||||
|
||||
wg := stario.NewWaitGroup(s.Threads)
|
||||
count := int32(0)
|
||||
allcount := int32(0)
|
||||
interrupt := make(chan [2]string)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Second * 2):
|
||||
fmt.Printf("scan %d ips, %d up\r", allcount, count)
|
||||
case ip, opened := <-interrupt:
|
||||
if !opened {
|
||||
return
|
||||
}
|
||||
if s.WithHostname {
|
||||
starlog.Infof("Host %v is up, Name:%v\n", ip[0], ip[1])
|
||||
} else {
|
||||
starlog.Infof("Host %v is up\n", ip[0])
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
idx := 0
|
||||
for {
|
||||
ip := firstIP.String()
|
||||
if ip == lastIP.String() {
|
||||
break
|
||||
}
|
||||
idx++
|
||||
wg.Add(1)
|
||||
go func(ip string, idx int) {
|
||||
defer func() {
|
||||
atomic.AddInt32(&allcount, 1)
|
||||
}()
|
||||
defer wg.Done()
|
||||
for i := 0; i < s.Retry+1; i++ {
|
||||
_, err := starnet.Ping(ip, idx, time.Duration(s.Timeout)*time.Millisecond)
|
||||
if err == nil {
|
||||
atomic.AddInt32(&count, 1)
|
||||
if s.WithHostname {
|
||||
hostname, err := net.LookupAddr(ip)
|
||||
if err == nil {
|
||||
interrupt <- [2]string{ip, hostname[0]}
|
||||
return
|
||||
}
|
||||
}
|
||||
interrupt <- [2]string{ip, ""}
|
||||
return
|
||||
}
|
||||
}
|
||||
}(ip, idx)
|
||||
firstIP, _ = s.nextIP(ip)
|
||||
}
|
||||
wg.Wait()
|
||||
close(interrupt)
|
||||
starlog.Infof("scan %d ips, %d up\n", ns, count)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ScanIP) TCP(port int) error {
|
||||
if s.ScanType != "tcp" {
|
||||
return fmt.Errorf("scan type must be tcp")
|
||||
}
|
||||
if err := s.Parse(); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.Log != "" {
|
||||
starlog.SetLogFile(s.Log, starlog.Std, true)
|
||||
}
|
||||
firstIP, lastIP, err := s.FirstLastIP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns, _ := s.NetSize()
|
||||
starlog.Infof("Scan %s/%d\n", s.Host, s.CIDR)
|
||||
starlog.Infof("Scan %s-%s\n", firstIP.String(), lastIP.String())
|
||||
starlog.Infof("There are %d hosts\n", ns)
|
||||
starlog.Infof("Threads: %d\n", s.Threads)
|
||||
|
||||
wg := stario.NewWaitGroup(s.Threads)
|
||||
count := int32(0)
|
||||
allcount := int32(0)
|
||||
interrupt := make(chan [2]string)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Second * 2):
|
||||
fmt.Printf("scan %d ips, %d up\r", allcount, count)
|
||||
case ip, opened := <-interrupt:
|
||||
if !opened {
|
||||
return
|
||||
}
|
||||
if s.WithHostname {
|
||||
starlog.Infof("Host %v is up, Name:%v\n", ip[0], ip[1])
|
||||
} else {
|
||||
starlog.Infof("Host %v is up\n", ip[0])
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
idx := 0
|
||||
localAddr, err := net.ResolveTCPAddr("tcp", ":0")
|
||||
if err != nil {
|
||||
starlog.Errorln("ResolveTCPAddr error, ", err)
|
||||
return err
|
||||
}
|
||||
for {
|
||||
ip := firstIP.String()
|
||||
if ip == lastIP.String() {
|
||||
break
|
||||
}
|
||||
idx++
|
||||
wg.Add(1)
|
||||
go func(ip string, idx int) {
|
||||
defer func() {
|
||||
atomic.AddInt32(&allcount, 1)
|
||||
}()
|
||||
defer wg.Done()
|
||||
for i := 0; i < s.Retry+1; i++ {
|
||||
dialer := net.Dialer{
|
||||
LocalAddr: localAddr,
|
||||
Timeout: time.Duration(s.Timeout) * time.Millisecond,
|
||||
Control: netforward.ControlSetReUseAddr,
|
||||
}
|
||||
_, err := dialer.Dial("tcp", fmt.Sprintf("%s:%d", ip, port))
|
||||
if err == nil {
|
||||
atomic.AddInt32(&count, 1)
|
||||
if s.WithHostname {
|
||||
hostname, err := net.LookupAddr(ip)
|
||||
if err == nil {
|
||||
interrupt <- [2]string{ip, hostname[0]}
|
||||
return
|
||||
}
|
||||
}
|
||||
interrupt <- [2]string{ip, ""}
|
||||
return
|
||||
}
|
||||
}
|
||||
}(ip, idx)
|
||||
firstIP, _ = s.nextIP(ip)
|
||||
}
|
||||
wg.Wait()
|
||||
close(interrupt)
|
||||
starlog.Infof("scan %d ips, %d up\n", ns, count)
|
||||
return nil
|
||||
}
|
@ -0,0 +1,131 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"b612.me/apps/b612/netforward"
|
||||
"b612.me/stario"
|
||||
"b612.me/starlog"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ScanPort struct {
|
||||
Host string
|
||||
Ports []int
|
||||
Timeout int
|
||||
Threads int
|
||||
Log string
|
||||
Retry int
|
||||
}
|
||||
|
||||
func (s *ScanPort) Parse(potStr string) error {
|
||||
ports := strings.Split(potStr, ",")
|
||||
for _, port := range ports {
|
||||
port = strings.TrimSpace(port)
|
||||
if strings.Contains(port, "-") {
|
||||
// range
|
||||
r := strings.Split(port, "-")
|
||||
if len(r) != 2 {
|
||||
continue
|
||||
}
|
||||
start, err := strconv.Atoi(r[0])
|
||||
if err != nil {
|
||||
starlog.Warningf("invalid port: %s\n", r[0])
|
||||
continue
|
||||
}
|
||||
end, err := strconv.Atoi(r[1])
|
||||
if err != nil {
|
||||
starlog.Warningf("invalid port: %s\n", r[1])
|
||||
continue
|
||||
}
|
||||
for i := start; i <= end; i++ {
|
||||
if i < 1 || i > 65535 {
|
||||
starlog.Warningf("invalid port: %d\n", i)
|
||||
continue
|
||||
}
|
||||
s.Ports = append(s.Ports, i)
|
||||
}
|
||||
} else {
|
||||
// single port
|
||||
tmp, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
starlog.Warningf("invalid port: %s\n", port)
|
||||
continue
|
||||
}
|
||||
if tmp < 1 || tmp > 65535 {
|
||||
starlog.Warningf("invalid port: %d\n", tmp)
|
||||
continue
|
||||
}
|
||||
s.Ports = append(s.Ports, tmp)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ScanPort) Run() error {
|
||||
if s.Threads < 1 {
|
||||
s.Threads = 1
|
||||
}
|
||||
if s.Log != "" {
|
||||
starlog.SetLogFile(s.Log, starlog.Std, true)
|
||||
}
|
||||
sort.Ints(s.Ports)
|
||||
starlog.Infof("scan count %d ports for host %v\n", len(s.Ports), s.Host)
|
||||
wg := stario.NewWaitGroup(s.Threads)
|
||||
localAddr, err := net.ResolveTCPAddr("tcp", ":0")
|
||||
if err != nil {
|
||||
starlog.Errorln("ResolveTCPAddr error, ", err)
|
||||
return err
|
||||
}
|
||||
count := int32(0)
|
||||
allcount := int32(0)
|
||||
interrupt := make(chan int)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Second * 2):
|
||||
fmt.Printf("scan %d ports, %d open\r", atomic.LoadInt32(&allcount), count)
|
||||
case port, opened := <-interrupt:
|
||||
if !opened {
|
||||
return
|
||||
}
|
||||
starlog.Infof("port %d is open\n", port)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
for _, port := range s.Ports {
|
||||
wg.Add(1)
|
||||
go func(port int) {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
atomic.AddInt32(&allcount, 1)
|
||||
}()
|
||||
for i := 0; i < s.Retry+1; i++ {
|
||||
dialer := net.Dialer{
|
||||
LocalAddr: localAddr,
|
||||
Timeout: time.Duration(s.Timeout) * time.Millisecond,
|
||||
Control: netforward.ControlSetReUseAddr,
|
||||
}
|
||||
conn, err := dialer.Dial("tcp", net.JoinHostPort(s.Host, strconv.Itoa(port)))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.(*net.TCPConn).SetLinger(0)
|
||||
conn.Close()
|
||||
interrupt <- port
|
||||
atomic.AddInt32(&count, 1)
|
||||
return
|
||||
}
|
||||
}(port)
|
||||
|
||||
}
|
||||
wg.Wait()
|
||||
close(interrupt)
|
||||
starlog.Infof("scan %d ports, %d open\n", len(s.Ports), count)
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue