diff --git a/sm2/sm2.go b/sm2/sm2_dsa.go similarity index 61% rename from sm2/sm2.go rename to sm2/sm2_dsa.go index d33d7fb..7db8168 100644 --- a/sm2/sm2.go +++ b/sm2/sm2_dsa.go @@ -1,21 +1,12 @@ -// Package sm2 implements ShangMi(SM) sm2 digital signature, public key encryption and key exchange algorithms. package sm2 -// Further references: -// [NSA]: Suite B implementer's guide to FIPS 186-3 -// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.182.4503&rep=rep1&type=pdf -// [SECG]: SECG, SEC1 -// http://www.secg.org/sec1-v2.pdf -// [GM/T]: SM2 GB/T 32918.2-2016, GB/T 32918.4-2016 -// - import ( "crypto" "crypto/ecdsa" "crypto/elliptic" _subtle "crypto/subtle" "errors" - "fmt" + "hash" "io" "math/big" "sync" @@ -24,91 +15,12 @@ import ( "github.com/emmansun/gmsm/internal/bigmod" "github.com/emmansun/gmsm/internal/randutil" _sm2ec "github.com/emmansun/gmsm/internal/sm2ec" - "github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/sm2/sm2ec" "github.com/emmansun/gmsm/sm3" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" ) -const ( - uncompressed byte = 0x04 - compressed02 byte = 0x02 - compressed03 byte = compressed02 | 0x01 - hybrid06 byte = 0x06 - hybrid07 byte = hybrid06 | 0x01 -) - -// PrivateKey represents an ECDSA SM2 private key. -// It implemented both crypto.Decrypter and crypto.Signer interfaces. -type PrivateKey struct { - ecdsa.PrivateKey - // inverseOfKeyPlus1 is set under inverseOfKeyPlus1Once - inverseOfKeyPlus1 *bigmod.Nat - inverseOfKeyPlus1Once sync.Once -} - -type pointMarshalMode byte - -const ( - //MarshalUncompressed uncompressed mashal mode - MarshalUncompressed pointMarshalMode = iota - //MarshalCompressed compressed mashal mode - MarshalCompressed - //MarshalHybrid hybrid mashal mode - MarshalHybrid -) - -type ciphertextSplicingOrder byte - -const ( - C1C3C2 ciphertextSplicingOrder = iota - C1C2C3 -) - -type ciphertextEncoding byte - -const ( - ENCODING_PLAIN ciphertextEncoding = iota - ENCODING_ASN1 -) - -// EncrypterOpts encryption options -type EncrypterOpts struct { - ciphertextEncoding ciphertextEncoding - pointMarshalMode pointMarshalMode - ciphertextSplicingOrder ciphertextSplicingOrder -} - -// DecrypterOpts decryption options -type DecrypterOpts struct { - ciphertextEncoding ciphertextEncoding - cipherTextSplicingOrder ciphertextSplicingOrder -} - -// NewPlainEncrypterOpts creates a SM2 non-ASN1 encrypter options. -func NewPlainEncrypterOpts(marhsalMode pointMarshalMode, splicingOrder ciphertextSplicingOrder) *EncrypterOpts { - return &EncrypterOpts{ENCODING_PLAIN, marhsalMode, splicingOrder} -} - -// NewPlainDecrypterOpts creates a SM2 non-ASN1 decrypter options. -func NewPlainDecrypterOpts(splicingOrder ciphertextSplicingOrder) *DecrypterOpts { - return &DecrypterOpts{ENCODING_PLAIN, splicingOrder} -} - -func toBytes(curve elliptic.Curve, value *big.Int) []byte { - byteLen := (curve.Params().BitSize + 7) >> 3 - result := make([]byte, byteLen) - value.FillBytes(result) - return result -} - -var defaultEncrypterOpts = &EncrypterOpts{ENCODING_PLAIN, MarshalUncompressed, C1C3C2} - -var ASN1EncrypterOpts = &EncrypterOpts{ENCODING_ASN1, MarshalUncompressed, C1C3C2} - -var ASN1DecrypterOpts = &DecrypterOpts{ENCODING_ASN1, C1C3C2} - // directSigning is a standard Hash value that signals that no pre-hashing // should be performed. var directSigning crypto.Hash = 0 @@ -118,7 +30,7 @@ type Signer interface { SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, error) } -// SM2SignerOption implements crypto.SignerOpts interface. +// SM2SignerOption implements crypto.SignerOpts interface and is used for SM2-specific signing options. // It is specific for SM2, used in private key's Sign method. type SM2SignerOption struct { uid []byte @@ -146,11 +58,28 @@ func (*SM2SignerOption) HashFunc() crypto.Hash { return directSigning } +var ( + errInvalidPrivateKey = errors.New("sm2: invalid private key") + errInvalidPublicKey = errors.New("sm2: invalid public key") +) + +// PrivateKey represents an ECDSA SM2 private key. +// It embeds ecdsa.PrivateKey and includes additional fields for SM2-specific operations. +// It implements both crypto.Decrypter and crypto.Signer interfaces. +type PrivateKey struct { + ecdsa.PrivateKey + // inverseOfKeyPlus1 stores the modular inverse of (private key + 1) modulo the curve order. + // It is computed lazily and cached using sync.Once to ensure it is only calculated once. + inverseOfKeyPlus1 *bigmod.Nat + inverseOfKeyPlus1Once sync.Once +} + // FromECPrivateKey convert an ecdsa private key to SM2 private key. func (priv *PrivateKey) FromECPrivateKey(key *ecdsa.PrivateKey) (*PrivateKey, error) { if key.Curve != sm2ec.P256() { - return nil, errors.New("sm2: it's NOT a sm2 curve private key") + return nil, errors.New("sm2: not an SM2 curve private key") } + // Copy the ECDSA private key fields to the SM2 private key priv.PrivateKey = *key return priv, nil } @@ -160,13 +89,7 @@ func (priv *PrivateKey) Equal(x crypto.PrivateKey) bool { if !ok { return false } - return priv.PublicKey.Equal(&xx.PublicKey) && bigIntEqual(priv.D, xx.D) -} - -// bigIntEqual reports whether a and b are equal leaking only their bit length -// through timing side-channels. -func bigIntEqual(a, b *big.Int) bool { - return _subtle.ConstantTimeCompare(a.Bytes(), b.Bytes()) == 1 + return priv.PublicKey.Equal(&xx.PublicKey) && _subtle.ConstantTimeCompare(priv.D.Bytes(), xx.D.Bytes()) == 1 } // Sign signs digest with priv, reading randomness from rand. Compliance with GB/T 32918.2-2016. @@ -186,124 +109,6 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er return priv.Sign(rand, msg, NewSM2SignerOption(true, uid)) } -// Decrypt decrypts ciphertext msg to plaintext. -// The opts argument should be appropriate for the primitive used. -// Compliance with GB/T 32918.4-2016 chapter 7. -func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) { - var sm2Opts *DecrypterOpts - sm2Opts, _ = opts.(*DecrypterOpts) - return decrypt(priv, msg, sm2Opts) -} - -const maxRetryLimit = 100 - -var ( - errCiphertextTooShort = errors.New("sm2: ciphertext too short") -) - -// EncryptASN1 sm2 encrypt and output ASN.1 result, compliance with GB/T 32918.4-2016. -// -// The random parameter is used as a source of entropy to ensure that -// encrypting the same message twice doesn't result in the same ciphertext. -// Most applications should use [crypto/rand.Reader] as random. -func EncryptASN1(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) { - return Encrypt(random, pub, msg, ASN1EncrypterOpts) -} - -// Encrypt sm2 encrypt implementation, compliance with GB/T 32918.4-2016. -// -// The random parameter is used as a source of entropy to ensure that -// encrypting the same message twice doesn't result in the same ciphertext. -// Most applications should use [crypto/rand.Reader] as random. -func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) { - //A3, requirement is to check if h*P is infinite point, h is 1 - if pub.X.Sign() == 0 && pub.Y.Sign() == 0 { - return nil, errors.New("sm2: public key point is the infinity") - } - if len(msg) == 0 { - return nil, nil - } - if opts == nil { - opts = defaultEncrypterOpts - } - switch pub.Curve.Params() { - case P256().Params(): - return encryptSM2EC(p256(), pub, random, msg, opts) - default: - return encryptLegacy(random, pub, msg, opts) - } -} - -func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byte, opts *EncrypterOpts) ([]byte, error) { - Q, err := c.pointFromAffine(pub.X, pub.Y) - if err != nil { - return nil, err - } - var retryCount int = 0 - for { - k, C1, err := randomPoint(c, random, false) - if err != nil { - return nil, err - } - C2, err := Q.ScalarMult(Q, k.Bytes(c.N)) - if err != nil { - return nil, err - } - C2Bytes := C2.Bytes()[1:] - c2 := sm3.Kdf(C2Bytes, len(msg)) - if subtle.ConstantTimeAllZero(c2) == 1 { - retryCount++ - if retryCount > maxRetryLimit { - return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", retryCount) - } - continue - } - //A6, C2 = M + t; - subtle.XORBytes(c2, msg, c2) - - //A7, C3 = hash(x2||M||y2) - md := sm3.New() - md.Write(C2Bytes[:len(C2Bytes)/2]) - md.Write(msg) - md.Write(C2Bytes[len(C2Bytes)/2:]) - c3 := md.Sum(nil) - - if opts.ciphertextEncoding == ENCODING_PLAIN { - return encodingCiphertext(opts, C1, c2, c3) - } - return encodingCiphertextASN1(C1, c2, c3) - } -} - -func encodingCiphertext(opts *EncrypterOpts, C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) { - var c1 []byte - switch opts.pointMarshalMode { - case MarshalCompressed: - c1 = C1.BytesCompressed() - default: - c1 = C1.Bytes() - } - - if opts.ciphertextSplicingOrder == C1C3C2 { - // c1 || c3 || c2 - return append(append(c1, c3...), c2...), nil - } - // c1 || c2 || c3 - return append(append(c1, c2...), c3...), nil -} - -func encodingCiphertextASN1(C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) { - c1 := C1.Bytes() - var b cryptobyte.Builder - b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { - addASN1IntBytes(b, c1[1:len(c1)/2+1]) - addASN1IntBytes(b, c1[len(c1)/2+1:]) - b.AddASN1OctetString(c3) - b.AddASN1OctetString(c2) - }) - return b.Bytes() -} - // GenerateKey generates a new SM2 private key. // // Most applications should use [crypto/rand.Reader] as rand. Note that the @@ -358,23 +163,27 @@ func NewPrivateKey(key []byte) (*PrivateKey, error) { return priv, nil } -// NewPrivateKeyFromInt checks that key is valid and returns a SM2 PrivateKey. +// NewPrivateKeyFromInt creates a new SM2 private key from a given big integer. +// It returns an error if the provided key is nil. func NewPrivateKeyFromInt(key *big.Int) (*PrivateKey, error) { if key == nil { - return nil, errors.New("sm2: invalid private key size") + return nil, errors.New("sm2: private key is nil") } keyBytes := make([]byte, p256().N.Size()) return NewPrivateKey(key.FillBytes(keyBytes)) } -// NewPublicKey checks that key is valid and returns a PublicKey. +// NewPublicKey checks that the provided key is valid and returns an SM2 PublicKey. // -// According GB/T 32918.1-2016, the private key must be in [1, n-2]. +// The key parameter is a byte slice representing the public key in uncompressed format. +// According to GB/T 32918.1-2016, the public key must be in the correct format and on the curve. func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) { c := p256() // Reject the point at infinity and compressed encodings. + // Points at infinity are invalid because they do not represent a valid point on the curve. + // Compressed encodings are not supported by this implementation, so they are also rejected. if len(key) == 0 || key[0] != 4 { - return nil, errors.New("sm2: invalid public key") + return nil, errInvalidPublicKey } // SetBytes also checks that the point is on the curve. p, err := c.newPoint().SetBytes(key) @@ -390,138 +199,6 @@ func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) { return k, nil } -// Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2}. -// Compliance with GB/T 32918.4-2016. -func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { - return decrypt(priv, ciphertext, nil) -} - -// ErrDecryption represents a failure to decrypt a message. -// It is deliberately vague to avoid adaptive attacks. -var ErrDecryption = errors.New("sm2: decryption error") - -func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) { - ciphertextLen := len(ciphertext) - if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size { - return nil, errCiphertextTooShort - } - switch priv.Curve.Params() { - case P256().Params(): - return decryptSM2EC(p256(), priv, ciphertext, opts) - default: - return decryptLegacy(priv, ciphertext, opts) - } -} - -func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) { - C1, c2, c3, err := parseCiphertext(c, ciphertext, opts) - if err != nil { - return nil, ErrDecryption - } - d, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N) - if err != nil { - return nil, ErrDecryption - } - - C2, err := C1.ScalarMult(C1, d.Bytes(c.N)) - if err != nil { - return nil, ErrDecryption - } - C2Bytes := C2.Bytes()[1:] - msgLen := len(c2) - msg := sm3.Kdf(C2Bytes, msgLen) - if subtle.ConstantTimeAllZero(c2) == 1 { - return nil, ErrDecryption - } - - //B5, calculate msg = c2 ^ t - subtle.XORBytes(msg, c2, msg) - - md := sm3.New() - md.Write(C2Bytes[:len(C2Bytes)/2]) - md.Write(msg) - md.Write(C2Bytes[len(C2Bytes)/2:]) - u := md.Sum(nil) - - if _subtle.ConstantTimeCompare(u, c3) == 1 { - return msg, nil - } - return nil, ErrDecryption -} - -func parseCiphertext(c *sm2Curve, ciphertext []byte, opts *DecrypterOpts) (*_sm2ec.SM2P256Point, []byte, []byte, error) { - bitSize := c.curve.Params().BitSize - // Encode the coordinates and let SetBytes reject invalid points. - byteLen := (bitSize + 7) / 8 - splicingOrder := C1C3C2 - if opts != nil { - splicingOrder = opts.cipherTextSplicingOrder - } - - b := ciphertext[0] - switch b { - case uncompressed: - if len(ciphertext) <= 1+2*byteLen+sm3.Size { - return nil, nil, nil, errCiphertextTooShort - } - C1, err := c.newPoint().SetBytes(ciphertext[:1+2*byteLen]) - if err != nil { - return nil, nil, nil, err - } - c2, c3 := parseCiphertextC2C3(ciphertext[1+2*byteLen:], splicingOrder) - return C1, c2, c3, nil - case compressed02, compressed03: - C1, err := c.newPoint().SetBytes(ciphertext[:1+byteLen]) - if err != nil { - return nil, nil, nil, err - } - c2, c3 := parseCiphertextC2C3(ciphertext[1+byteLen:], splicingOrder) - return C1, c2, c3, nil - case byte(0x30): - return parseCiphertextASN1(c, ciphertext) - default: - return nil, nil, nil, errors.New("sm2: invalid/unsupport ciphertext format") - } -} - -func parseCiphertextC2C3(ciphertext []byte, order ciphertextSplicingOrder) ([]byte, []byte) { - if order == C1C3C2 { - return ciphertext[sm3.Size:], ciphertext[:sm3.Size] - } - return ciphertext[:len(ciphertext)-sm3.Size], ciphertext[len(ciphertext)-sm3.Size:] -} - -func unmarshalASN1Ciphertext(ciphertext []byte) (*big.Int, *big.Int, []byte, []byte, error) { - var ( - x1, y1 = &big.Int{}, &big.Int{} - c2, c3 []byte - inner cryptobyte.String - ) - input := cryptobyte.String(ciphertext) - if !input.ReadASN1(&inner, asn1.SEQUENCE) || - !input.Empty() || - !inner.ReadASN1Integer(x1) || - !inner.ReadASN1Integer(y1) || - !inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) || - !inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) || - !inner.Empty() { - return nil, nil, nil, nil, errors.New("sm2: invalid asn1 format ciphertext") - } - return x1, y1, c2, c3, nil -} - -func parseCiphertextASN1(c *sm2Curve, ciphertext []byte) (*_sm2ec.SM2P256Point, []byte, []byte, error) { - x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext) - if err != nil { - return nil, nil, nil, err - } - C1, err := c.pointFromAffine(x1, y1) - if err != nil { - return nil, nil, nil, err - } - return C1, c2, c3, nil -} - var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38} // CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA). @@ -530,29 +207,53 @@ var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x // This function will not use default UID even the uid argument is empty. func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) { uidLen := len(uid) - if uidLen >= 0x2000 { + if uidLen > 0x1fff { return nil, errors.New("sm2: the uid is too long") } - entla := uint16(uidLen) << 3 + uidBitLength := uint16(uidLen) << 3 md := sm3.New() - md.Write([]byte{byte(entla >> 8), byte(entla)}) + md.Write([]byte{byte(uidBitLength >> 8), byte(uidBitLength)}) if uidLen > 0 { md.Write(uid) } - a := new(big.Int).Sub(pub.Params().P, big.NewInt(3)) - md.Write(toBytes(pub.Curve, a)) - md.Write(toBytes(pub.Curve, pub.Params().B)) - md.Write(toBytes(pub.Curve, pub.Params().Gx)) - md.Write(toBytes(pub.Curve, pub.Params().Gy)) - md.Write(toBytes(pub.Curve, pub.X)) - md.Write(toBytes(pub.Curve, pub.Y)) + writeCurveParams(md, pub.Curve) + md.Write(bigIntToBytes(pub.Curve, pub.X)) + md.Write(bigIntToBytes(pub.Curve, pub.Y)) + // Return the calculated ZA value return md.Sum(nil), nil } -// CalculateSM2Hash calculates hash value for data including uid and public key parameters -// according standards. +// writeCurveParams writes the parameters of the given elliptic curve to the provided hash.Hash. +// It writes the following parameters in order: +// - a: P - 3 (where P is the prime specifying the base field of the curve) +// - B: the coefficient B of the curve equation +// - Gx: the x-coordinate of the base point G +// - Gy: the y-coordinate of the base point G // -// uid can be nil, then it will use default uid (1234567812345678) +// Parameters: +// - md: the hash.Hash to write the curve parameters to +// - curve: the elliptic.Curve whose parameters are to be written +func writeCurveParams(md hash.Hash, curve elliptic.Curve) { + a := new(big.Int).Sub(curve.Params().P, big.NewInt(3)) + md.Write(bigIntToBytes(curve, a)) + md.Write(bigIntToBytes(curve, curve.Params().B)) + md.Write(bigIntToBytes(curve, curve.Params().Gx)) + md.Write(bigIntToBytes(curve, curve.Params().Gy)) +} + +// bigIntToBytes converts a big integer value to a byte slice of the appropriate length for the given elliptic curve. +// The byte slice is zero-padded to the left if necessary to match the curve's byte length. +func bigIntToBytes(curve elliptic.Curve, value *big.Int) []byte { + byteLen := (curve.Params().BitSize + 7) >> 3 + byteArray := make([]byte, byteLen) + value.FillBytes(byteArray) + return byteArray +} + +// CalculateSM2Hash calculates the SM2 hash for the given public key, data, and user ID (UID). +// If the UID is not provided, a default UID (1234567812345678) is used. +// The public key must be valid, otherwise will be panic. +// This function is used to calculate the hash value for SM2 signature. func CalculateSM2Hash(pub *ecdsa.PublicKey, data, uid []byte) ([]byte, error) { if len(uid) == 0 { uid = defaultUID @@ -597,21 +298,24 @@ func SignASN1(rand io.Reader, priv *PrivateKey, hash []byte, opts crypto.SignerO } } +// inverseOfPrivateKeyPlus1 calculates and returns the modular inverse of (private key + 1) modulo the curve order. +// It uses lazy initialization and caching to ensure the calculation is performed only once. +// If the private key is invalid, it returns an error. func (priv *PrivateKey) inverseOfPrivateKeyPlus1(c *sm2Curve) (*bigmod.Nat, error) { var ( - err error - dp1Inv, oneNat *bigmod.Nat - dp1Bytes []byte + err error + oneNat = bigmod.NewNat().SetUint(1, c.N) + inverseDPlus1 *bigmod.Nat + dp1Bytes []byte ) priv.inverseOfKeyPlus1Once.Do(func() { - oneNat = bigmod.NewNat().SetUint(1, c.N) - dp1Inv, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N) + inverseDPlus1, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N) if err == nil { - dp1Inv.Add(oneNat, c.N) - if dp1Inv.IsZero() == 1 { // make sure private key is NOT N-1 + inverseDPlus1.Add(oneNat, c.N) + if inverseDPlus1.IsZero() == 1 { // make sure private key is NOT N-1 err = errInvalidPrivateKey } else { - dp1Bytes, err = _sm2ec.P256OrdInverse(dp1Inv.Bytes(c.N)) + dp1Bytes, err = _sm2ec.P256OrdInverse(inverseDPlus1.Bytes(c.N)) if err == nil { priv.inverseOfKeyPlus1, err = bigmod.NewNat().SetBytes(dp1Bytes, c.N) } @@ -624,9 +328,27 @@ func (priv *PrivateKey) inverseOfPrivateKeyPlus1(c *sm2Curve) (*bigmod.Nat, erro return priv.inverseOfKeyPlus1, nil } +// signSM2EC generates an SM2 digital signature using the provided private key and hash. +// It follows the SM2 signature algorithm as specified in the Chinese cryptographic standards. +// +// Parameters: +// - c: A pointer to the sm2Curve structure representing the elliptic curve parameters. +// - priv: A pointer to the PrivateKey structure containing the private key for signing. +// - rand: An io.Reader instance used to generate random values. +// - hash: A byte slice containing the hash of the message to be signed. +// +// Returns: +// - sig: A byte slice containing the generated signature. +// - err: An error value indicating any issues encountered during the signing process. +// +// The function performs the following steps: +// 1. Computes the inverse of (d + 1) where d is the private key. +// 2. Converts the hash to an integer. +// 3. Generates a random point on the elliptic curve and computes the signature components (r, s). +// 4. Ensures that the signature components are non-zero and valid. +// 5. Encodes the signature components into a byte slice and returns it. func signSM2EC(c *sm2Curve, priv *PrivateKey, rand io.Reader, hash []byte) (sig []byte, err error) { - // dp1Inv = (d+1)⁻¹ - dp1Inv, err := priv.inverseOfPrivateKeyPlus1(c) + inverseDPlus1, err := priv.inverseOfPrivateKeyPlus1(c) if err != nil { return nil, err } @@ -675,7 +397,7 @@ func signSM2EC(c *sm2Curve, priv *PrivateKey, rand io.Reader, hash []byte) (sig // k = [k - s] k.Sub(s, c.N) // k = [(d+1)⁻¹ * (k - r * d)] - k.Mul(dp1Inv, c.N) + k.Mul(inverseDPlus1, c.N) if k.IsZero() == 0 { break } @@ -713,91 +435,6 @@ func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) { var ErrInvalidSignature = errors.New("sm2: invalid signature") -// RecoverPublicKeysFromSM2Signature recovers two or four SM2 public keys from a given signature and hash. -// It takes the hash and signature as input and returns the recovered public keys as []*ecdsa.PublicKey. -// If the signature or hash is invalid, it returns an error. -// The function follows the SM2 algorithm to recover the public keys. -func RecoverPublicKeysFromSM2Signature(hash, sig []byte) ([]*ecdsa.PublicKey, error) { - c := p256() - rBytes, sBytes, err := parseSignature(sig) - if err != nil { - return nil, err - } - r, err := bigmod.NewNat().SetBytes(rBytes, c.N) - if err != nil || r.IsZero() == 1 { - return nil, ErrInvalidSignature - } - s, err := bigmod.NewNat().SetBytes(sBytes, c.N) - if err != nil || s.IsZero() == 1 { - return nil, ErrInvalidSignature - } - - e := bigmod.NewNat() - hashToNat(c, e, hash) - - // p₁ = [-s]G - negS := bigmod.NewNat().ExpandFor(c.N).Sub(s, c.N) - p1, err := c.newPoint().ScalarBaseMult(negS.Bytes(c.N)) - if err != nil { - return nil, err - } - - // s = [r + s] - s.Add(r, c.N) - if s.IsZero() == 1 { - return nil, ErrInvalidSignature - } - // sBytes = (r+s)⁻¹ - sBytes, err = _sm2ec.P256OrdInverse(s.Bytes(c.N)) - if err != nil { - return nil, err - } - - // r = (Rx + e) mod N - // Rx = r - e - r.Sub(e, c.N) - if r.IsZero() == 1 { - return nil, ErrInvalidSignature - } - pointRx := make([]*bigmod.Nat, 0, 2) - pointRx = append(pointRx, r) - // check if Rx in (N, P), small probability event - s.Set(r) - s = s.Add(c.N.Nat(), c.P) - if s.CmpGeq(c.N.Nat()) == 1 { - pointRx = append(pointRx, s) - } - pubs := make([]*ecdsa.PublicKey, 0, 4) - bytes := make([]byte, 32+1) - compressFlags := []byte{compressed02, compressed03} - // Rx has one or two possible values, so point R has two or four possible values - for _, x := range pointRx { - rBytes = x.Bytes(c.N) - copy(bytes[1:], rBytes) - for _, flag := range compressFlags { - bytes[0] = flag - // p0 = R - p0, err := c.newPoint().SetBytes(bytes) - if err != nil { - return nil, err - } - // p0 = R - [s]G - p0.Add(p0, p1) - // Pub = [(r + s)⁻¹](R - [s]G) - p0.ScalarMult(p0, sBytes) - pub := new(ecdsa.PublicKey) - pub.Curve = c.curve - pub.X, pub.Y, err = c.pointToAffine(p0) - if err != nil { - return nil, err - } - pubs = append(pubs, pub) - } - } - - return pubs, nil -} - // VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the // public key, pub. Its return value records whether the signature is valid. // @@ -919,7 +556,9 @@ func hashToNat(c *sm2Curve, e *bigmod.Nat, hash []byte) { } } -// IsSM2PublicKey check if given public key is a SM2 public key or not +// IsSM2PublicKey checks if the provided public key is an SM2 public key. +// It takes an interface{} as input and attempts to assert it to an *ecdsa.PublicKey. +// The function returns true if the assertion is successful and the public key's curve is SM2 P-256. func IsSM2PublicKey(publicKey any) bool { pub, ok := publicKey.(*ecdsa.PublicKey) return ok && pub.Curve == sm2ec.P256() @@ -939,7 +578,7 @@ func PublicKeyToECDH(k *ecdsa.PublicKey) (*ecdh.PublicKey, error) { return nil, errors.New("sm2: unsupported curve by ecdh") } if !k.Curve.IsOnCurve(k.X, k.Y) { - return nil, errors.New("sm2: invalid public key") + return nil, errInvalidPublicKey } return c.NewPublicKey(elliptic.Marshal(k.Curve, k.X, k.Y)) } @@ -954,7 +593,7 @@ func (k *PrivateKey) ECDH() (*ecdh.PrivateKey, error) { } size := (k.Curve.Params().N.BitLen() + 7) / 8 if k.D.BitLen() > size*8 { - return nil, errors.New("sm2: invalid private key") + return nil, errInvalidPrivateKey } return c.NewPrivateKey(k.D.FillBytes(make([]byte, size))) } @@ -1011,6 +650,110 @@ func randomPoint(c *sm2Curve, rand io.Reader, checkOrderMinus1 bool) (k *bigmod. // randomPoint rejects a candidate for being higher than the modulus. var testingOnlyRejectionSamplingLooped func() + +// RecoverPublicKeysFromSM2Signature attempts to recover the public keys from an SM2 signature. +// This function takes a hash and a signature as input and returns a slice of possible public keys +// that could have generated the given signature. +// +// Parameters: +// - hash: The hash of the message that was signed. +// - sig: The SM2 signature. +// +// Returns: +// - A slice of pointers to ecdsa.PublicKey, representing the possible public keys. +// - An error if the signature is invalid or if any other error occurs during the recovery process. +// +// The function performs the following steps: +// 1. Parses the signature to extract the r and s values. +// 2. Converts the hash to a big integer (Nat). +// 3. Computes the point p₁ = [-s]G. +// 4. Computes s = [r + s] and its modular inverse. +// 5. Computes the possible x-coordinates (Rx) for the point R. +// 6. For each possible Rx, computes the corresponding point R and derives the public key. +// +// Note: The function handles the case where there are one or two possible values for Rx, +// resulting in two or four possible public keys. +func RecoverPublicKeysFromSM2Signature(hash, sig []byte) ([]*ecdsa.PublicKey, error) { + c := p256() + rBytes, sBytes, err := parseSignature(sig) + if err != nil { + return nil, err + } + r, err := bigmod.NewNat().SetBytes(rBytes, c.N) + if err != nil || r.IsZero() == 1 { + return nil, ErrInvalidSignature + } + s, err := bigmod.NewNat().SetBytes(sBytes, c.N) + if err != nil || s.IsZero() == 1 { + return nil, ErrInvalidSignature + } + + e := bigmod.NewNat() + hashToNat(c, e, hash) + + // p₁ = [-s]G + negS := bigmod.NewNat().ExpandFor(c.N).Sub(s, c.N) + p1, err := c.newPoint().ScalarBaseMult(negS.Bytes(c.N)) + if err != nil { + return nil, err + } + + // s = [r + s] + s.Add(r, c.N) + if s.IsZero() == 1 { + return nil, ErrInvalidSignature + } + // sBytes = (r+s)⁻¹ + sBytes, err = _sm2ec.P256OrdInverse(s.Bytes(c.N)) + if err != nil { + return nil, err + } + + // r = (Rx + e) mod N + // Rx = r - e + r.Sub(e, c.N) + if r.IsZero() == 1 { + return nil, ErrInvalidSignature + } + pointRx := make([]*bigmod.Nat, 0, 2) + pointRx = append(pointRx, r) + // check if Rx in (N, P), small probability event + s.Set(r) + s = s.Add(c.N.Nat(), c.P) + if s.CmpGeq(c.N.Nat()) == 1 { + pointRx = append(pointRx, s) + } + pubs := make([]*ecdsa.PublicKey, 0, 4) + bytes := make([]byte, 32+1) + compressFlags := []byte{compressed02, compressed03} + // Rx has one or two possible values, so point R has two or four possible values + for _, x := range pointRx { + rBytes = x.Bytes(c.N) + copy(bytes[1:], rBytes) + for _, flag := range compressFlags { + bytes[0] = flag + // p0 = R + p0, err := c.newPoint().SetBytes(bytes) + if err != nil { + return nil, err + } + // p0 = R - [s]G + p0.Add(p0, p1) + // Pub = [(r + s)⁻¹](R - [s]G) + p0.ScalarMult(p0, sBytes) + pub := new(ecdsa.PublicKey) + pub.Curve = c.curve + pub.X, pub.Y, err = c.pointToAffine(p0) + if err != nil { + return nil, err + } + pubs = append(pubs, pub) + } + } + + return pubs, nil +} + type sm2Curve struct { newPoint func() *_sm2ec.SM2P256Point curve elliptic.Curve @@ -1073,5 +816,3 @@ func precomputeParams(c *sm2Curve, curve elliptic.Curve) { c.nMinus1 = c.N.Nat().SubOne(c.N) c.nMinus2 = new(bigmod.Nat).Set(c.nMinus1).SubOne(c.N).Bytes(c.N) } - -var errInvalidPrivateKey = errors.New("sm2: invalid private key") diff --git a/sm2/sm2_test.go b/sm2/sm2_dsa_test.go similarity index 55% rename from sm2/sm2_test.go rename to sm2/sm2_dsa_test.go index 1311324..d57f010 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_dsa_test.go @@ -10,7 +10,6 @@ import ( "encoding/hex" "io" "math/big" - "reflect" "testing" "github.com/emmansun/gmsm/sm3" @@ -109,367 +108,6 @@ func TestNewPublicKey(t *testing.T) { } } -func TestSplicingOrder(t *testing.T) { - priv, _ := GenerateKey(rand.Reader) - tests := []struct { - name string - plainText string - from ciphertextSplicingOrder - to ciphertextSplicingOrder - }{ - // TODO: Add test cases. - {"less than 32 1", "encryption standard", C1C2C3, C1C3C2}, - {"less than 32 2", "encryption standard", C1C3C2, C1C2C3}, - {"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2}, - {"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3}, - {"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2}, - {"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewPlainEncrypterOpts(MarshalUncompressed, tt.from)) - if err != nil { - t.Fatalf("encrypt failed %v", err) - } - plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.from)) - if err != nil { - t.Fatalf("decrypt failed %v", err) - } - if !reflect.DeepEqual(string(plaintext), tt.plainText) { - t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) - } - - //Adjust splicing order - ciphertext, err = AdjustCiphertextSplicingOrder(ciphertext, tt.from, tt.to) - if err != nil { - t.Fatalf("adjust splicing order failed %v", err) - } - plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.to)) - if err != nil { - t.Fatalf("decrypt failed after adjust splicing order %v", err) - } - if !reflect.DeepEqual(string(plaintext), tt.plainText) { - t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) - } - }) - } -} - -func TestEncryptDecryptASN1(t *testing.T) { - priv, _ := GenerateKey(rand.Reader) - priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - key2 := new(PrivateKey) - key2.PrivateKey = *priv2 - tests := []struct { - name string - plainText string - priv *PrivateKey - }{ - // TODO: Add test cases. - {"less than 32", "encryption standard", priv}, - {"equals 32", "encryption standard encryption ", priv}, - {"long than 32", "encryption standard encryption standard", priv}, - {"less than 32", "encryption standard", key2}, - {"equals 32", "encryption standard encryption ", key2}, - {"long than 32", "encryption standard encryption standard", key2}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - encrypterOpts := ASN1EncrypterOpts - ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts) - if err != nil { - t.Fatalf("%v encrypt failed %v", tt.priv.Curve.Params().Name, err) - } - plaintext, err := tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts) - if err != nil { - t.Fatalf("%v decrypt 1 failed %v", tt.priv.Curve.Params().Name, err) - } - if !reflect.DeepEqual(string(plaintext), tt.plainText) { - t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) - } - plaintext, err = tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts) - if err != nil { - t.Fatalf("%v decrypt 2 failed %v", tt.priv.Curve.Params().Name, err) - } - if !reflect.DeepEqual(string(plaintext), tt.plainText) { - t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) - } - }) - } -} - -func TestPlainCiphertext2ASN1(t *testing.T) { - ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05") - _, err := PlainCiphertext2ASN1(append([]byte{0x30}, ciphertext...), C1C3C2) - if err == nil { - t.Fatalf("expected error") - } - _, err = PlainCiphertext2ASN1(ciphertext[:65], C1C3C2) - if err == nil { - t.Fatalf("expected error") - } - ciphertext[0] = 0x10 - _, err = PlainCiphertext2ASN1(ciphertext, C1C3C2) - if err == nil { - t.Fatalf("expected error") - } -} - -func TestAdjustCiphertextSplicingOrder(t *testing.T) { - ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05") - res, err := AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C3C2) - if err != nil || &res[0] != &ciphertext[0] { - t.Fatalf("should be same one") - } - _, err = AdjustCiphertextSplicingOrder(ciphertext[:65], C1C3C2, C1C2C3) - if err == nil { - t.Fatalf("expected error") - } - ciphertext[0] = 0x10 - _, err = AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C2C3) - if err == nil { - t.Fatalf("expected error") - } -} - -func TestCiphertext2ASN1(t *testing.T) { - priv, _ := GenerateKey(rand.Reader) - tests := []struct { - name string - plainText string - }{ - // TODO: Add test cases. - {"less than 32", "encryption standard"}, - {"equals 32", "encryption standard encryption "}, - {"long than 32", "encryption standard encryption standard"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ciphertext1, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil) - if err != nil { - t.Fatalf("encrypt failed %v", err) - } - - ciphertext, err := PlainCiphertext2ASN1(ciphertext1, C1C3C2) - if err != nil { - t.Fatalf("convert to ASN.1 failed %v", err) - } - plaintext, err := priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts) - if err != nil { - t.Fatalf("decrypt failed %v", err) - } - if !reflect.DeepEqual(string(plaintext), tt.plainText) { - t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) - } - - ciphertext2, err := AdjustCiphertextSplicingOrder(ciphertext1, C1C3C2, C1C2C3) - if err != nil { - t.Fatalf("adjust order failed %v", err) - } - ciphertext, err = PlainCiphertext2ASN1(ciphertext2, C1C2C3) - if err != nil { - t.Fatalf("convert to ASN.1 failed %v", err) - } - plaintext, err = priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts) - if err != nil { - t.Fatalf("decrypt failed %v", err) - } - if !reflect.DeepEqual(string(plaintext), tt.plainText) { - t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) - } - }) - } -} - -func TestCiphertextASN12Plain(t *testing.T) { - priv, _ := GenerateKey(rand.Reader) - tests := []struct { - name string - plainText string - }{ - // TODO: Add test cases. - {"less than 32", "encryption standard"}, - {"equals 32", "encryption standard encryption "}, - {"long than 32", "encryption standard encryption standard"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ciphertext, err := EncryptASN1(rand.Reader, &priv.PublicKey, []byte(tt.plainText)) - if err != nil { - t.Fatalf("encrypt failed %v", err) - } - ciphertext, err = ASN1Ciphertext2Plain(ciphertext, nil) - if err != nil { - t.Fatalf("convert to plain failed %v", err) - } - plaintext, err := priv.Decrypt(rand.Reader, ciphertext, nil) - if err != nil { - t.Fatalf("decrypt failed %v", err) - } - if !reflect.DeepEqual(string(plaintext), tt.plainText) { - t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) - } - }) - } -} - -func TestEncryptWithInfinitePublicKey(t *testing.T) { - pub := new(ecdsa.PublicKey) - pub.Curve = P256() - pub.X = big.NewInt(0) - pub.Y = big.NewInt(0) - - _, err := Encrypt(rand.Reader, pub, []byte("sm2 encryption standard"), nil) - if err == nil { - t.Fatalf("should be failed") - } -} - -func TestEncryptEmptyPlaintext(t *testing.T) { - priv, _ := GenerateKey(rand.Reader) - ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, nil, nil) - if err != nil || ciphertext != nil { - t.Fatalf("nil plaintext should return nil") - } - ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte{}, nil) - if err != nil || ciphertext != nil { - t.Fatalf("empty plaintext should return nil") - } -} - -func TestEncryptDecrypt(t *testing.T) { - priv, _ := GenerateKey(rand.Reader) - priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - key2 := new(PrivateKey) - key2.PrivateKey = *priv2 - tests := []struct { - name string - plainText string - priv *PrivateKey - }{ - // TODO: Add test cases. - {"less than 32", "encryption standard", priv}, - {"equals 32", "encryption standard encryption ", priv}, - {"long than 32", "encryption standard encryption standard", priv}, - {"less than 32", "encryption standard", key2}, - {"equals 32", "encryption standard encryption ", key2}, - {"long than 32", "encryption standard encryption standard", key2}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), nil) - if err != nil { - t.Fatalf("encrypt failed %v", err) - } - plaintext, err := Decrypt(tt.priv, ciphertext) - if err != nil { - t.Fatalf("decrypt failed %v", err) - } - if !reflect.DeepEqual(string(plaintext), tt.plainText) { - t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) - } - // compress mode - encrypterOpts := NewPlainEncrypterOpts(MarshalCompressed, C1C3C2) - ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts) - if err != nil { - t.Fatalf("encrypt failed %v", err) - } - plaintext, err = Decrypt(tt.priv, ciphertext) - if err != nil { - t.Fatalf("decrypt failed %v", err) - } - if !reflect.DeepEqual(string(plaintext), tt.plainText) { - t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) - } - - // hybrid mode - encrypterOpts = NewPlainEncrypterOpts(MarshalHybrid, C1C3C2) - ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts) - if err != nil { - t.Fatalf("encrypt failed %v", err) - } - plaintext, err = Decrypt(tt.priv, ciphertext) - if err != nil { - t.Fatalf("decrypt failed %v", err) - } - if !reflect.DeepEqual(string(plaintext), tt.plainText) { - t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) - } - plaintext, err = Decrypt(tt.priv, ciphertext) - if err != nil { - t.Fatalf("decrypt failed %v", err) - } - if !reflect.DeepEqual(string(plaintext), tt.plainText) { - t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) - } - }) - } -} - -func TestInvalidCiphertext(t *testing.T) { - priv, _ := GenerateKey(rand.Reader) - tests := []struct { - name string - ciphertext []byte - }{ - // TODO: Add test cases. - {errCiphertextTooShort.Error(), make([]byte, 65)}, - {ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 96)...)}, - {ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 97)...)}, - {ErrDecryption.Error(), append([]byte{0x02}, make([]byte, 65)...)}, - {ErrDecryption.Error(), append([]byte{0x30}, make([]byte, 97)...)}, - {ErrDecryption.Error(), make([]byte, 97)}, - } - for i, tt := range tests { - _, err := Decrypt(priv, tt.ciphertext) - if err.Error() != tt.name { - t.Fatalf("case %v, expected %v, got %v\n", i, tt.name, err.Error()) - } - } -} - -func TestPrivateKeyPlus1WithOrderMinus1(t *testing.T) { - priv := new(PrivateKey) - priv.D = new(big.Int).Sub(P256().Params().N, big.NewInt(1)) - priv.Curve = P256() - priv.PublicKey.X, priv.PublicKey.Y = P256().ScalarBaseMult(priv.D.Bytes()) - - _, err := priv.inverseOfPrivateKeyPlus1(p256()) - if err == nil || err != errInvalidPrivateKey { - t.Errorf("expected invalid private key error") - } -} - -func TestSignVerify(t *testing.T) { - priv, _ := GenerateKey(rand.Reader) - tests := []struct { - name string - plainText string - }{ - // TODO: Add test cases. - {"less than 32", "encryption standard"}, - {"equals 32", "encryption standard encryption "}, - {"long than 32", "encryption standard encryption standard"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - hashed := sm3.Sum([]byte(tt.plainText)) - signature, err := priv.Sign(rand.Reader, hashed[:], nil) - if err != nil { - t.Fatalf("sign failed %v", err) - } - result := VerifyASN1(&priv.PublicKey, hashed[:], signature) - if !result { - t.Fatal("verify failed") - } - hashed[0] ^= 0xff - if VerifyASN1(&priv.PublicKey, hashed[:], signature) { - t.Errorf("VerifyASN1 always works!") - } - }) - } -} - func testRecoverPublicKeysFromSM2Signature(t *testing.T, priv *PrivateKey) { tests := []struct { name string @@ -774,6 +412,48 @@ func TestRandomPoint(t *testing.T) { } } +func TestPrivateKeyPlus1WithOrderMinus1(t *testing.T) { + priv := new(PrivateKey) + priv.D = new(big.Int).Sub(P256().Params().N, big.NewInt(1)) + priv.Curve = P256() + priv.PublicKey.X, priv.PublicKey.Y = P256().ScalarBaseMult(priv.D.Bytes()) + + _, err := priv.inverseOfPrivateKeyPlus1(p256()) + if err == nil || err != errInvalidPrivateKey { + t.Errorf("expected invalid private key error") + } +} + +func TestSignVerify(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + tests := []struct { + name string + plainText string + }{ + // TODO: Add test cases. + {"less than 32", "encryption standard"}, + {"equals 32", "encryption standard encryption "}, + {"long than 32", "encryption standard encryption standard"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hashed := sm3.Sum([]byte(tt.plainText)) + signature, err := priv.Sign(rand.Reader, hashed[:], nil) + if err != nil { + t.Fatalf("sign failed %v", err) + } + result := VerifyASN1(&priv.PublicKey, hashed[:], signature) + if !result { + t.Fatal("verify failed") + } + hashed[0] ^= 0xff + if VerifyASN1(&priv.PublicKey, hashed[:], signature) { + t.Errorf("VerifyASN1 always works!") + } + }) + } +} + func BenchmarkGenerateKey_SM2(b *testing.B) { r := bufio.NewReaderSize(rand.Reader, 1<<15) b.ReportAllocs() @@ -894,45 +574,3 @@ func BenchmarkVerify_SM2(b *testing.B) { } } } - -func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext []byte) { - r := bufio.NewReaderSize(rand.Reader, 1<<15) - priv, err := ecdsa.GenerateKey(curve, r) - if err != nil { - b.Fatal(err) - } - b.SetBytes(int64(len(plaintext))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil) - } -} - -func BenchmarkEncryptNoMoreThan32_P256(b *testing.B) { - benchmarkEncrypt(b, elliptic.P256(), make([]byte, 31)) -} - -func BenchmarkEncryptNoMoreThan32_SM2(b *testing.B) { - benchmarkEncrypt(b, P256(), make([]byte, 31)) -} - -func BenchmarkEncrypt128_P256(b *testing.B) { - benchmarkEncrypt(b, elliptic.P256(), make([]byte, 128)) -} - -func BenchmarkEncrypt128_SM2(b *testing.B) { - benchmarkEncrypt(b, P256(), make([]byte, 128)) -} - -func BenchmarkEncrypt512_SM2(b *testing.B) { - benchmarkEncrypt(b, P256(), make([]byte, 512)) -} - -func BenchmarkEncrypt1K_SM2(b *testing.B) { - benchmarkEncrypt(b, P256(), make([]byte, 1024)) -} - -func BenchmarkEncrypt8K_SM2(b *testing.B) { - benchmarkEncrypt(b, P256(), make([]byte, 8*1024)) -} diff --git a/sm2/sm2_envelopedkey.go b/sm2/sm2_envelopedkey.go index 59ce3f3..17ae09a 100644 --- a/sm2/sm2_envelopedkey.go +++ b/sm2/sm2_envelopedkey.go @@ -79,8 +79,26 @@ func MarshalEnvelopedPrivateKey(rand io.Reader, pub *ecdsa.PublicKey, tobeEnvelo return b.Bytes() } -// ParseEnvelopedPrivateKey, parses and decrypts the enveloped SM2 private key. -// This methed just supports SM4 cipher now. +// ParseEnvelopedPrivateKey parses an enveloped private key using the provided private key. +// The enveloped key is expected to be in ASN.1 format and encrypted with a symmetric cipher. +// +// Parameters: +// - priv: The private key used to decrypt the symmetric key. +// - enveloped: The ASN.1 encoded and encrypted enveloped private key. +// +// Returns: +// - A pointer to the decrypted PrivateKey. +// - An error if the parsing or decryption fails. +// +// The function performs the following steps: +// 1. Unmarshals the ASN.1 data to extract the symmetric algorithm identifier, encrypted symmetric key, public key, and encrypted private key. +// 2. Verifies that the symmetric algorithm is supported (SM4 or SM4ECB). +// 3. Parses the public key from the ASN.1 data. +// 4. Decrypts the symmetric key using the provided private key. +// 5. Decrypts the SM2 private key using the decrypted symmetric key. +// 6. Verifies that the decrypted private key matches the public key. +// +// Errors are returned if any of the steps fail, including invalid ASN.1 format, unsupported symmetric cipher, decryption failures, or key mismatches. func ParseEnvelopedPrivateKey(priv *PrivateKey, enveloped []byte) (*PrivateKey, error) { // unmarshal the asn.1 data var ( diff --git a/sm2/sm2_keyexchange.go b/sm2/sm2_keyexchange.go index 7831d26..6b9d331 100644 --- a/sm2/sm2_keyexchange.go +++ b/sm2/sm2_keyexchange.go @@ -149,34 +149,34 @@ func (ke *KeyExchange) InitKeyExchange(rand io.Reader) (*ecdsa.PublicKey, error) func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte { var buffer []byte hash := sm3.New() - hash.Write(toBytes(ke.privateKey, ke.v.X)) + hash.Write(bigIntToBytes(ke.privateKey, ke.v.X)) if isResponder { hash.Write(ke.peerZ) hash.Write(ke.z) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.X)) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y)) - hash.Write(toBytes(ke.privateKey, ke.secret.X)) - hash.Write(toBytes(ke.privateKey, ke.secret.Y)) + hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.X)) + hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.Y)) + hash.Write(bigIntToBytes(ke.privateKey, ke.secret.X)) + hash.Write(bigIntToBytes(ke.privateKey, ke.secret.Y)) } else { hash.Write(ke.z) hash.Write(ke.peerZ) - hash.Write(toBytes(ke.privateKey, ke.secret.X)) - hash.Write(toBytes(ke.privateKey, ke.secret.Y)) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.X)) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y)) + hash.Write(bigIntToBytes(ke.privateKey, ke.secret.X)) + hash.Write(bigIntToBytes(ke.privateKey, ke.secret.Y)) + hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.X)) + hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.Y)) } buffer = hash.Sum(nil) hash.Reset() hash.Write([]byte{prefix}) - hash.Write(toBytes(ke.privateKey, ke.v.Y)) + hash.Write(bigIntToBytes(ke.privateKey, ke.v.Y)) hash.Write(buffer) return hash.Sum(nil) } func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) { var buffer []byte - buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...) - buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...) + buffer = append(buffer, bigIntToBytes(ke.privateKey, ke.v.X)...) + buffer = append(buffer, bigIntToBytes(ke.privateKey, ke.v.Y)...) if isResponder { buffer = append(buffer, ke.peerZ...) buffer = append(buffer, ke.z...) diff --git a/sm2/sm2_keyexchange_sample_test.go b/sm2/sm2_keyexchange_sample_test.go index 6ef7078..a649623 100644 --- a/sm2/sm2_keyexchange_sample_test.go +++ b/sm2/sm2_keyexchange_sample_test.go @@ -301,12 +301,9 @@ func calculateSampleZA(pub *ecdsa.PublicKey, a *big.Int, uid []byte) ([]byte, er if uidLen > 0 { md.Write(uid) } - md.Write(toBytes(pub.Curve, a)) - md.Write(toBytes(pub.Curve, pub.Params().B)) - md.Write(toBytes(pub.Curve, pub.Params().Gx)) - md.Write(toBytes(pub.Curve, pub.Params().Gy)) - md.Write(toBytes(pub.Curve, pub.X)) - md.Write(toBytes(pub.Curve, pub.Y)) + writeCurveParams(md, pub.Curve.Params()) + md.Write(bigIntToBytes(pub.Curve, pub.X)) + md.Write(bigIntToBytes(pub.Curve, pub.Y)) return md.Sum(nil), nil } diff --git a/sm2/sm2_legacy.go b/sm2/sm2_legacy.go index 6b1ace9..b854cfd 100644 --- a/sm2/sm2_legacy.go +++ b/sm2/sm2_legacy.go @@ -259,7 +259,7 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes()) //A5, calculate t=KDF(x2||y2, klen) - c2 := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + c2 := sm3.Kdf(append(bigIntToBytes(curve, x2), bigIntToBytes(curve, y2)...), msgLen) if subtle.ConstantTimeAllZero(c2) == 1 { retryCount++ if retryCount > maxRetryLimit { @@ -289,9 +289,9 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte { md := sm3.New() - md.Write(toBytes(curve, x2)) + md.Write(bigIntToBytes(curve, x2)) md.Write(msg) - md.Write(toBytes(curve, y2)) + md.Write(bigIntToBytes(curve, y2)) return md.Sum(nil) } @@ -306,95 +306,6 @@ func mashalASN1Ciphertext(x1, y1 *big.Int, c2, c3 []byte) ([]byte, error) { return b.Bytes() } -// ASN1Ciphertext2Plain utility method to convert ASN.1 encoding ciphertext to plain encoding format -func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error) { - if opts == nil { - opts = defaultEncrypterOpts - } - x1, y1, c2, c3, err := unmarshalASN1Ciphertext((ciphertext)) - if err != nil { - return nil, err - } - curve := sm2ec.P256() - c1 := opts.pointMarshalMode.mashal(curve, x1, y1) - if opts.ciphertextSplicingOrder == C1C3C2 { - // c1 || c3 || c2 - return append(append(c1, c3...), c2...), nil - } - // c1 || c2 || c3 - return append(append(c1, c2...), c3...), nil -} - -// PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format -func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) { - if ciphertext[0] == 0x30 { - return nil, errors.New("sm2: invalid plain encoding ciphertext") - } - curve := sm2ec.P256() - ciphertextLen := len(ciphertext) - if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size { - return nil, errCiphertextTooShort - } - // get C1, and check C1 - x1, y1, c3Start, err := bytes2Point(curve, ciphertext) - if err != nil { - return nil, err - } - - var c2, c3 []byte - - if from == C1C3C2 { - c2 = ciphertext[c3Start+sm3.Size:] - c3 = ciphertext[c3Start : c3Start+sm3.Size] - } else { - c2 = ciphertext[c3Start : ciphertextLen-sm3.Size] - c3 = ciphertext[ciphertextLen-sm3.Size:] - } - return mashalASN1Ciphertext(x1, y1, c2, c3) -} - -// AdjustCiphertextSplicingOrder utility method to change c2 c3 order -func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicingOrder) ([]byte, error) { - curve := sm2ec.P256() - if from == to { - return ciphertext, nil - } - ciphertextLen := len(ciphertext) - if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size { - return nil, errCiphertextTooShort - } - - // get C1, and check C1 - _, _, c3Start, err := bytes2Point(curve, ciphertext) - if err != nil { - return nil, err - } - - var c1, c2, c3 []byte - - c1 = ciphertext[:c3Start] - if from == C1C3C2 { - c2 = ciphertext[c3Start+sm3.Size:] - c3 = ciphertext[c3Start : c3Start+sm3.Size] - } else { - c2 = ciphertext[c3Start : ciphertextLen-sm3.Size] - c3 = ciphertext[ciphertextLen-sm3.Size:] - } - - result := make([]byte, ciphertextLen) - copy(result, c1) - if to == C1C3C2 { - // c1 || c3 || c2 - copy(result[c3Start:], c3) - copy(result[c3Start+sm3.Size:], c2) - } else { - // c1 || c2 || c3 - copy(result[c3Start:], c2) - copy(result[ciphertextLen-sm3.Size:], c3) - } - return result, nil -} - func decryptASN1(priv *PrivateKey, ciphertext []byte) ([]byte, error) { x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext) if err != nil { @@ -407,7 +318,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error curve := priv.Curve x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) msgLen := len(c2) - msg := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + msg := sm3.Kdf(append(bigIntToBytes(curve, x2), bigIntToBytes(curve, y2)...), msgLen) if subtle.ConstantTimeAllZero(c2) == 1 { return nil, ErrDecryption } @@ -428,7 +339,7 @@ func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([] if opts.ciphertextEncoding == ENCODING_ASN1 { return decryptASN1(priv, ciphertext) } - splicingOrder = opts.cipherTextSplicingOrder + splicingOrder = opts.ciphertextSplicingOrder } if ciphertext[0] == 0x30 { return decryptASN1(priv, ciphertext) @@ -436,7 +347,7 @@ func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([] ciphertextLen := len(ciphertext) curve := priv.Curve // B1, get C1, and check C1 - x1, y1, c3Start, err := bytes2Point(curve, ciphertext) + x1, y1, c3Start, err := bytesToPoint(curve, ciphertext) if err != nil { return nil, ErrDecryption } @@ -454,7 +365,20 @@ func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([] return rawDecrypt(priv, x1, y1, c2, c3) } -func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) { +func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte { + switch mode { + case MarshalCompressed: + return elliptic.MarshalCompressed(curve, x, y) + case MarshalHybrid: + buffer := elliptic.Marshal(curve, x, y) + buffer[0] = byte(y.Bit(0)) | hybrid06 + return buffer + default: + return elliptic.Marshal(curve, x, y) + } +} + +func bytesToPoint(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) { if len(bytes) < 1+(curve.Params().BitSize/8) { return nil, nil, 0, fmt.Errorf("sm2: invalid bytes length %d", len(bytes)) } @@ -486,20 +410,7 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e } return x, y, 1 + byteLen, nil } - return nil, nil, 0, fmt.Errorf("sm2: unsupport point form %d, curve %s", format, curve.Params().Name) + return nil, nil, 0, fmt.Errorf("sm2: unsupported point form %d, curve %s", format, curve.Params().Name) } return nil, nil, 0, fmt.Errorf("sm2: unknown point form %d", format) } - -func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte { - switch mode { - case MarshalCompressed: - return elliptic.MarshalCompressed(curve, x, y) - case MarshalHybrid: - buffer := elliptic.Marshal(curve, x, y) - buffer[0] = byte(y.Bit(0)) | hybrid06 - return buffer - default: - return elliptic.Marshal(curve, x, y) - } -} diff --git a/sm2/sm2_pke.go b/sm2/sm2_pke.go new file mode 100644 index 0000000..eda9ea2 --- /dev/null +++ b/sm2/sm2_pke.go @@ -0,0 +1,396 @@ +// Package sm2 implements ShangMi(SM) sm2 digital signature, public key encryption and key exchange algorithms. +package sm2 + +// Further references: +// [NSA]: Suite B implementer's guide to FIPS 186-3 +// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.182.4503&rep=rep1&type=pdf +// [SECG]: SECG, SEC1 +// http://www.secg.org/sec1-v2.pdf +// [GM/T]: SM2 GB/T 32918.2-2016, GB/T 32918.4-2016 +// + +import ( + "crypto" + "crypto/ecdsa" + _subtle "crypto/subtle" + "errors" + "fmt" + "io" + "math/big" + + "github.com/emmansun/gmsm/internal/bigmod" + _sm2ec "github.com/emmansun/gmsm/internal/sm2ec" + "github.com/emmansun/gmsm/internal/subtle" + "github.com/emmansun/gmsm/sm3" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +const ( + uncompressed byte = 0x04 + compressed02 byte = 0x02 + compressed03 byte = compressed02 | 0x01 + hybrid06 byte = 0x06 + hybrid07 byte = hybrid06 | 0x01 +) + +type pointMarshalMode byte + +const ( + //MarshalUncompressed uncompressed marshal mode + MarshalUncompressed pointMarshalMode = iota + //MarshalCompressed compressed marshal mode + MarshalCompressed + //MarshalHybrid hybrid marshal mode + MarshalHybrid +) + +type ciphertextSplicingOrder byte + +const ( + C1C3C2 ciphertextSplicingOrder = iota + C1C2C3 +) + +// splitC2C3 splits the given ciphertext into two parts, C2 and C3, based on the splicing order. +// If the order is C1C3C2, it returns the first sm3.Size bytes as C3 and the rest as C2. +// Otherwise, it returns the first part as C2 and the last sm3.Size bytes as C3. +func (order ciphertextSplicingOrder) splitC2C3(ciphertext []byte) ([]byte, []byte) { + if order == C1C3C2 { + return ciphertext[sm3.Size:], ciphertext[:sm3.Size] + } + return ciphertext[:len(ciphertext)-sm3.Size], ciphertext[len(ciphertext)-sm3.Size:] +} + +// spliceCiphertext splices the given ciphertext components together based on the splicing order. +func (order ciphertextSplicingOrder) spliceCiphertext(c1, c2, c3 []byte) ([]byte, error) { + switch order { + case C1C3C2: + return append(append(c1, c3...), c2...), nil + case C1C2C3: + return append(append(c1, c2...), c3...), nil + default: + return nil, errors.New("sm2: invalid ciphertext splicing order") + } +} + +type ciphertextEncoding byte + +const ( + ENCODING_PLAIN ciphertextEncoding = iota + ENCODING_ASN1 +) + +// EncrypterOpts represents the options for the SM2 encryption process. +// It includes settings for ciphertext encoding, point marshaling mode, +// and the order in which the ciphertext components are spliced together. +type EncrypterOpts struct { + ciphertextEncoding ciphertextEncoding + pointMarshalMode pointMarshalMode + ciphertextSplicingOrder ciphertextSplicingOrder +} + +// DecrypterOpts represents the options for the decryption process. +// It includes settings for how the ciphertext is encoded and how the +// components of the ciphertext are spliced together. +// +// Fields: +// - ciphertextEncoding: Specifies the encoding format of the ciphertext. +// - ciphertextSplicingOrder: Defines the order in which the components +// of the ciphertext are spliced together. +type DecrypterOpts struct { + ciphertextEncoding ciphertextEncoding + ciphertextSplicingOrder ciphertextSplicingOrder +} + +// NewPlainEncrypterOpts creates a SM2 non-ASN1 encrypter options. +func NewPlainEncrypterOpts(marshalMode pointMarshalMode, splicingOrder ciphertextSplicingOrder) *EncrypterOpts { + return &EncrypterOpts{ENCODING_PLAIN, marshalMode, splicingOrder} +} + +// NewPlainDecrypterOpts creates a SM2 non-ASN1 decrypter options. +func NewPlainDecrypterOpts(splicingOrder ciphertextSplicingOrder) *DecrypterOpts { + return &DecrypterOpts{ENCODING_PLAIN, splicingOrder} +} + +var ( + defaultEncrypterOpts = &EncrypterOpts{ENCODING_PLAIN, MarshalUncompressed, C1C3C2} + + ASN1EncrypterOpts = &EncrypterOpts{ENCODING_ASN1, MarshalUncompressed, C1C3C2} + + ASN1DecrypterOpts = &DecrypterOpts{ENCODING_ASN1, C1C3C2} +) + +const maxRetryLimit = 100 + +var errCiphertextTooShort = errors.New("sm2: ciphertext too short") + +// EncryptASN1 sm2 encrypt and output ASN.1 result, compliance with GB/T 32918.4-2016. +// +// The random parameter is used as a source of entropy to ensure that +// encrypting the same message twice doesn't result in the same ciphertext. +// Most applications should use [crypto/rand.Reader] as random. +func EncryptASN1(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) { + return Encrypt(random, pub, msg, ASN1EncrypterOpts) +} + +// Encrypt sm2 encrypt implementation, compliance with GB/T 32918.4-2016. +// +// The random parameter is used as a source of entropy to ensure that +// encrypting the same message twice doesn't result in the same ciphertext. +// Most applications should use [crypto/rand.Reader] as random. +func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) { + //A3, requirement is to check if h*P is infinite point, h is 1 + if pub.X.Sign() == 0 && pub.Y.Sign() == 0 { + return nil, errors.New("sm2: public key point is the infinity") + } + if len(msg) == 0 { + return nil, nil + } + if opts == nil { + opts = defaultEncrypterOpts + } + switch pub.Curve.Params() { + case P256().Params(): + return encryptSM2EC(p256(), pub, random, msg, opts) + default: + return encryptLegacy(random, pub, msg, opts) + } +} + +func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byte, opts *EncrypterOpts) ([]byte, error) { + Q, err := c.pointFromAffine(pub.X, pub.Y) + if err != nil { + return nil, err + } + retryCount := 0 + for { + k, C1, err := randomPoint(c, random, false) + if err != nil { + return nil, err + } + C2, err := Q.ScalarMult(Q, k.Bytes(c.N)) + if err != nil { + return nil, err + } + C2Bytes := C2.Bytes()[1:] + c2 := sm3.Kdf(C2Bytes, len(msg)) + if subtle.ConstantTimeAllZero(c2) == 1 { + retryCount++ + if retryCount > maxRetryLimit { + return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", retryCount) + } + continue + } + //A6, C2 = M + t; + subtle.XORBytes(c2, msg, c2) + + //A7, C3 = hash(x2||M||y2) + md := sm3.New() + md.Write(C2Bytes[:len(C2Bytes)/2]) + md.Write(msg) + md.Write(C2Bytes[len(C2Bytes)/2:]) + c3 := md.Sum(nil) + + if opts.ciphertextEncoding == ENCODING_PLAIN { + return encodeCiphertext(opts, C1, c2, c3) + } + return encodingCiphertextASN1(C1, c2, c3) + } +} + +func encodeCiphertext(opts *EncrypterOpts, C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) { + var c1 []byte + switch opts.pointMarshalMode { + case MarshalCompressed: + c1 = C1.BytesCompressed() + default: + c1 = C1.Bytes() + } + return opts.ciphertextSplicingOrder.spliceCiphertext(c1, c2, c3) +} + +func encodingCiphertextASN1(C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) { + c1 := C1.Bytes() + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + addASN1IntBytes(b, c1[1:len(c1)/2+1]) + addASN1IntBytes(b, c1[len(c1)/2+1:]) + b.AddASN1OctetString(c3) + b.AddASN1OctetString(c2) + }) + return b.Bytes() +} + +// Decrypt decrypts ciphertext msg to plaintext. +// The opts argument should be appropriate for the primitive used. +// Compliance with GB/T 32918.4-2016 chapter 7. +func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) { + var sm2Opts *DecrypterOpts + sm2Opts, _ = opts.(*DecrypterOpts) + return decrypt(priv, msg, sm2Opts) +} + +// Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2}. +// Compliance with GB/T 32918.4-2016. +func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { + return decrypt(priv, ciphertext, nil) +} + +// ErrDecryption represents a failure to decrypt a message. +// It is deliberately vague to avoid adaptive attacks. +var ErrDecryption = errors.New("sm2: decryption error") + +func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) { + ciphertextLen := len(ciphertext) + if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size { + return nil, errCiphertextTooShort + } + switch priv.Curve.Params() { + case P256().Params(): + return decryptSM2EC(p256(), priv, ciphertext, opts) + default: + return decryptLegacy(priv, ciphertext, opts) + } +} + +func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) { + C1, c2, c3, err := parseCiphertext(c, ciphertext, opts) + if err != nil { + return nil, ErrDecryption + } + d, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N) + if err != nil { + return nil, ErrDecryption + } + + C2, err := C1.ScalarMult(C1, d.Bytes(c.N)) + if err != nil { + return nil, ErrDecryption + } + C2Bytes := C2.Bytes()[1:] + msgLen := len(c2) + msg := sm3.Kdf(C2Bytes, msgLen) + if subtle.ConstantTimeAllZero(c2) == 1 { + return nil, ErrDecryption + } + + //B5, calculate msg = c2 ^ t + subtle.XORBytes(msg, c2, msg) + + md := sm3.New() + md.Write(C2Bytes[:len(C2Bytes)/2]) + md.Write(msg) + md.Write(C2Bytes[len(C2Bytes)/2:]) + u := md.Sum(nil) + + if _subtle.ConstantTimeCompare(u, c3) == 1 { + return msg, nil + } + return nil, ErrDecryption +} + +// parseCiphertext parses the given ciphertext according to the specified SM2 curve and decryption options. +// It returns the parsed SM2 point (C1), the decrypted message (C2), the message digest (C3), and an error if any. +func parseCiphertext(c *sm2Curve, ciphertext []byte, opts *DecrypterOpts) (*_sm2ec.SM2P256Point, []byte, []byte, error) { + bitSize := c.curve.Params().BitSize + byteLen := (bitSize + 7) / 8 + splicingOrder := C1C3C2 + if opts != nil { + splicingOrder = opts.ciphertextSplicingOrder + } + + var ciphertextFormat byte = 0xff // invalid + if len(ciphertext) > 0 { + ciphertextFormat = ciphertext[0] + } + var c1Len int + switch ciphertextFormat { + case byte(asn1.SEQUENCE): + return parseCiphertextASN1(c, ciphertext) + case uncompressed: + c1Len = 1 + 2*byteLen + case compressed02, compressed03: + c1Len = 1 + byteLen + default: + return nil, nil, nil, errors.New("sm2: invalid/unsupported ciphertext format") + } + if len(ciphertext) < c1Len+sm3.Size { + return nil, nil, nil, errCiphertextTooShort + } + C1, err := c.newPoint().SetBytes(ciphertext[:c1Len]) + if err != nil { + return nil, nil, nil, err + } + c2, c3 := splicingOrder.splitC2C3(ciphertext[c1Len:]) + return C1, c2, c3, nil +} + +func unmarshalASN1Ciphertext(ciphertext []byte) (*big.Int, *big.Int, []byte, []byte, error) { + var ( + x1, y1 = &big.Int{}, &big.Int{} + c2, c3 []byte + inner cryptobyte.String + ) + input := cryptobyte.String(ciphertext) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(x1) || + !inner.ReadASN1Integer(y1) || + !inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) || + !inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) || + !inner.Empty() { + return nil, nil, nil, nil, errors.New("sm2: invalid asn1 format ciphertext") + } + return x1, y1, c2, c3, nil +} + +func parseCiphertextASN1(c *sm2Curve, ciphertext []byte) (*_sm2ec.SM2P256Point, []byte, []byte, error) { + x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext) + if err != nil { + return nil, nil, nil, err + } + C1, err := c.pointFromAffine(x1, y1) + if err != nil { + return nil, nil, nil, err + } + return C1, c2, c3, nil +} + +// AdjustCiphertextSplicingOrder utility method to change c2 c3 order +func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicingOrder) ([]byte, error) { + curve := p256() + if from == to { + return ciphertext, nil + } + C1, c2, c3, err := parseCiphertext(curve, ciphertext, NewPlainDecrypterOpts(from)) + if err != nil { + return nil, err + } + opts := NewPlainEncrypterOpts(MarshalUncompressed, to) + if ciphertext[0] == compressed02 || ciphertext[0] == compressed03 { + opts.pointMarshalMode = MarshalCompressed + } + return encodeCiphertext(opts, C1, c2, c3) +} + +// ASN1Ciphertext2Plain utility method to convert ASN.1 encoding ciphertext to plain encoding format +func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error) { + if opts == nil { + opts = defaultEncrypterOpts + } + C1, c2, c3, err := parseCiphertextASN1(p256(), ciphertext) + if err != nil { + return nil, err + } + return encodeCiphertext(opts, C1, c2, c3) +} + +// PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format +func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) { + C1, c2, c3, err := parseCiphertext(p256(), ciphertext, NewPlainDecrypterOpts(from)) + if err != nil { + return nil, err + } + return encodingCiphertextASN1(C1, c2, c3) +} diff --git a/sm2/sm2_pke_test.go b/sm2/sm2_pke_test.go new file mode 100644 index 0000000..4ee6d78 --- /dev/null +++ b/sm2/sm2_pke_test.go @@ -0,0 +1,374 @@ +package sm2 + +import ( + "bufio" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/hex" + "math/big" + "reflect" + "testing" +) + +func TestSplicingOrder(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + tests := []struct { + name string + plainText string + from ciphertextSplicingOrder + to ciphertextSplicingOrder + }{ + // TODO: Add test cases. + {"less than 32 1", "encryption standard", C1C2C3, C1C3C2}, + {"less than 32 2", "encryption standard", C1C3C2, C1C2C3}, + {"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2}, + {"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3}, + {"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2}, + {"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewPlainEncrypterOpts(MarshalUncompressed, tt.from)) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.from)) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + + //Adjust splicing order + ciphertext, err = AdjustCiphertextSplicingOrder(ciphertext, tt.from, tt.to) + if err != nil { + t.Fatalf("adjust splicing order failed %v", err) + } + plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.to)) + if err != nil { + t.Fatalf("decrypt failed after adjust splicing order %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + }) + } +} + +func TestEncryptDecryptASN1(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + key2 := new(PrivateKey) + key2.PrivateKey = *priv2 + tests := []struct { + name string + plainText string + priv *PrivateKey + }{ + // TODO: Add test cases. + {"less than 32", "encryption standard", priv}, + {"equals 32", "encryption standard encryption ", priv}, + {"long than 32", "encryption standard encryption standard", priv}, + {"less than 32", "encryption standard", key2}, + {"equals 32", "encryption standard encryption ", key2}, + {"long than 32", "encryption standard encryption standard", key2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encrypterOpts := ASN1EncrypterOpts + ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts) + if err != nil { + t.Fatalf("%v encrypt failed %v", tt.priv.Curve.Params().Name, err) + } + plaintext, err := tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts) + if err != nil { + t.Fatalf("%v decrypt 1 failed %v", tt.priv.Curve.Params().Name, err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + plaintext, err = tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts) + if err != nil { + t.Fatalf("%v decrypt 2 failed %v", tt.priv.Curve.Params().Name, err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + }) + } +} + +func TestPlainCiphertext2ASN1(t *testing.T) { + ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05") + _, err := PlainCiphertext2ASN1(append([]byte{0x30}, ciphertext...), C1C3C2) + if err == nil { + t.Fatalf("expected error") + } + _, err = PlainCiphertext2ASN1(ciphertext[:65], C1C3C2) + if err == nil { + t.Fatalf("expected error") + } + ciphertext[0] = 0x10 + _, err = PlainCiphertext2ASN1(ciphertext, C1C3C2) + if err == nil { + t.Fatalf("expected error") + } +} + +func TestAdjustCiphertextSplicingOrder(t *testing.T) { + ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05") + res, err := AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C3C2) + if err != nil || &res[0] != &ciphertext[0] { + t.Fatalf("should be same one") + } + _, err = AdjustCiphertextSplicingOrder(ciphertext[:65], C1C3C2, C1C2C3) + if err == nil { + t.Fatalf("expected error") + } + ciphertext[0] = 0x10 + _, err = AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C2C3) + if err == nil { + t.Fatalf("expected error") + } +} + +func TestCiphertext2ASN1(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + tests := []struct { + name string + plainText string + }{ + // TODO: Add test cases. + {"less than 32", "encryption standard"}, + {"equals 32", "encryption standard encryption "}, + {"long than 32", "encryption standard encryption standard"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ciphertext1, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + + ciphertext, err := PlainCiphertext2ASN1(ciphertext1, C1C3C2) + if err != nil { + t.Fatalf("convert to ASN.1 failed %v", err) + } + plaintext, err := priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + + ciphertext2, err := AdjustCiphertextSplicingOrder(ciphertext1, C1C3C2, C1C2C3) + if err != nil { + t.Fatalf("adjust order failed %v", err) + } + ciphertext, err = PlainCiphertext2ASN1(ciphertext2, C1C2C3) + if err != nil { + t.Fatalf("convert to ASN.1 failed %v", err) + } + plaintext, err = priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + }) + } +} + +func TestCiphertextASN12Plain(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + tests := []struct { + name string + plainText string + }{ + // TODO: Add test cases. + {"less than 32", "encryption standard"}, + {"equals 32", "encryption standard encryption "}, + {"long than 32", "encryption standard encryption standard"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ciphertext, err := EncryptASN1(rand.Reader, &priv.PublicKey, []byte(tt.plainText)) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + ciphertext, err = ASN1Ciphertext2Plain(ciphertext, nil) + if err != nil { + t.Fatalf("convert to plain failed %v", err) + } + plaintext, err := priv.Decrypt(rand.Reader, ciphertext, nil) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + }) + } +} + +func TestEncryptWithInfinitePublicKey(t *testing.T) { + pub := new(ecdsa.PublicKey) + pub.Curve = P256() + pub.X = big.NewInt(0) + pub.Y = big.NewInt(0) + + _, err := Encrypt(rand.Reader, pub, []byte("sm2 encryption standard"), nil) + if err == nil { + t.Fatalf("should be failed") + } +} + +func TestEncryptEmptyPlaintext(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, nil, nil) + if err != nil || ciphertext != nil { + t.Fatalf("nil plaintext should return nil") + } + ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte{}, nil) + if err != nil || ciphertext != nil { + t.Fatalf("empty plaintext should return nil") + } +} + +func TestEncryptDecrypt(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + key2 := new(PrivateKey) + key2.PrivateKey = *priv2 + tests := []struct { + name string + plainText string + priv *PrivateKey + }{ + // TODO: Add test cases. + {"less than 32", "encryption standard", priv}, + {"equals 32", "encryption standard encryption ", priv}, + {"long than 32", "encryption standard encryption standard", priv}, + {"less than 32", "encryption standard", key2}, + {"equals 32", "encryption standard encryption ", key2}, + {"long than 32", "encryption standard encryption standard", key2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), nil) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + plaintext, err := Decrypt(tt.priv, ciphertext) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + // compress mode + encrypterOpts := NewPlainEncrypterOpts(MarshalCompressed, C1C3C2) + ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + plaintext, err = Decrypt(tt.priv, ciphertext) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + + // hybrid mode + encrypterOpts = NewPlainEncrypterOpts(MarshalHybrid, C1C3C2) + ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + plaintext, err = Decrypt(tt.priv, ciphertext) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + plaintext, err = Decrypt(tt.priv, ciphertext) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + }) + } +} + +func TestInvalidCiphertext(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + tests := []struct { + name string + ciphertext []byte + }{ + // TODO: Add test cases. + {errCiphertextTooShort.Error(), nil}, + {errCiphertextTooShort.Error(), make([]byte, 65)}, + {ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 96)...)}, + {ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 97)...)}, + {ErrDecryption.Error(), append([]byte{0x02}, make([]byte, 65)...)}, + {ErrDecryption.Error(), append([]byte{0x30}, make([]byte, 97)...)}, + {ErrDecryption.Error(), make([]byte, 97)}, + } + for i, tt := range tests { + _, err := Decrypt(priv, tt.ciphertext) + if err.Error() != tt.name { + t.Fatalf("case %v, expected %v, got %v\n", i, tt.name, err.Error()) + } + } +} + +func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext []byte) { + r := bufio.NewReaderSize(rand.Reader, 1<<15) + priv, err := ecdsa.GenerateKey(curve, r) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(plaintext))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil) + } +} + +func BenchmarkEncryptNoMoreThan32_P256(b *testing.B) { + benchmarkEncrypt(b, elliptic.P256(), make([]byte, 31)) +} + +func BenchmarkEncryptNoMoreThan32_SM2(b *testing.B) { + benchmarkEncrypt(b, P256(), make([]byte, 31)) +} + +func BenchmarkEncrypt128_P256(b *testing.B) { + benchmarkEncrypt(b, elliptic.P256(), make([]byte, 128)) +} + +func BenchmarkEncrypt128_SM2(b *testing.B) { + benchmarkEncrypt(b, P256(), make([]byte, 128)) +} + +func BenchmarkEncrypt512_SM2(b *testing.B) { + benchmarkEncrypt(b, P256(), make([]byte, 512)) +} + +func BenchmarkEncrypt1K_SM2(b *testing.B) { + benchmarkEncrypt(b, P256(), make([]byte, 1024)) +} + +func BenchmarkEncrypt8K_SM2(b *testing.B) { + benchmarkEncrypt(b, P256(), make([]byte, 8*1024)) +}