package tls

import (
	"b612.me/starlog"
	"crypto/ecdsa"
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"fmt"
	"github.com/spf13/cobra"
	"golang.org/x/net/idna"
	"golang.org/x/net/proxy"
	"net"
	"os"
	"path/filepath"
	"strconv"
	"strings"
	"sync"
	"time"
)

var hideDetail bool
var dump string
var reqRawIP int
var timeoutMillSec int
var socks5 string
var socks5Auth string
var showCA bool

func init() {
	Cmd.Flags().BoolVarP(&hideDetail, "hide-detail", "H", false, "隐藏证书详细信息")
	Cmd.Flags().StringVarP(&dump, "dump", "d", "", "将证书保存到文件")
	Cmd.Flags().IntVarP(&reqRawIP, "resolve-ip", "r", 0, "使用解析到的IP地址进行连接,输入数字表示使用解析到的第几个IP地址")
	Cmd.Flags().IntVarP(&timeoutMillSec, "timeout", "t", 5000, "连接超时时间(毫秒)")
	Cmd.Flags().StringVarP(&socks5, "socks5", "p", "", "socks5代理,示例:127.0.0.1:1080")
	Cmd.Flags().StringVarP(&socks5Auth, "socks5-auth", "A", "", "socks5代理认证,示例:username:password")
	Cmd.Flags().BoolVarP(&showCA, "show-ca", "c", false, "显示CA证书")
}

var Cmd = &cobra.Command{
	Use:   "tls",
	Short: "查看TLS证书信息",
	Long:  "查看TLS证书信息",
	Run: func(cmd *cobra.Command, args []string) {
		for _, target := range args {
			showTls(target, !hideDetail, showCA, reqRawIP, dump, time.Duration(timeoutMillSec)*time.Millisecond)
		}
	},
}

func showTls(target string, showDetail, showCA bool, reqRawIP int, dumpPath string, timeout time.Duration) {
	var err error
	{
		sp := strings.Split(target, ":")
		if len(sp) < 2 {
			target = target + ":443"
		} else {
			if _, err := strconv.Atoi(sp[len(sp)-1]); err != nil {
				target = target + ":443"
			}
		}
	}
	if timeout == 0 {
		timeout = 5 * time.Second
	}
	hostname := strings.Split(target, ":")[0]
	if strings.Count(target, ":") == 2 {
		strs := strings.Split(target, ":")
		if len(strs) != 3 {
			starlog.Errorln("invalid target format")
			return
		}
		target = strs[0] + ":" + strs[2]
		hostname = strs[1]
	}
	if reqRawIP > 0 {
		domain := strings.Split(target, ":")[0]
		ips, err := net.LookupIP(domain)
		if err != nil {
			starlog.Errorln("failed to resolve domain: " + err.Error())
			return
		}
		if len(ips) == 0 {
			starlog.Errorln("no ip found for domain")
			return
		}
		for _, v := range ips {
			starlog.Infof("解析到的IP地址为: %s\n", v.String())
		}
		if reqRawIP > len(ips) {
			reqRawIP = len(ips)
		}
		target = ips[reqRawIP-1].String() + ":443"
		hostname = ips[reqRawIP-1].String()
		starlog.Noticeln("使用解析到的IP地址进行连接:", target)
	}
	starlog.Noticef("将使用如下地址连接:%s ; ServerName: %s\n", target, hostname)
	punyCode, err := idna.ToASCII(hostname)
	if err == nil {
		if punyCode != hostname {
			starlog.Infoln("检测到域名中含有非ASCII字符,PunyCode转换后为:", punyCode)
			hostname = punyCode
		}
	}
	starlog.Infof("正在连接服务器: %s\n", target)
	var netDialer = &net.Dialer{
		Timeout: timeout,
	}
	var socksDialer *proxy.Dialer
	if socks5 != "" {
		var auth *proxy.Auth
		if socks5Auth != "" {
			up := strings.SplitN(socks5Auth, ":", 2)
			if len(up) == 2 {
				auth = &proxy.Auth{
					User:     up[0],
					Password: up[1],
				}
			} else {
				starlog.Errorln("socks5认证格式错误")
				return
			}
		}
		s5Dial, err := proxy.SOCKS5("tcp", socks5, auth, proxy.Direct)
		if err == nil {
			socksDialer = &s5Dial
		} else {
			starlog.Errorln("socks5代理错误:", err)
			return
		}
	}
	var verifyErr error
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		var conn *tls.Conn
		if socksDialer == nil {
			conn, verifyErr = tls.DialWithDialer(netDialer, "tcp", target, &tls.Config{
				InsecureSkipVerify: false,
				ServerName:         hostname,
				MinVersion:         tls.VersionSSL30,
			})
			if verifyErr == nil {
				conn.Close()
			}
		} else {
			con, err := (*socksDialer).Dial("tcp", target)
			if err != nil {
				verifyErr = err
				return
			}
			conn = tls.Client(con, &tls.Config{
				InsecureSkipVerify: false,
				ServerName:         hostname,
				MinVersion:         tls.VersionSSL30,
			})
			verifyErr = conn.Handshake()
			con.Close()
		}
	}()
	var conn *tls.Conn

	if socksDialer == nil {
		conn, err = tls.DialWithDialer(netDialer, "tcp", target, &tls.Config{
			InsecureSkipVerify: true,
			ServerName:         hostname,
			MinVersion:         tls.VersionSSL30,
		})
		if err != nil {
			starlog.Errorln("failed to connect: " + err.Error())
			return
		}
	} else {
		con, err := (*socksDialer).Dial("tcp", target)
		if err != nil {
			starlog.Errorln("failed to connect: " + err.Error())
			return
		}
		defer con.Close()
		conn = tls.Client(con, &tls.Config{
			InsecureSkipVerify: true,
			ServerName:         hostname,
			MinVersion:         tls.VersionSSL30,
		})
		err = conn.Handshake()
		if err != nil {
			starlog.Errorln("failed to handshake: " + err.Error())
			return
		}

	}
	defer conn.Close()
	starlog.Infof("连接成功,对方IP:%s,正在获取证书信息\n", conn.RemoteAddr().String())
	certs := conn.ConnectionState().PeerCertificates
	if len(certs) == 0 {
		starlog.Errorln("no certificate found")
		return
	}
	starlog.Infof("证书获取成功,证书链上共有%d个证书\n", len(certs))
	state := conn.ConnectionState()

	switch state.Version {
	case tls.VersionSSL30:
		starlog.Warningln("当前TLS版本: SSL 3.0")
	case tls.VersionTLS10:
		starlog.Warningln("当前TLS版本: TLS 1.0")
	case tls.VersionTLS11:
		starlog.Warningln("当前TLS版本: TLS 1.1")
	case tls.VersionTLS12:
		starlog.Infoln("当前TLS版本: TLS 1.2")
	case tls.VersionTLS13:
		starlog.Infoln("当前TLS版本: TLS 1.3")
	}
	switch state.CipherSuite {
	case tls.TLS_RSA_WITH_RC4_128_SHA:
		starlog.Infoln("当前加密套件: TLS_RSA_WITH_RC4_128_SHA")
	case tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA:
		starlog.Infoln("当前加密套件: TLS_RSA_WITH_3DES_EDE_CBC_SHA")
	case tls.TLS_RSA_WITH_AES_128_CBC_SHA:
		starlog.Infoln("当前加密套件: TLS_RSA_WITH_AES_128_CBC_SHA")
	case tls.TLS_RSA_WITH_AES_256_CBC_SHA:
		starlog.Infoln("当前加密套件: TLS_RSA_WITH_AES_256_CBC_SHA")
	case tls.TLS_RSA_WITH_AES_128_CBC_SHA256:
		starlog.Infoln("当前加密套件: TLS_RSA_WITH_AES_128_CBC_SHA256")
	case tls.TLS_RSA_WITH_AES_128_GCM_SHA256:
		starlog.Infoln("当前加密套件: TLS_RSA_WITH_AES_128_GCM_SHA256")
	case tls.TLS_RSA_WITH_AES_256_GCM_SHA384:
		starlog.Infoln("当前加密套件: TLS_RSA_WITH_AES_256_GCM_SHA384")
	case tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA:
		starlog.Infoln("当前加密套件: TLS_ECDHE_ECDSA_WITH_RC4_128_SHA")
	case tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA:
		starlog.Infoln("当前加密套件: TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA")
	case tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
		starlog.Infoln("当前加密套件: TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA")
	case tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA:
		starlog.Infoln("当前加密套件: TLS_ECDHE_RSA_WITH_RC4_128_SHA")
	case tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA:
		starlog.Infoln("当前加密套件: TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA")
	case tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
		starlog.Infoln("当前加密套件: TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA")
	case tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
		starlog.Infoln("当前加密套件: TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA")
	case tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256:
		starlog.Infoln("当前加密套件: TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256")
	case tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
		starlog.Infoln("当前加密套件: TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256")
	case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
		starlog.Infoln("当前加密套件: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256")
	case tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
		starlog.Infoln("当前加密套件: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")
	case tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384:
		starlog.Infoln("当前加密套件: TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384")
	case tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
		starlog.Infoln("当前加密套件: TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384")
	case tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305:
		starlog.Infoln("当前加密套件: TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305")
	case tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:
		starlog.Infoln("当前加密套件: TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305")
	case tls.TLS_AES_128_GCM_SHA256:
		starlog.Infoln("当前加密套件: TLS_AES_128_GCM_SHA256")
	case tls.TLS_AES_256_GCM_SHA384:
		starlog.Infoln("当前加密套件: TLS_AES_256_GCM_SHA384")
	case tls.TLS_CHACHA20_POLY1305_SHA256:
		starlog.Infoln("当前加密套件: TLS_CHACHA20_POLY1305_SHA256")
	default:
		starlog.Infoln("当前加密套件:", state.CipherSuite)
	}
	starlog.Infoln("服务器名称:", state.ServerName)
	wg.Wait()
	if verifyErr != nil {
		starlog.Red("证书验证失败: " + verifyErr.Error())
	} else {
		starlog.Green("证书验证成功")
	}
	if showDetail {
		for _, c := range certs {
			if c.IsCA && !showCA {
				continue
			}
			fmt.Printf("----------\n")
			if c.IsCA {
				fmt.Println("!这是一个CA证书!")
			}
			fmt.Printf("证书基础信息: %+v\n", c.Subject)
			fmt.Printf("证书颁发者: %+v\n", c.Issuer)
			fmt.Printf("证书生效时间: %+v 距今:%.1f天\n", c.NotBefore.In(time.Local), time.Since(c.NotBefore).Hours()/24)
			fmt.Printf("证书过期时间: %+v 剩余:%.1f天\n", c.NotAfter.In(time.Local), c.NotAfter.Sub(time.Now()).Hours()/24)
			fmt.Printf("证书序列号: %s\n", c.SerialNumber.Text(16))
			fmt.Printf("证书签名算法: %s\n", c.SignatureAlgorithm)
			fmt.Printf("证书公钥算法: %s\n", c.PublicKeyAlgorithm)
			switch pub := c.PublicKey.(type) {
			case *rsa.PublicKey:
				fmt.Printf("RSA公钥位数: %d\n", pub.Size()*8) // RSA公钥的位数
			case *ecdsa.PublicKey:
				fmt.Printf("ECDSA Curve位数: %d\n", pub.Curve.Params().BitSize) // ECDSA公钥的位数
			}
			if len(c.DNSNames) != 0 {
				fmt.Printf("可选使用的DNS: %s\n", strings.Join(c.DNSNames, ", "))
			}
			if len(c.IPAddresses) != 0 {
				ipAddr := ""
				for _, ip := range c.IPAddresses {
					ipAddr += ip.String() + ", "
				}
				ipAddr = ipAddr[:len(ipAddr)-2]
				fmt.Printf("可选使用的IP: %s\n", ipAddr)
			}
			if len(c.EmailAddresses) != 0 {
				fmt.Printf("可选使用的Email: %s\n", strings.Join(c.EmailAddresses, ", "))
			}
			if len(c.URIs) != 0 {
				fmt.Printf("可选使用的URI: %v\n", c.URIs)
			}
			if len(c.PermittedDNSDomains) != 0 {
				fmt.Printf("批准使用的DNS: %s\n", strings.Join(c.PermittedDNSDomains, ", "))
			}
			if len(c.PermittedIPRanges) != 0 {
				ipRange := ""
				for _, ip := range c.PermittedIPRanges {
					ipRange += ip.String() + ", "
				}
				ipRange = ipRange[:len(ipRange)-2]
				fmt.Printf("批准使用的IP: %s\n", ipRange)
			}
			if len(c.PermittedEmailAddresses) != 0 {
				fmt.Printf("批准使用的Email: %s\n", strings.Join(c.PermittedEmailAddresses, ", "))
			}
			if len(c.PermittedURIDomains) != 0 {
				fmt.Printf("批准使用的URI: %s\n", strings.Join(c.PermittedURIDomains, ", "))
			}
			fmt.Printf("证书密钥用途: %s\n", strings.Join(KeyUsageToString(c.KeyUsage), ", "))
			extKeyUsage := []string{}
			for _, v := range c.ExtKeyUsage {
				switch v {
				case x509.ExtKeyUsageAny:
					extKeyUsage = append(extKeyUsage, "任何用途")
				case x509.ExtKeyUsageServerAuth:
					extKeyUsage = append(extKeyUsage, "服务器认证")
				case x509.ExtKeyUsageClientAuth:
					extKeyUsage = append(extKeyUsage, "客户端认证")
				case x509.ExtKeyUsageCodeSigning:
					extKeyUsage = append(extKeyUsage, "代码签名")
				case x509.ExtKeyUsageEmailProtection:
					extKeyUsage = append(extKeyUsage, "电子邮件保护")
				case x509.ExtKeyUsageIPSECEndSystem:
					extKeyUsage = append(extKeyUsage, "IPSEC终端系统")
				case x509.ExtKeyUsageIPSECTunnel:
					extKeyUsage = append(extKeyUsage, "IPSEC隧道")
				case x509.ExtKeyUsageIPSECUser:
					extKeyUsage = append(extKeyUsage, "IPSEC用户")
				case x509.ExtKeyUsageTimeStamping:
					extKeyUsage = append(extKeyUsage, "时间戳")
				case x509.ExtKeyUsageOCSPSigning:
					extKeyUsage = append(extKeyUsage, "OCSP签名")
				case x509.ExtKeyUsageMicrosoftServerGatedCrypto:
					extKeyUsage = append(extKeyUsage, "Microsoft服务器门控加密")
				case x509.ExtKeyUsageNetscapeServerGatedCrypto:
					extKeyUsage = append(extKeyUsage, "Netscape服务器门控加密")
				case x509.ExtKeyUsageMicrosoftCommercialCodeSigning:
					extKeyUsage = append(extKeyUsage, "Microsoft商业代码签名")
				case x509.ExtKeyUsageMicrosoftKernelCodeSigning:
					extKeyUsage = append(extKeyUsage, "Microsoft内核代码签名")
				default:
					extKeyUsage = append(extKeyUsage, fmt.Sprintf("未知用途(%d)", v))
				}
			}
			fmt.Printf("证书扩展密钥用途: %s\n", strings.Join(extKeyUsage, ", "))
			fmt.Printf("证书版本: %d\n----------\n", c.Version)
			//fmt.Printf("证书扩展信息: %+v\n\n----------", c.Extensions)
		}
		if dumpPath != "" {
			var data []byte
			var name string
			for _, c := range certs {
				if name == "" {
					name = c.Subject.CommonName + ".crt"
				}
				certBlock := &pem.Block{
					Type:  "CERTIFICATE",
					Bytes: c.Raw,
				}
				data = append(data, pem.EncodeToMemory(certBlock)...)
			}
			err = os.WriteFile(filepath.Join(dumpPath, name), data, 0644)
			if err != nil {
				starlog.Errorln("failed to write file: " + err.Error())
				return
			}
			starlog.Infoln("dumped to " + filepath.Join(dumpPath, name))
		}
	}
}

func KeyUsageToString(ku x509.KeyUsage) []string {
	usages := []string{}
	flags := []struct {
		Flag x509.KeyUsage
		Name string
	}{
		{x509.KeyUsageDigitalSignature, "数字签名"},
		{x509.KeyUsageContentCommitment, "内容承诺"},
		{x509.KeyUsageKeyEncipherment, "密钥加密"},
		{x509.KeyUsageDataEncipherment, "数据加密"},
		{x509.KeyUsageKeyAgreement, "密钥协商"},
		{x509.KeyUsageCertSign, "证书签名"},
		{x509.KeyUsageCRLSign, "CRL签名"},
		{x509.KeyUsageEncipherOnly, "仅加密"},
		{x509.KeyUsageDecipherOnly, "仅解密"},
	}

	for _, flag := range flags {
		if ku&flag.Flag != 0 {
			usages = append(usages, flag.Name)
		}
	}

	return usages
}