From 9b64ee99d39b558188d125a859310e7fd15cf18b Mon Sep 17 00:00:00 2001 From: Emman Date: Tue, 26 Jan 2021 15:50:31 +0800 Subject: [PATCH] MAGIC - csr support, but failed to check ali kms generated csr --- sm2/sm2.go | 6 + sm2/x509.go | 660 ++++++++++++++++++++++++++++++++++++++++++++++- sm2/x509_test.go | 47 ++++ 3 files changed, 707 insertions(+), 6 deletions(-) diff --git a/sm2/sm2.go b/sm2/sm2.go index c27d9f3..ae5118c 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -332,6 +332,9 @@ func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) { // SignWithSM2 follow sm2 dsa standards for hash part func SignWithSM2(rand io.Reader, priv *ecdsa.PrivateKey, uid, msg []byte) (r, s *big.Int, err error) { + if len(uid) == 0 { + uid = defaultUID + } za, err := CalculateZA(&priv.PublicKey, uid) if err != nil { return nil, nil, err @@ -378,6 +381,9 @@ func Verify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool { // VerifyWithSM2 verifies the signature in r, s of hash using the public key, pub. Its // return value records whether the signature is valid. func VerifyWithSM2(pub *ecdsa.PublicKey, uid, msg []byte, r, s *big.Int) bool { + if len(uid) == 0 { + uid = defaultUID + } za, err := CalculateZA(pub, uid) if err != nil { return false diff --git a/sm2/x509.go b/sm2/x509.go index 4a884cd..3158a67 100644 --- a/sm2/x509.go +++ b/sm2/x509.go @@ -1,13 +1,22 @@ package sm2 import ( + "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "errors" + "fmt" + "io" "math/big" + "net" + "net/url" + "strconv" + "strings" + + "github.com/emmansun/gmsm/sm3" ) // pkixPublicKey reflects a PKIX public key structure. See SubjectPublicKeyInfo @@ -37,10 +46,50 @@ type ecdsaSignature dsaSignature // http://gmssl.org/docs/oid.html var ( - oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} + oidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33} + oidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7} + oidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34} + oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35} oidNamedCurveP256SM2 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301} + + oidSignatureSM2WithSM3 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 501} + oidSignatureSM2WithSHA1 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 502} + oidSignatureSM2WithSHA256 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 503} ) +func oidFromNamedCurve(curve elliptic.Curve) (asn1.ObjectIdentifier, bool) { + switch curve { + case elliptic.P224(): + return oidNamedCurveP224, true + case elliptic.P256(): + return oidNamedCurveP256, true + case elliptic.P384(): + return oidNamedCurveP384, true + case elliptic.P521(): + return oidNamedCurveP521, true + case P256(): + return oidNamedCurveP256SM2, true + } + + return nil, false +} + +func namedCurveFromOID(oid asn1.ObjectIdentifier) elliptic.Curve { + switch { + case oid.Equal(oidNamedCurveP224): + return elliptic.P224() + case oid.Equal(oidNamedCurveP256): + return elliptic.P256() + case oid.Equal(oidNamedCurveP384): + return elliptic.P384() + case oid.Equal(oidNamedCurveP521): + return elliptic.P521() + case oid.Equal(oidNamedCurveP256SM2): + return P256() + } + return nil +} + // ParsePKIXPublicKey parses a public key in PKIX, ASN.1 DER form. // // It returns a *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey, or @@ -100,10 +149,10 @@ func MarshalPKIXPublicKey(pub interface{}) ([]byte, error) { return x509.MarshalPKIXPublicKey(pub) } - publicKeyAlgorithm := pkix.AlgorithmIdentifier{Algorithm: oidPublicKeyECDSA} - publicKeyBytes := elliptic.Marshal(ecdPub.Curve, ecdPub.X, ecdPub.Y) - paramBytes, _ := asn1.Marshal(oidNamedCurveP256SM2) - publicKeyAlgorithm.Parameters.FullBytes = paramBytes + publicKeyBytes, publicKeyAlgorithm, err := marshalPublicKey(ecdPub) + if err != nil { + return nil, err + } pkix := pkixPublicKey{ Algo: publicKeyAlgorithm, @@ -113,6 +162,605 @@ func MarshalPKIXPublicKey(pub interface{}) ([]byte, error) { }, } - ret, _ := asn1.Marshal(pkix) + return asn1.Marshal(pkix) +} + +func marshalPublicKey(pub *ecdsa.PublicKey) (publicKeyBytes []byte, publicKeyAlgorithm pkix.AlgorithmIdentifier, err error) { + publicKeyAlgorithm = pkix.AlgorithmIdentifier{Algorithm: oidPublicKeyECDSA} + publicKeyBytes = elliptic.Marshal(pub.Curve, pub.X, pub.Y) + paramBytes, err := asn1.Marshal(oidNamedCurveP256SM2) + publicKeyAlgorithm.Parameters.FullBytes = paramBytes + return +} + +// CreateCertificateRequest creates a new certificate request based on a +// template. The following members of template are used: +// +// - SignatureAlgorithm +// - Subject +// - DNSNames +// - EmailAddresses +// - IPAddresses +// - URIs +// - ExtraExtensions +// - Attributes (deprecated) +// +// priv is the private key to sign the CSR with, and the corresponding public +// key will be included in the CSR. It must implement crypto.Signer and its +// Public() method must return a *rsa.PublicKey or a *ecdsa.PublicKey or a +// ed25519.PublicKey. (A *rsa.PrivateKey, *ecdsa.PrivateKey or +// ed25519.PrivateKey satisfies this.) +// +// The returned slice is the certificate request in DER encoding. +func CreateCertificateRequest(rand io.Reader, template *x509.CertificateRequest, priv interface{}) (csr []byte, err error) { + key, ok := priv.(crypto.Signer) + if !ok { + return nil, errors.New("x509: certificate private key does not implement crypto.Signer") + } + privKey, ok := key.(*PrivateKey) + if !ok { + return x509.CreateCertificateRequest(rand, template, priv) + } + var sigAlgo = pkix.AlgorithmIdentifier{} + sigAlgo.Algorithm = oidSignatureSM2WithSM3 + + var publicKeyBytes []byte + var publicKeyAlgorithm pkix.AlgorithmIdentifier + publicKeyBytes, publicKeyAlgorithm, err = marshalPublicKey(key.Public().(*ecdsa.PublicKey)) + if err != nil { + return nil, err + } + + var extensions []pkix.Extension + + if (len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0 || len(template.URIs) > 0) && + !oidInExtensions(oidExtensionSubjectAltName, template.ExtraExtensions) { + sanBytes, err := marshalSANs(template.DNSNames, template.EmailAddresses, template.IPAddresses, template.URIs) + if err != nil { + return nil, err + } + + extensions = append(extensions, pkix.Extension{ + Id: oidExtensionSubjectAltName, + Value: sanBytes, + }) + } + + extensions = append(extensions, template.ExtraExtensions...) + + // Make a copy of template.Attributes because we may alter it below. + attributes := make([]pkix.AttributeTypeAndValueSET, 0, len(template.Attributes)) + for _, attr := range template.Attributes { + values := make([][]pkix.AttributeTypeAndValue, len(attr.Value)) + copy(values, attr.Value) + attributes = append(attributes, pkix.AttributeTypeAndValueSET{ + Type: attr.Type, + Value: values, + }) + } + + extensionsAppended := false + if len(extensions) > 0 { + // Append the extensions to an existing attribute if possible. + for _, atvSet := range attributes { + if !atvSet.Type.Equal(oidExtensionRequest) || len(atvSet.Value) == 0 { + continue + } + + // specifiedExtensions contains all the extensions that we + // found specified via template.Attributes. + specifiedExtensions := make(map[string]bool) + + for _, atvs := range atvSet.Value { + for _, atv := range atvs { + specifiedExtensions[atv.Type.String()] = true + } + } + + newValue := make([]pkix.AttributeTypeAndValue, 0, len(atvSet.Value[0])+len(extensions)) + newValue = append(newValue, atvSet.Value[0]...) + + for _, e := range extensions { + if specifiedExtensions[e.Id.String()] { + // Attributes already contained a value for + // this extension and it takes priority. + continue + } + + newValue = append(newValue, pkix.AttributeTypeAndValue{ + // There is no place for the critical + // flag in an AttributeTypeAndValue. + Type: e.Id, + Value: e.Value, + }) + } + + atvSet.Value[0] = newValue + extensionsAppended = true + break + } + } + + rawAttributes, err := newRawAttributes(attributes) + if err != nil { + return + } + + // If not included in attributes, add a new attribute for the + // extensions. + if len(extensions) > 0 && !extensionsAppended { + attr := struct { + Type asn1.ObjectIdentifier + Value [][]pkix.Extension `asn1:"set"` + }{ + Type: oidExtensionRequest, + Value: [][]pkix.Extension{extensions}, + } + + b, err := asn1.Marshal(attr) + if err != nil { + return nil, errors.New("x509: failed to serialise extensions attribute: " + err.Error()) + } + + var rawValue asn1.RawValue + if _, err := asn1.Unmarshal(b, &rawValue); err != nil { + return nil, err + } + + rawAttributes = append(rawAttributes, rawValue) + } + + asn1Subject := template.RawSubject + if len(asn1Subject) == 0 { + asn1Subject, err = asn1.Marshal(template.Subject.ToRDNSequence()) + if err != nil { + return nil, err + } + } + + tbsCSR := tbsCertificateRequest{ + Version: 0, // PKCS #10, RFC 2986 + Subject: asn1.RawValue{FullBytes: asn1Subject}, + PublicKey: publicKeyInfo{ + Algorithm: publicKeyAlgorithm, + PublicKey: asn1.BitString{ + Bytes: publicKeyBytes, + BitLength: len(publicKeyBytes) * 8, + }, + }, + RawAttributes: rawAttributes, + } + + tbsCSRContents, err := asn1.Marshal(tbsCSR) + if err != nil { + return + } + tbsCSR.Raw = tbsCSRContents + + signed := tbsCSRContents + za, err := CalculateZA(&privKey.PublicKey, defaultUID) //Emman, use template.Subject as UID? + if err != nil { + return + } + h := sm3.New() + h.Write(za) + h.Write(signed) + signed = h.Sum(nil) + + var signature []byte + signature, err = privKey.Sign(rand, signed, nil) + if err != nil { + return + } + + return asn1.Marshal(certificateRequest{ + TBSCSR: tbsCSR, + SignatureAlgorithm: sigAlgo, + SignatureValue: asn1.BitString{ + Bytes: signature, + BitLength: len(signature) * 8, + }, + }) +} + +// These structures reflect the ASN.1 structure of X.509 certificate +// signature requests (see RFC 2986): + +type tbsCertificateRequest struct { + Raw asn1.RawContent + Version int + Subject asn1.RawValue + PublicKey publicKeyInfo + RawAttributes []asn1.RawValue `asn1:"tag:0"` +} + +type certificateRequest struct { + Raw asn1.RawContent + TBSCSR tbsCertificateRequest + SignatureAlgorithm pkix.AlgorithmIdentifier + SignatureValue asn1.BitString +} + +// oidExtensionRequest is a PKCS#9 OBJECT IDENTIFIER that indicates requested +// extensions in a CSR. +var oidExtensionRequest = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 14} +var ( + oidExtensionSubjectKeyId = []int{2, 5, 29, 14} + oidExtensionKeyUsage = []int{2, 5, 29, 15} + oidExtensionExtendedKeyUsage = []int{2, 5, 29, 37} + oidExtensionAuthorityKeyId = []int{2, 5, 29, 35} + oidExtensionBasicConstraints = []int{2, 5, 29, 19} + oidExtensionSubjectAltName = []int{2, 5, 29, 17} + oidExtensionCertificatePolicies = []int{2, 5, 29, 32} + oidExtensionNameConstraints = []int{2, 5, 29, 30} + oidExtensionCRLDistributionPoints = []int{2, 5, 29, 31} + oidExtensionAuthorityInfoAccess = []int{1, 3, 6, 1, 5, 5, 7, 1, 1} +) + +// newRawAttributes converts AttributeTypeAndValueSETs from a template +// CertificateRequest's Attributes into tbsCertificateRequest RawAttributes. +func newRawAttributes(attributes []pkix.AttributeTypeAndValueSET) ([]asn1.RawValue, error) { + var rawAttributes []asn1.RawValue + b, err := asn1.Marshal(attributes) + if err != nil { + return nil, err + } + rest, err := asn1.Unmarshal(b, &rawAttributes) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, errors.New("x509: failed to unmarshal raw CSR Attributes") + } + return rawAttributes, nil +} + +// oidNotInExtensions reports whether an extension with the given oid exists in +// extensions. +func oidInExtensions(oid asn1.ObjectIdentifier, extensions []pkix.Extension) bool { + for _, e := range extensions { + if e.Id.Equal(oid) { + return true + } + } + return false +} + +const ( + nameTypeEmail = 1 + nameTypeDNS = 2 + nameTypeURI = 6 + nameTypeIP = 7 +) + +// marshalSANs marshals a list of addresses into a the contents of an X.509 +// SubjectAlternativeName extension. +func marshalSANs(dnsNames, emailAddresses []string, ipAddresses []net.IP, uris []*url.URL) (derBytes []byte, err error) { + var rawValues []asn1.RawValue + for _, name := range dnsNames { + rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeDNS, Class: 2, Bytes: []byte(name)}) + } + for _, email := range emailAddresses { + rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeEmail, Class: 2, Bytes: []byte(email)}) + } + for _, rawIP := range ipAddresses { + // If possible, we always want to encode IPv4 addresses in 4 bytes. + ip := rawIP.To4() + if ip == nil { + ip = rawIP + } + rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeIP, Class: 2, Bytes: ip}) + } + for _, uri := range uris { + rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeURI, Class: 2, Bytes: []byte(uri.String())}) + } + return asn1.Marshal(rawValues) +} + +// ParseCertificateRequest parses a single certificate request from the +// given ASN.1 DER data. +func ParseCertificateRequest(asn1Data []byte) (*x509.CertificateRequest, error) { + var csr certificateRequest + + rest, err := asn1.Unmarshal(asn1Data, &csr) + if err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, asn1.SyntaxError{Msg: "trailing data"} + } + if !csr.SignatureAlgorithm.Algorithm.Equal(oidSignatureSM2WithSM3) { + return x509.ParseCertificateRequest(asn1Data) + } + return parseCertificateRequest(&csr) +} + +func parseCertificateRequest(in *certificateRequest) (*x509.CertificateRequest, error) { + if !oidSignatureSM2WithSM3.Equal(in.SignatureAlgorithm.Algorithm) { + return nil, errors.New("unsupport signature algorithm") + } + out := &x509.CertificateRequest{ + Raw: in.Raw, + RawTBSCertificateRequest: in.TBSCSR.Raw, + RawSubjectPublicKeyInfo: in.TBSCSR.PublicKey.Raw, + RawSubject: in.TBSCSR.Subject.FullBytes, + + Signature: in.SignatureValue.RightAlign(), + + PublicKeyAlgorithm: getPublicKeyAlgorithmFromOID(in.TBSCSR.PublicKey.Algorithm.Algorithm), + + Version: in.TBSCSR.Version, + Attributes: parseRawAttributes(in.TBSCSR.RawAttributes), + } + + var err error + out.PublicKey, err = parsePublicKey(&in.TBSCSR.PublicKey) + if err != nil { + return nil, err + } + + var subject pkix.RDNSequence + if rest, err := asn1.Unmarshal(in.TBSCSR.Subject.FullBytes, &subject); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 Subject") + } + + out.Subject.FillFromRDNSequence(&subject) + + if out.Extensions, err = parseCSRExtensions(in.TBSCSR.RawAttributes); err != nil { + return nil, err + } + + for _, extension := range out.Extensions { + if extension.Id.Equal(oidExtensionSubjectAltName) { + out.DNSNames, out.EmailAddresses, out.IPAddresses, out.URIs, err = parseSANExtension(extension.Value) + if err != nil { + return nil, err + } + } + } + + return out, nil +} + +func parsePublicKey(keyData *publicKeyInfo) (interface{}, error) { + asn1Data := keyData.PublicKey.RightAlign() + paramsData := keyData.Algorithm.Parameters.FullBytes + namedCurveOID := new(asn1.ObjectIdentifier) + rest, err := asn1.Unmarshal(paramsData, namedCurveOID) + if err != nil { + return nil, errors.New("x509: failed to parse ECDSA parameters as named curve") + } + if len(rest) != 0 { + return nil, errors.New("x509: trailing data after ECDSA parameters") + } + namedCurve := namedCurveFromOID(*namedCurveOID) + if namedCurve == nil { + return nil, errors.New("x509: unsupported elliptic curve") + } + x, y := elliptic.Unmarshal(namedCurve, asn1Data) + if x == nil { + return nil, errors.New("x509: failed to unmarshal elliptic curve point") + } + pub := &ecdsa.PublicKey{ + Curve: namedCurve, + X: x, + Y: y, + } + return pub, nil +} + +// parseRawAttributes Unmarshals RawAttributes into AttributeTypeAndValueSETs. +func parseRawAttributes(rawAttributes []asn1.RawValue) []pkix.AttributeTypeAndValueSET { + var attributes []pkix.AttributeTypeAndValueSET + for _, rawAttr := range rawAttributes { + var attr pkix.AttributeTypeAndValueSET + rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr) + // Ignore attributes that don't parse into pkix.AttributeTypeAndValueSET + // (i.e.: challengePassword or unstructuredName). + if err == nil && len(rest) == 0 { + attributes = append(attributes, attr) + } + } + return attributes +} + +// parseCSRExtensions parses the attributes from a CSR and extracts any +// requested extensions. +func parseCSRExtensions(rawAttributes []asn1.RawValue) ([]pkix.Extension, error) { + // pkcs10Attribute reflects the Attribute structure from RFC 2986, Section 4.1. + type pkcs10Attribute struct { + Id asn1.ObjectIdentifier + Values []asn1.RawValue `asn1:"set"` + } + + var ret []pkix.Extension + for _, rawAttr := range rawAttributes { + var attr pkcs10Attribute + if rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr); err != nil || len(rest) != 0 || len(attr.Values) == 0 { + // Ignore attributes that don't parse. + continue + } + + if !attr.Id.Equal(oidExtensionRequest) { + continue + } + + var extensions []pkix.Extension + if _, err := asn1.Unmarshal(attr.Values[0].FullBytes, &extensions); err != nil { + return nil, err + } + ret = append(ret, extensions...) + } + return ret, nil } + +func forEachSAN(extension []byte, callback func(tag int, data []byte) error) error { + // RFC 5280, 4.2.1.6 + + // SubjectAltName ::= GeneralNames + // + // GeneralNames ::= SEQUENCE SIZE (1..MAX) OF GeneralName + // + // GeneralName ::= CHOICE { + // otherName [0] OtherName, + // rfc822Name [1] IA5String, + // dNSName [2] IA5String, + // x400Address [3] ORAddress, + // directoryName [4] Name, + // ediPartyName [5] EDIPartyName, + // uniformResourceIdentifier [6] IA5String, + // iPAddress [7] OCTET STRING, + // registeredID [8] OBJECT IDENTIFIER } + var seq asn1.RawValue + rest, err := asn1.Unmarshal(extension, &seq) + if err != nil { + return err + } else if len(rest) != 0 { + return errors.New("x509: trailing data after X.509 extension") + } + if !seq.IsCompound || seq.Tag != 16 || seq.Class != 0 { + return asn1.StructuralError{Msg: "bad SAN sequence"} + } + + rest = seq.Bytes + for len(rest) > 0 { + var v asn1.RawValue + rest, err = asn1.Unmarshal(rest, &v) + if err != nil { + return err + } + + if err := callback(v.Tag, v.Bytes); err != nil { + return err + } + } + + return nil +} + +// domainToReverseLabels converts a textual domain name like foo.example.com to +// the list of labels in reverse order, e.g. ["com", "example", "foo"]. +func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) { + for len(domain) > 0 { + if i := strings.LastIndexByte(domain, '.'); i == -1 { + reverseLabels = append(reverseLabels, domain) + domain = "" + } else { + reverseLabels = append(reverseLabels, domain[i+1:]) + domain = domain[:i] + } + } + + if len(reverseLabels) > 0 && len(reverseLabels[0]) == 0 { + // An empty label at the end indicates an absolute value. + return nil, false + } + + for _, label := range reverseLabels { + if len(label) == 0 { + // Empty labels are otherwise invalid. + return nil, false + } + + for _, c := range label { + if c < 33 || c > 126 { + // Invalid character. + return nil, false + } + } + } + + return reverseLabels, true +} + +func parseSANExtension(value []byte) (dnsNames, emailAddresses []string, ipAddresses []net.IP, uris []*url.URL, err error) { + err = forEachSAN(value, func(tag int, data []byte) error { + switch tag { + case nameTypeEmail: + emailAddresses = append(emailAddresses, string(data)) + case nameTypeDNS: + dnsNames = append(dnsNames, string(data)) + case nameTypeURI: + uri, err := url.Parse(string(data)) + if err != nil { + return fmt.Errorf("x509: cannot parse URI %q: %s", string(data), err) + } + if len(uri.Host) > 0 { + if _, ok := domainToReverseLabels(uri.Host); !ok { + return fmt.Errorf("x509: cannot parse URI %q: invalid domain", string(data)) + } + } + uris = append(uris, uri) + case nameTypeIP: + switch len(data) { + case net.IPv4len, net.IPv6len: + ipAddresses = append(ipAddresses, data) + default: + return errors.New("x509: cannot parse IP address of length " + strconv.Itoa(len(data))) + } + } + + return nil + }) + + return +} + +// RFC 3279, 2.3 Public Key Algorithms +// +// pkcs-1 OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840) +// rsadsi(113549) pkcs(1) 1 } +// +// rsaEncryption OBJECT IDENTIFIER ::== { pkcs1-1 1 } +// +// id-dsa OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840) +// x9-57(10040) x9cm(4) 1 } +// +// RFC 5480, 2.1.1 Unrestricted Algorithm Identifier and Parameters +// +// id-ecPublicKey OBJECT IDENTIFIER ::= { +// iso(1) member-body(2) us(840) ansi-X9-62(10045) keyType(2) 1 } +var ( + oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} + oidPublicKeyDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1} + oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} +) + +func getPublicKeyAlgorithmFromOID(oid asn1.ObjectIdentifier) x509.PublicKeyAlgorithm { + if oid.Equal(oidPublicKeyECDSA) { + return x509.ECDSA + } + return x509.UnknownPublicKeyAlgorithm +} + +// CheckSignature reports whether the signature on c is valid. +func CheckSignature(c *x509.CertificateRequest) error { + if c.PublicKeyAlgorithm == x509.ECDSA { + pub, ok := c.PublicKey.(*ecdsa.PublicKey) + if ok && strings.EqualFold(P256().Params().Name, pub.Curve.Params().Name) { + return checkSignature(c, pub) + } + } + return c.CheckSignature() +} + +// CheckSignature verifies that signature is a valid signature over signed from +// a crypto.PublicKey. +func checkSignature(c *x509.CertificateRequest, publicKey *ecdsa.PublicKey) (err error) { + signed := c.RawTBSCertificateRequest + ecdsaSig := new(ecdsaSignature) + if rest, err := asn1.Unmarshal(c.Signature, ecdsaSig); err != nil { + return err + } else if len(rest) != 0 { + return errors.New("x509: trailing data after ECDSA signature") + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + return errors.New("x509: ECDSA signature contained zero or negative values") + } + if !VerifyWithSM2(publicKey, nil, signed, ecdsaSig.R, ecdsaSig.S) { + return errors.New("x509: ECDSA verification failure") + } + return +} diff --git a/sm2/x509_test.go b/sm2/x509_test.go index a7dc2ee..fa5b704 100644 --- a/sm2/x509_test.go +++ b/sm2/x509_test.go @@ -3,6 +3,8 @@ package sm2 import ( "crypto/ecdsa" "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" "encoding/asn1" "encoding/base64" "encoding/hex" @@ -24,8 +26,19 @@ MFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAERrsLH25zLm2LIo6tivZM9afLprSX 6TCKAmQJArAO7VOtZyW4PQwfaTsUIF7IXEFG4iI8bNuTQwMykUzLu2ypEA== -----END PUBLIC KEY----- ` + const hashBase64 = `Zsfw9GLu7dnR8tRr3BDk4kFnxIdc8veiKX2gK49LqOA=` const signature = `MEUCIHV5hOCgYzlO4HkrUhct1Cc8BeKmbXNP+ASje5rGOcCYAiEA2XOajXo3/IihtCEJmNpImtWw3uHIy5CX5TIxit7V0gQ=` +const csrFromAli = `-----BEGIN CERTIFICATE REQUEST----- +MIIBaTCCAQ8CAQAwRzELMAkGA1UEBhMCQ04xEzARBgNVBAMMCkNhcmdvU21hcnQx +DzANBgNVBAcMBlpodWhhaTESMBAGA1UECAwJR3Vhbmdkb25nMFkwEwYHKoZIzj0C +AQYIKoEcz1UBgi0DQgAERrsLH25zLm2LIo6tivZM9afLprSX6TCKAmQJArAO7VOt +ZyW4PQwfaTsUIF7IXEFG4iI8bNuTQwMykUzLu2ypEKBmMC4GCSqGSIb3DQEJDjEh +MB8wHQYDVR0OBBYEFA3FO8vT+8qZBfGZa2TRhLRbme+9MDQGCSqGSIb3DQEJDjEn +MCUwIwYDVR0RBBwwGoEYZW1tYW4uc3VuQGNhcmdvc21hcnQuY29tMAoGCCqBHM9V +AYN1A0gAMEUCIDFGOqXaAIBBc1vQNlCXx8QJcKb5C1+NT+Ij6lSbWCgwAiEAlDIk +PUbitEIHvhvdoOmNGzPDV1LgCwGD5OHVO9Kpy08= +-----END CERTIFICATE REQUEST-----` func getPublicKey(pemContent []byte) (interface{}, error) { block, _ := pem.Decode(pemContent) @@ -35,6 +48,40 @@ func getPublicKey(pemContent []byte) (interface{}, error) { return ParsePKIXPublicKey(block.Bytes) } +func parseAndCheckCsr(csrblock []byte) error { + csr, err := ParseCertificateRequest(csrblock) + if err != nil { + return err + } + fmt.Printf("%v\n", csr) + return CheckSignature(csr) +} + +func TestParseCertificateRequest(t *testing.T) { + block, _ := pem.Decode([]byte(csrFromAli)) + if block == nil { + t.Fatal("Failed to parse PEM block") + } + err := parseAndCheckCsr(block.Bytes) + if err != nil { + t.Fatal(err) + } +} + +func TestCreateCertificateRequest(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + names := pkix.Name{CommonName: "TestName"} + var template = x509.CertificateRequest{Subject: names} + csrblock, err := CreateCertificateRequest(rand.Reader, &template, priv) + if err != nil { + t.Fatal(err) + } + err = parseAndCheckCsr(csrblock) + if err != nil { + t.Fatal(err) + } +} + func TestSignByAliVerifyAtLocal(t *testing.T) { var rs = &ecdsaSignature{} dig, err := base64.StdEncoding.DecodeString(signature)