sm2: code review and refactor

This commit is contained in:
Sun Yimin 2024-12-19 08:17:21 +08:00 committed by GitHub
parent a71e806a2d
commit 89317b8f0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1076 additions and 1001 deletions

View File

@ -1,21 +1,12 @@
// Package sm2 implements ShangMi(SM) sm2 digital signature, public key encryption and key exchange algorithms.
package sm2
// Further references:
// [NSA]: Suite B implementer's guide to FIPS 186-3
// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.182.4503&rep=rep1&type=pdf
// [SECG]: SECG, SEC1
// http://www.secg.org/sec1-v2.pdf
// [GM/T]: SM2 GB/T 32918.2-2016, GB/T 32918.4-2016
//
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
_subtle "crypto/subtle"
"errors"
"fmt"
"hash"
"io"
"math/big"
"sync"
@ -24,91 +15,12 @@ import (
"github.com/emmansun/gmsm/internal/bigmod"
"github.com/emmansun/gmsm/internal/randutil"
_sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
"github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/sm2/sm2ec"
"github.com/emmansun/gmsm/sm3"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
const (
uncompressed byte = 0x04
compressed02 byte = 0x02
compressed03 byte = compressed02 | 0x01
hybrid06 byte = 0x06
hybrid07 byte = hybrid06 | 0x01
)
// PrivateKey represents an ECDSA SM2 private key.
// It implemented both crypto.Decrypter and crypto.Signer interfaces.
type PrivateKey struct {
ecdsa.PrivateKey
// inverseOfKeyPlus1 is set under inverseOfKeyPlus1Once
inverseOfKeyPlus1 *bigmod.Nat
inverseOfKeyPlus1Once sync.Once
}
type pointMarshalMode byte
const (
//MarshalUncompressed uncompressed mashal mode
MarshalUncompressed pointMarshalMode = iota
//MarshalCompressed compressed mashal mode
MarshalCompressed
//MarshalHybrid hybrid mashal mode
MarshalHybrid
)
type ciphertextSplicingOrder byte
const (
C1C3C2 ciphertextSplicingOrder = iota
C1C2C3
)
type ciphertextEncoding byte
const (
ENCODING_PLAIN ciphertextEncoding = iota
ENCODING_ASN1
)
// EncrypterOpts encryption options
type EncrypterOpts struct {
ciphertextEncoding ciphertextEncoding
pointMarshalMode pointMarshalMode
ciphertextSplicingOrder ciphertextSplicingOrder
}
// DecrypterOpts decryption options
type DecrypterOpts struct {
ciphertextEncoding ciphertextEncoding
cipherTextSplicingOrder ciphertextSplicingOrder
}
// NewPlainEncrypterOpts creates a SM2 non-ASN1 encrypter options.
func NewPlainEncrypterOpts(marhsalMode pointMarshalMode, splicingOrder ciphertextSplicingOrder) *EncrypterOpts {
return &EncrypterOpts{ENCODING_PLAIN, marhsalMode, splicingOrder}
}
// NewPlainDecrypterOpts creates a SM2 non-ASN1 decrypter options.
func NewPlainDecrypterOpts(splicingOrder ciphertextSplicingOrder) *DecrypterOpts {
return &DecrypterOpts{ENCODING_PLAIN, splicingOrder}
}
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
byteLen := (curve.Params().BitSize + 7) >> 3
result := make([]byte, byteLen)
value.FillBytes(result)
return result
}
var defaultEncrypterOpts = &EncrypterOpts{ENCODING_PLAIN, MarshalUncompressed, C1C3C2}
var ASN1EncrypterOpts = &EncrypterOpts{ENCODING_ASN1, MarshalUncompressed, C1C3C2}
var ASN1DecrypterOpts = &DecrypterOpts{ENCODING_ASN1, C1C3C2}
// directSigning is a standard Hash value that signals that no pre-hashing
// should be performed.
var directSigning crypto.Hash = 0
@ -118,7 +30,7 @@ type Signer interface {
SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, error)
}
// SM2SignerOption implements crypto.SignerOpts interface.
// SM2SignerOption implements crypto.SignerOpts interface and is used for SM2-specific signing options.
// It is specific for SM2, used in private key's Sign method.
type SM2SignerOption struct {
uid []byte
@ -146,11 +58,28 @@ func (*SM2SignerOption) HashFunc() crypto.Hash {
return directSigning
}
var (
errInvalidPrivateKey = errors.New("sm2: invalid private key")
errInvalidPublicKey = errors.New("sm2: invalid public key")
)
// PrivateKey represents an ECDSA SM2 private key.
// It embeds ecdsa.PrivateKey and includes additional fields for SM2-specific operations.
// It implements both crypto.Decrypter and crypto.Signer interfaces.
type PrivateKey struct {
ecdsa.PrivateKey
// inverseOfKeyPlus1 stores the modular inverse of (private key + 1) modulo the curve order.
// It is computed lazily and cached using sync.Once to ensure it is only calculated once.
inverseOfKeyPlus1 *bigmod.Nat
inverseOfKeyPlus1Once sync.Once
}
// FromECPrivateKey convert an ecdsa private key to SM2 private key.
func (priv *PrivateKey) FromECPrivateKey(key *ecdsa.PrivateKey) (*PrivateKey, error) {
if key.Curve != sm2ec.P256() {
return nil, errors.New("sm2: it's NOT a sm2 curve private key")
return nil, errors.New("sm2: not an SM2 curve private key")
}
// Copy the ECDSA private key fields to the SM2 private key
priv.PrivateKey = *key
return priv, nil
}
@ -160,13 +89,7 @@ func (priv *PrivateKey) Equal(x crypto.PrivateKey) bool {
if !ok {
return false
}
return priv.PublicKey.Equal(&xx.PublicKey) && bigIntEqual(priv.D, xx.D)
}
// bigIntEqual reports whether a and b are equal leaking only their bit length
// through timing side-channels.
func bigIntEqual(a, b *big.Int) bool {
return _subtle.ConstantTimeCompare(a.Bytes(), b.Bytes()) == 1
return priv.PublicKey.Equal(&xx.PublicKey) && _subtle.ConstantTimeCompare(priv.D.Bytes(), xx.D.Bytes()) == 1
}
// Sign signs digest with priv, reading randomness from rand. Compliance with GB/T 32918.2-2016.
@ -186,124 +109,6 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er
return priv.Sign(rand, msg, NewSM2SignerOption(true, uid))
}
// Decrypt decrypts ciphertext msg to plaintext.
// The opts argument should be appropriate for the primitive used.
// Compliance with GB/T 32918.4-2016 chapter 7.
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
var sm2Opts *DecrypterOpts
sm2Opts, _ = opts.(*DecrypterOpts)
return decrypt(priv, msg, sm2Opts)
}
const maxRetryLimit = 100
var (
errCiphertextTooShort = errors.New("sm2: ciphertext too short")
)
// EncryptASN1 sm2 encrypt and output ASN.1 result, compliance with GB/T 32918.4-2016.
//
// The random parameter is used as a source of entropy to ensure that
// encrypting the same message twice doesn't result in the same ciphertext.
// Most applications should use [crypto/rand.Reader] as random.
func EncryptASN1(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) {
return Encrypt(random, pub, msg, ASN1EncrypterOpts)
}
// Encrypt sm2 encrypt implementation, compliance with GB/T 32918.4-2016.
//
// The random parameter is used as a source of entropy to ensure that
// encrypting the same message twice doesn't result in the same ciphertext.
// Most applications should use [crypto/rand.Reader] as random.
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) {
//A3, requirement is to check if h*P is infinite point, h is 1
if pub.X.Sign() == 0 && pub.Y.Sign() == 0 {
return nil, errors.New("sm2: public key point is the infinity")
}
if len(msg) == 0 {
return nil, nil
}
if opts == nil {
opts = defaultEncrypterOpts
}
switch pub.Curve.Params() {
case P256().Params():
return encryptSM2EC(p256(), pub, random, msg, opts)
default:
return encryptLegacy(random, pub, msg, opts)
}
}
func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byte, opts *EncrypterOpts) ([]byte, error) {
Q, err := c.pointFromAffine(pub.X, pub.Y)
if err != nil {
return nil, err
}
var retryCount int = 0
for {
k, C1, err := randomPoint(c, random, false)
if err != nil {
return nil, err
}
C2, err := Q.ScalarMult(Q, k.Bytes(c.N))
if err != nil {
return nil, err
}
C2Bytes := C2.Bytes()[1:]
c2 := sm3.Kdf(C2Bytes, len(msg))
if subtle.ConstantTimeAllZero(c2) == 1 {
retryCount++
if retryCount > maxRetryLimit {
return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", retryCount)
}
continue
}
//A6, C2 = M + t;
subtle.XORBytes(c2, msg, c2)
//A7, C3 = hash(x2||M||y2)
md := sm3.New()
md.Write(C2Bytes[:len(C2Bytes)/2])
md.Write(msg)
md.Write(C2Bytes[len(C2Bytes)/2:])
c3 := md.Sum(nil)
if opts.ciphertextEncoding == ENCODING_PLAIN {
return encodingCiphertext(opts, C1, c2, c3)
}
return encodingCiphertextASN1(C1, c2, c3)
}
}
func encodingCiphertext(opts *EncrypterOpts, C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) {
var c1 []byte
switch opts.pointMarshalMode {
case MarshalCompressed:
c1 = C1.BytesCompressed()
default:
c1 = C1.Bytes()
}
if opts.ciphertextSplicingOrder == C1C3C2 {
// c1 || c3 || c2
return append(append(c1, c3...), c2...), nil
}
// c1 || c2 || c3
return append(append(c1, c2...), c3...), nil
}
func encodingCiphertextASN1(C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) {
c1 := C1.Bytes()
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
addASN1IntBytes(b, c1[1:len(c1)/2+1])
addASN1IntBytes(b, c1[len(c1)/2+1:])
b.AddASN1OctetString(c3)
b.AddASN1OctetString(c2)
})
return b.Bytes()
}
// GenerateKey generates a new SM2 private key.
//
// Most applications should use [crypto/rand.Reader] as rand. Note that the
@ -358,23 +163,27 @@ func NewPrivateKey(key []byte) (*PrivateKey, error) {
return priv, nil
}
// NewPrivateKeyFromInt checks that key is valid and returns a SM2 PrivateKey.
// NewPrivateKeyFromInt creates a new SM2 private key from a given big integer.
// It returns an error if the provided key is nil.
func NewPrivateKeyFromInt(key *big.Int) (*PrivateKey, error) {
if key == nil {
return nil, errors.New("sm2: invalid private key size")
return nil, errors.New("sm2: private key is nil")
}
keyBytes := make([]byte, p256().N.Size())
return NewPrivateKey(key.FillBytes(keyBytes))
}
// NewPublicKey checks that key is valid and returns a PublicKey.
// NewPublicKey checks that the provided key is valid and returns an SM2 PublicKey.
//
// According GB/T 32918.1-2016, the private key must be in [1, n-2].
// The key parameter is a byte slice representing the public key in uncompressed format.
// According to GB/T 32918.1-2016, the public key must be in the correct format and on the curve.
func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) {
c := p256()
// Reject the point at infinity and compressed encodings.
// Points at infinity are invalid because they do not represent a valid point on the curve.
// Compressed encodings are not supported by this implementation, so they are also rejected.
if len(key) == 0 || key[0] != 4 {
return nil, errors.New("sm2: invalid public key")
return nil, errInvalidPublicKey
}
// SetBytes also checks that the point is on the curve.
p, err := c.newPoint().SetBytes(key)
@ -390,138 +199,6 @@ func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) {
return k, nil
}
// Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2}.
// Compliance with GB/T 32918.4-2016.
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
return decrypt(priv, ciphertext, nil)
}
// ErrDecryption represents a failure to decrypt a message.
// It is deliberately vague to avoid adaptive attacks.
var ErrDecryption = errors.New("sm2: decryption error")
func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
return nil, errCiphertextTooShort
}
switch priv.Curve.Params() {
case P256().Params():
return decryptSM2EC(p256(), priv, ciphertext, opts)
default:
return decryptLegacy(priv, ciphertext, opts)
}
}
func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
C1, c2, c3, err := parseCiphertext(c, ciphertext, opts)
if err != nil {
return nil, ErrDecryption
}
d, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
if err != nil {
return nil, ErrDecryption
}
C2, err := C1.ScalarMult(C1, d.Bytes(c.N))
if err != nil {
return nil, ErrDecryption
}
C2Bytes := C2.Bytes()[1:]
msgLen := len(c2)
msg := sm3.Kdf(C2Bytes, msgLen)
if subtle.ConstantTimeAllZero(c2) == 1 {
return nil, ErrDecryption
}
//B5, calculate msg = c2 ^ t
subtle.XORBytes(msg, c2, msg)
md := sm3.New()
md.Write(C2Bytes[:len(C2Bytes)/2])
md.Write(msg)
md.Write(C2Bytes[len(C2Bytes)/2:])
u := md.Sum(nil)
if _subtle.ConstantTimeCompare(u, c3) == 1 {
return msg, nil
}
return nil, ErrDecryption
}
func parseCiphertext(c *sm2Curve, ciphertext []byte, opts *DecrypterOpts) (*_sm2ec.SM2P256Point, []byte, []byte, error) {
bitSize := c.curve.Params().BitSize
// Encode the coordinates and let SetBytes reject invalid points.
byteLen := (bitSize + 7) / 8
splicingOrder := C1C3C2
if opts != nil {
splicingOrder = opts.cipherTextSplicingOrder
}
b := ciphertext[0]
switch b {
case uncompressed:
if len(ciphertext) <= 1+2*byteLen+sm3.Size {
return nil, nil, nil, errCiphertextTooShort
}
C1, err := c.newPoint().SetBytes(ciphertext[:1+2*byteLen])
if err != nil {
return nil, nil, nil, err
}
c2, c3 := parseCiphertextC2C3(ciphertext[1+2*byteLen:], splicingOrder)
return C1, c2, c3, nil
case compressed02, compressed03:
C1, err := c.newPoint().SetBytes(ciphertext[:1+byteLen])
if err != nil {
return nil, nil, nil, err
}
c2, c3 := parseCiphertextC2C3(ciphertext[1+byteLen:], splicingOrder)
return C1, c2, c3, nil
case byte(0x30):
return parseCiphertextASN1(c, ciphertext)
default:
return nil, nil, nil, errors.New("sm2: invalid/unsupport ciphertext format")
}
}
func parseCiphertextC2C3(ciphertext []byte, order ciphertextSplicingOrder) ([]byte, []byte) {
if order == C1C3C2 {
return ciphertext[sm3.Size:], ciphertext[:sm3.Size]
}
return ciphertext[:len(ciphertext)-sm3.Size], ciphertext[len(ciphertext)-sm3.Size:]
}
func unmarshalASN1Ciphertext(ciphertext []byte) (*big.Int, *big.Int, []byte, []byte, error) {
var (
x1, y1 = &big.Int{}, &big.Int{}
c2, c3 []byte
inner cryptobyte.String
)
input := cryptobyte.String(ciphertext)
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
!input.Empty() ||
!inner.ReadASN1Integer(x1) ||
!inner.ReadASN1Integer(y1) ||
!inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) ||
!inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) ||
!inner.Empty() {
return nil, nil, nil, nil, errors.New("sm2: invalid asn1 format ciphertext")
}
return x1, y1, c2, c3, nil
}
func parseCiphertextASN1(c *sm2Curve, ciphertext []byte) (*_sm2ec.SM2P256Point, []byte, []byte, error) {
x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext)
if err != nil {
return nil, nil, nil, err
}
C1, err := c.pointFromAffine(x1, y1)
if err != nil {
return nil, nil, nil, err
}
return C1, c2, c3, nil
}
var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
// CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA).
@ -530,29 +207,53 @@ var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x
// This function will not use default UID even the uid argument is empty.
func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) {
uidLen := len(uid)
if uidLen >= 0x2000 {
if uidLen > 0x1fff {
return nil, errors.New("sm2: the uid is too long")
}
entla := uint16(uidLen) << 3
uidBitLength := uint16(uidLen) << 3
md := sm3.New()
md.Write([]byte{byte(entla >> 8), byte(entla)})
md.Write([]byte{byte(uidBitLength >> 8), byte(uidBitLength)})
if uidLen > 0 {
md.Write(uid)
}
a := new(big.Int).Sub(pub.Params().P, big.NewInt(3))
md.Write(toBytes(pub.Curve, a))
md.Write(toBytes(pub.Curve, pub.Params().B))
md.Write(toBytes(pub.Curve, pub.Params().Gx))
md.Write(toBytes(pub.Curve, pub.Params().Gy))
md.Write(toBytes(pub.Curve, pub.X))
md.Write(toBytes(pub.Curve, pub.Y))
writeCurveParams(md, pub.Curve)
md.Write(bigIntToBytes(pub.Curve, pub.X))
md.Write(bigIntToBytes(pub.Curve, pub.Y))
// Return the calculated ZA value
return md.Sum(nil), nil
}
// CalculateSM2Hash calculates hash value for data including uid and public key parameters
// according standards.
// writeCurveParams writes the parameters of the given elliptic curve to the provided hash.Hash.
// It writes the following parameters in order:
// - a: P - 3 (where P is the prime specifying the base field of the curve)
// - B: the coefficient B of the curve equation
// - Gx: the x-coordinate of the base point G
// - Gy: the y-coordinate of the base point G
//
// uid can be nil, then it will use default uid (1234567812345678)
// Parameters:
// - md: the hash.Hash to write the curve parameters to
// - curve: the elliptic.Curve whose parameters are to be written
func writeCurveParams(md hash.Hash, curve elliptic.Curve) {
a := new(big.Int).Sub(curve.Params().P, big.NewInt(3))
md.Write(bigIntToBytes(curve, a))
md.Write(bigIntToBytes(curve, curve.Params().B))
md.Write(bigIntToBytes(curve, curve.Params().Gx))
md.Write(bigIntToBytes(curve, curve.Params().Gy))
}
// bigIntToBytes converts a big integer value to a byte slice of the appropriate length for the given elliptic curve.
// The byte slice is zero-padded to the left if necessary to match the curve's byte length.
func bigIntToBytes(curve elliptic.Curve, value *big.Int) []byte {
byteLen := (curve.Params().BitSize + 7) >> 3
byteArray := make([]byte, byteLen)
value.FillBytes(byteArray)
return byteArray
}
// CalculateSM2Hash calculates the SM2 hash for the given public key, data, and user ID (UID).
// If the UID is not provided, a default UID (1234567812345678) is used.
// The public key must be valid, otherwise will be panic.
// This function is used to calculate the hash value for SM2 signature.
func CalculateSM2Hash(pub *ecdsa.PublicKey, data, uid []byte) ([]byte, error) {
if len(uid) == 0 {
uid = defaultUID
@ -597,21 +298,24 @@ func SignASN1(rand io.Reader, priv *PrivateKey, hash []byte, opts crypto.SignerO
}
}
// inverseOfPrivateKeyPlus1 calculates and returns the modular inverse of (private key + 1) modulo the curve order.
// It uses lazy initialization and caching to ensure the calculation is performed only once.
// If the private key is invalid, it returns an error.
func (priv *PrivateKey) inverseOfPrivateKeyPlus1(c *sm2Curve) (*bigmod.Nat, error) {
var (
err error
dp1Inv, oneNat *bigmod.Nat
dp1Bytes []byte
err error
oneNat = bigmod.NewNat().SetUint(1, c.N)
inverseDPlus1 *bigmod.Nat
dp1Bytes []byte
)
priv.inverseOfKeyPlus1Once.Do(func() {
oneNat = bigmod.NewNat().SetUint(1, c.N)
dp1Inv, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
inverseDPlus1, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
if err == nil {
dp1Inv.Add(oneNat, c.N)
if dp1Inv.IsZero() == 1 { // make sure private key is NOT N-1
inverseDPlus1.Add(oneNat, c.N)
if inverseDPlus1.IsZero() == 1 { // make sure private key is NOT N-1
err = errInvalidPrivateKey
} else {
dp1Bytes, err = _sm2ec.P256OrdInverse(dp1Inv.Bytes(c.N))
dp1Bytes, err = _sm2ec.P256OrdInverse(inverseDPlus1.Bytes(c.N))
if err == nil {
priv.inverseOfKeyPlus1, err = bigmod.NewNat().SetBytes(dp1Bytes, c.N)
}
@ -624,9 +328,27 @@ func (priv *PrivateKey) inverseOfPrivateKeyPlus1(c *sm2Curve) (*bigmod.Nat, erro
return priv.inverseOfKeyPlus1, nil
}
// signSM2EC generates an SM2 digital signature using the provided private key and hash.
// It follows the SM2 signature algorithm as specified in the Chinese cryptographic standards.
//
// Parameters:
// - c: A pointer to the sm2Curve structure representing the elliptic curve parameters.
// - priv: A pointer to the PrivateKey structure containing the private key for signing.
// - rand: An io.Reader instance used to generate random values.
// - hash: A byte slice containing the hash of the message to be signed.
//
// Returns:
// - sig: A byte slice containing the generated signature.
// - err: An error value indicating any issues encountered during the signing process.
//
// The function performs the following steps:
// 1. Computes the inverse of (d + 1) where d is the private key.
// 2. Converts the hash to an integer.
// 3. Generates a random point on the elliptic curve and computes the signature components (r, s).
// 4. Ensures that the signature components are non-zero and valid.
// 5. Encodes the signature components into a byte slice and returns it.
func signSM2EC(c *sm2Curve, priv *PrivateKey, rand io.Reader, hash []byte) (sig []byte, err error) {
// dp1Inv = (d+1)⁻¹
dp1Inv, err := priv.inverseOfPrivateKeyPlus1(c)
inverseDPlus1, err := priv.inverseOfPrivateKeyPlus1(c)
if err != nil {
return nil, err
}
@ -675,7 +397,7 @@ func signSM2EC(c *sm2Curve, priv *PrivateKey, rand io.Reader, hash []byte) (sig
// k = [k - s]
k.Sub(s, c.N)
// k = [(d+1)⁻¹ * (k - r * d)]
k.Mul(dp1Inv, c.N)
k.Mul(inverseDPlus1, c.N)
if k.IsZero() == 0 {
break
}
@ -713,91 +435,6 @@ func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
var ErrInvalidSignature = errors.New("sm2: invalid signature")
// RecoverPublicKeysFromSM2Signature recovers two or four SM2 public keys from a given signature and hash.
// It takes the hash and signature as input and returns the recovered public keys as []*ecdsa.PublicKey.
// If the signature or hash is invalid, it returns an error.
// The function follows the SM2 algorithm to recover the public keys.
func RecoverPublicKeysFromSM2Signature(hash, sig []byte) ([]*ecdsa.PublicKey, error) {
c := p256()
rBytes, sBytes, err := parseSignature(sig)
if err != nil {
return nil, err
}
r, err := bigmod.NewNat().SetBytes(rBytes, c.N)
if err != nil || r.IsZero() == 1 {
return nil, ErrInvalidSignature
}
s, err := bigmod.NewNat().SetBytes(sBytes, c.N)
if err != nil || s.IsZero() == 1 {
return nil, ErrInvalidSignature
}
e := bigmod.NewNat()
hashToNat(c, e, hash)
// p₁ = [-s]G
negS := bigmod.NewNat().ExpandFor(c.N).Sub(s, c.N)
p1, err := c.newPoint().ScalarBaseMult(negS.Bytes(c.N))
if err != nil {
return nil, err
}
// s = [r + s]
s.Add(r, c.N)
if s.IsZero() == 1 {
return nil, ErrInvalidSignature
}
// sBytes = (r+s)⁻¹
sBytes, err = _sm2ec.P256OrdInverse(s.Bytes(c.N))
if err != nil {
return nil, err
}
// r = (Rx + e) mod N
// Rx = r - e
r.Sub(e, c.N)
if r.IsZero() == 1 {
return nil, ErrInvalidSignature
}
pointRx := make([]*bigmod.Nat, 0, 2)
pointRx = append(pointRx, r)
// check if Rx in (N, P), small probability event
s.Set(r)
s = s.Add(c.N.Nat(), c.P)
if s.CmpGeq(c.N.Nat()) == 1 {
pointRx = append(pointRx, s)
}
pubs := make([]*ecdsa.PublicKey, 0, 4)
bytes := make([]byte, 32+1)
compressFlags := []byte{compressed02, compressed03}
// Rx has one or two possible values, so point R has two or four possible values
for _, x := range pointRx {
rBytes = x.Bytes(c.N)
copy(bytes[1:], rBytes)
for _, flag := range compressFlags {
bytes[0] = flag
// p0 = R
p0, err := c.newPoint().SetBytes(bytes)
if err != nil {
return nil, err
}
// p0 = R - [s]G
p0.Add(p0, p1)
// Pub = [(r + s)⁻¹](R - [s]G)
p0.ScalarMult(p0, sBytes)
pub := new(ecdsa.PublicKey)
pub.Curve = c.curve
pub.X, pub.Y, err = c.pointToAffine(p0)
if err != nil {
return nil, err
}
pubs = append(pubs, pub)
}
}
return pubs, nil
}
// VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the
// public key, pub. Its return value records whether the signature is valid.
//
@ -919,7 +556,9 @@ func hashToNat(c *sm2Curve, e *bigmod.Nat, hash []byte) {
}
}
// IsSM2PublicKey check if given public key is a SM2 public key or not
// IsSM2PublicKey checks if the provided public key is an SM2 public key.
// It takes an interface{} as input and attempts to assert it to an *ecdsa.PublicKey.
// The function returns true if the assertion is successful and the public key's curve is SM2 P-256.
func IsSM2PublicKey(publicKey any) bool {
pub, ok := publicKey.(*ecdsa.PublicKey)
return ok && pub.Curve == sm2ec.P256()
@ -939,7 +578,7 @@ func PublicKeyToECDH(k *ecdsa.PublicKey) (*ecdh.PublicKey, error) {
return nil, errors.New("sm2: unsupported curve by ecdh")
}
if !k.Curve.IsOnCurve(k.X, k.Y) {
return nil, errors.New("sm2: invalid public key")
return nil, errInvalidPublicKey
}
return c.NewPublicKey(elliptic.Marshal(k.Curve, k.X, k.Y))
}
@ -954,7 +593,7 @@ func (k *PrivateKey) ECDH() (*ecdh.PrivateKey, error) {
}
size := (k.Curve.Params().N.BitLen() + 7) / 8
if k.D.BitLen() > size*8 {
return nil, errors.New("sm2: invalid private key")
return nil, errInvalidPrivateKey
}
return c.NewPrivateKey(k.D.FillBytes(make([]byte, size)))
}
@ -1011,6 +650,110 @@ func randomPoint(c *sm2Curve, rand io.Reader, checkOrderMinus1 bool) (k *bigmod.
// randomPoint rejects a candidate for being higher than the modulus.
var testingOnlyRejectionSamplingLooped func()
// RecoverPublicKeysFromSM2Signature attempts to recover the public keys from an SM2 signature.
// This function takes a hash and a signature as input and returns a slice of possible public keys
// that could have generated the given signature.
//
// Parameters:
// - hash: The hash of the message that was signed.
// - sig: The SM2 signature.
//
// Returns:
// - A slice of pointers to ecdsa.PublicKey, representing the possible public keys.
// - An error if the signature is invalid or if any other error occurs during the recovery process.
//
// The function performs the following steps:
// 1. Parses the signature to extract the r and s values.
// 2. Converts the hash to a big integer (Nat).
// 3. Computes the point p₁ = [-s]G.
// 4. Computes s = [r + s] and its modular inverse.
// 5. Computes the possible x-coordinates (Rx) for the point R.
// 6. For each possible Rx, computes the corresponding point R and derives the public key.
//
// Note: The function handles the case where there are one or two possible values for Rx,
// resulting in two or four possible public keys.
func RecoverPublicKeysFromSM2Signature(hash, sig []byte) ([]*ecdsa.PublicKey, error) {
c := p256()
rBytes, sBytes, err := parseSignature(sig)
if err != nil {
return nil, err
}
r, err := bigmod.NewNat().SetBytes(rBytes, c.N)
if err != nil || r.IsZero() == 1 {
return nil, ErrInvalidSignature
}
s, err := bigmod.NewNat().SetBytes(sBytes, c.N)
if err != nil || s.IsZero() == 1 {
return nil, ErrInvalidSignature
}
e := bigmod.NewNat()
hashToNat(c, e, hash)
// p₁ = [-s]G
negS := bigmod.NewNat().ExpandFor(c.N).Sub(s, c.N)
p1, err := c.newPoint().ScalarBaseMult(negS.Bytes(c.N))
if err != nil {
return nil, err
}
// s = [r + s]
s.Add(r, c.N)
if s.IsZero() == 1 {
return nil, ErrInvalidSignature
}
// sBytes = (r+s)⁻¹
sBytes, err = _sm2ec.P256OrdInverse(s.Bytes(c.N))
if err != nil {
return nil, err
}
// r = (Rx + e) mod N
// Rx = r - e
r.Sub(e, c.N)
if r.IsZero() == 1 {
return nil, ErrInvalidSignature
}
pointRx := make([]*bigmod.Nat, 0, 2)
pointRx = append(pointRx, r)
// check if Rx in (N, P), small probability event
s.Set(r)
s = s.Add(c.N.Nat(), c.P)
if s.CmpGeq(c.N.Nat()) == 1 {
pointRx = append(pointRx, s)
}
pubs := make([]*ecdsa.PublicKey, 0, 4)
bytes := make([]byte, 32+1)
compressFlags := []byte{compressed02, compressed03}
// Rx has one or two possible values, so point R has two or four possible values
for _, x := range pointRx {
rBytes = x.Bytes(c.N)
copy(bytes[1:], rBytes)
for _, flag := range compressFlags {
bytes[0] = flag
// p0 = R
p0, err := c.newPoint().SetBytes(bytes)
if err != nil {
return nil, err
}
// p0 = R - [s]G
p0.Add(p0, p1)
// Pub = [(r + s)⁻¹](R - [s]G)
p0.ScalarMult(p0, sBytes)
pub := new(ecdsa.PublicKey)
pub.Curve = c.curve
pub.X, pub.Y, err = c.pointToAffine(p0)
if err != nil {
return nil, err
}
pubs = append(pubs, pub)
}
}
return pubs, nil
}
type sm2Curve struct {
newPoint func() *_sm2ec.SM2P256Point
curve elliptic.Curve
@ -1073,5 +816,3 @@ func precomputeParams(c *sm2Curve, curve elliptic.Curve) {
c.nMinus1 = c.N.Nat().SubOne(c.N)
c.nMinus2 = new(bigmod.Nat).Set(c.nMinus1).SubOne(c.N).Bytes(c.N)
}
var errInvalidPrivateKey = errors.New("sm2: invalid private key")

View File

@ -10,7 +10,6 @@ import (
"encoding/hex"
"io"
"math/big"
"reflect"
"testing"
"github.com/emmansun/gmsm/sm3"
@ -109,367 +108,6 @@ func TestNewPublicKey(t *testing.T) {
}
}
func TestSplicingOrder(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
plainText string
from ciphertextSplicingOrder
to ciphertextSplicingOrder
}{
// TODO: Add test cases.
{"less than 32 1", "encryption standard", C1C2C3, C1C3C2},
{"less than 32 2", "encryption standard", C1C3C2, C1C2C3},
{"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2},
{"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3},
{"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2},
{"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewPlainEncrypterOpts(MarshalUncompressed, tt.from))
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.from))
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
//Adjust splicing order
ciphertext, err = AdjustCiphertextSplicingOrder(ciphertext, tt.from, tt.to)
if err != nil {
t.Fatalf("adjust splicing order failed %v", err)
}
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.to))
if err != nil {
t.Fatalf("decrypt failed after adjust splicing order %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func TestEncryptDecryptASN1(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
key2 := new(PrivateKey)
key2.PrivateKey = *priv2
tests := []struct {
name string
plainText string
priv *PrivateKey
}{
// TODO: Add test cases.
{"less than 32", "encryption standard", priv},
{"equals 32", "encryption standard encryption ", priv},
{"long than 32", "encryption standard encryption standard", priv},
{"less than 32", "encryption standard", key2},
{"equals 32", "encryption standard encryption ", key2},
{"long than 32", "encryption standard encryption standard", key2},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
encrypterOpts := ASN1EncrypterOpts
ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
if err != nil {
t.Fatalf("%v encrypt failed %v", tt.priv.Curve.Params().Name, err)
}
plaintext, err := tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
if err != nil {
t.Fatalf("%v decrypt 1 failed %v", tt.priv.Curve.Params().Name, err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
plaintext, err = tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
if err != nil {
t.Fatalf("%v decrypt 2 failed %v", tt.priv.Curve.Params().Name, err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func TestPlainCiphertext2ASN1(t *testing.T) {
ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05")
_, err := PlainCiphertext2ASN1(append([]byte{0x30}, ciphertext...), C1C3C2)
if err == nil {
t.Fatalf("expected error")
}
_, err = PlainCiphertext2ASN1(ciphertext[:65], C1C3C2)
if err == nil {
t.Fatalf("expected error")
}
ciphertext[0] = 0x10
_, err = PlainCiphertext2ASN1(ciphertext, C1C3C2)
if err == nil {
t.Fatalf("expected error")
}
}
func TestAdjustCiphertextSplicingOrder(t *testing.T) {
ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05")
res, err := AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C3C2)
if err != nil || &res[0] != &ciphertext[0] {
t.Fatalf("should be same one")
}
_, err = AdjustCiphertextSplicingOrder(ciphertext[:65], C1C3C2, C1C2C3)
if err == nil {
t.Fatalf("expected error")
}
ciphertext[0] = 0x10
_, err = AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C2C3)
if err == nil {
t.Fatalf("expected error")
}
}
func TestCiphertext2ASN1(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
plainText string
}{
// TODO: Add test cases.
{"less than 32", "encryption standard"},
{"equals 32", "encryption standard encryption "},
{"long than 32", "encryption standard encryption standard"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext1, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
ciphertext, err := PlainCiphertext2ASN1(ciphertext1, C1C3C2)
if err != nil {
t.Fatalf("convert to ASN.1 failed %v", err)
}
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
ciphertext2, err := AdjustCiphertextSplicingOrder(ciphertext1, C1C3C2, C1C2C3)
if err != nil {
t.Fatalf("adjust order failed %v", err)
}
ciphertext, err = PlainCiphertext2ASN1(ciphertext2, C1C2C3)
if err != nil {
t.Fatalf("convert to ASN.1 failed %v", err)
}
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func TestCiphertextASN12Plain(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
plainText string
}{
// TODO: Add test cases.
{"less than 32", "encryption standard"},
{"equals 32", "encryption standard encryption "},
{"long than 32", "encryption standard encryption standard"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext, err := EncryptASN1(rand.Reader, &priv.PublicKey, []byte(tt.plainText))
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
ciphertext, err = ASN1Ciphertext2Plain(ciphertext, nil)
if err != nil {
t.Fatalf("convert to plain failed %v", err)
}
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, nil)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func TestEncryptWithInfinitePublicKey(t *testing.T) {
pub := new(ecdsa.PublicKey)
pub.Curve = P256()
pub.X = big.NewInt(0)
pub.Y = big.NewInt(0)
_, err := Encrypt(rand.Reader, pub, []byte("sm2 encryption standard"), nil)
if err == nil {
t.Fatalf("should be failed")
}
}
func TestEncryptEmptyPlaintext(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, nil, nil)
if err != nil || ciphertext != nil {
t.Fatalf("nil plaintext should return nil")
}
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte{}, nil)
if err != nil || ciphertext != nil {
t.Fatalf("empty plaintext should return nil")
}
}
func TestEncryptDecrypt(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
key2 := new(PrivateKey)
key2.PrivateKey = *priv2
tests := []struct {
name string
plainText string
priv *PrivateKey
}{
// TODO: Add test cases.
{"less than 32", "encryption standard", priv},
{"equals 32", "encryption standard encryption ", priv},
{"long than 32", "encryption standard encryption standard", priv},
{"less than 32", "encryption standard", key2},
{"equals 32", "encryption standard encryption ", key2},
{"long than 32", "encryption standard encryption standard", key2},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), nil)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err := Decrypt(tt.priv, ciphertext)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
// compress mode
encrypterOpts := NewPlainEncrypterOpts(MarshalCompressed, C1C3C2)
ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err = Decrypt(tt.priv, ciphertext)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
// hybrid mode
encrypterOpts = NewPlainEncrypterOpts(MarshalHybrid, C1C3C2)
ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err = Decrypt(tt.priv, ciphertext)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
plaintext, err = Decrypt(tt.priv, ciphertext)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func TestInvalidCiphertext(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
ciphertext []byte
}{
// TODO: Add test cases.
{errCiphertextTooShort.Error(), make([]byte, 65)},
{ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 96)...)},
{ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 97)...)},
{ErrDecryption.Error(), append([]byte{0x02}, make([]byte, 65)...)},
{ErrDecryption.Error(), append([]byte{0x30}, make([]byte, 97)...)},
{ErrDecryption.Error(), make([]byte, 97)},
}
for i, tt := range tests {
_, err := Decrypt(priv, tt.ciphertext)
if err.Error() != tt.name {
t.Fatalf("case %v, expected %v, got %v\n", i, tt.name, err.Error())
}
}
}
func TestPrivateKeyPlus1WithOrderMinus1(t *testing.T) {
priv := new(PrivateKey)
priv.D = new(big.Int).Sub(P256().Params().N, big.NewInt(1))
priv.Curve = P256()
priv.PublicKey.X, priv.PublicKey.Y = P256().ScalarBaseMult(priv.D.Bytes())
_, err := priv.inverseOfPrivateKeyPlus1(p256())
if err == nil || err != errInvalidPrivateKey {
t.Errorf("expected invalid private key error")
}
}
func TestSignVerify(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
plainText string
}{
// TODO: Add test cases.
{"less than 32", "encryption standard"},
{"equals 32", "encryption standard encryption "},
{"long than 32", "encryption standard encryption standard"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hashed := sm3.Sum([]byte(tt.plainText))
signature, err := priv.Sign(rand.Reader, hashed[:], nil)
if err != nil {
t.Fatalf("sign failed %v", err)
}
result := VerifyASN1(&priv.PublicKey, hashed[:], signature)
if !result {
t.Fatal("verify failed")
}
hashed[0] ^= 0xff
if VerifyASN1(&priv.PublicKey, hashed[:], signature) {
t.Errorf("VerifyASN1 always works!")
}
})
}
}
func testRecoverPublicKeysFromSM2Signature(t *testing.T, priv *PrivateKey) {
tests := []struct {
name string
@ -774,6 +412,48 @@ func TestRandomPoint(t *testing.T) {
}
}
func TestPrivateKeyPlus1WithOrderMinus1(t *testing.T) {
priv := new(PrivateKey)
priv.D = new(big.Int).Sub(P256().Params().N, big.NewInt(1))
priv.Curve = P256()
priv.PublicKey.X, priv.PublicKey.Y = P256().ScalarBaseMult(priv.D.Bytes())
_, err := priv.inverseOfPrivateKeyPlus1(p256())
if err == nil || err != errInvalidPrivateKey {
t.Errorf("expected invalid private key error")
}
}
func TestSignVerify(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
plainText string
}{
// TODO: Add test cases.
{"less than 32", "encryption standard"},
{"equals 32", "encryption standard encryption "},
{"long than 32", "encryption standard encryption standard"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hashed := sm3.Sum([]byte(tt.plainText))
signature, err := priv.Sign(rand.Reader, hashed[:], nil)
if err != nil {
t.Fatalf("sign failed %v", err)
}
result := VerifyASN1(&priv.PublicKey, hashed[:], signature)
if !result {
t.Fatal("verify failed")
}
hashed[0] ^= 0xff
if VerifyASN1(&priv.PublicKey, hashed[:], signature) {
t.Errorf("VerifyASN1 always works!")
}
})
}
}
func BenchmarkGenerateKey_SM2(b *testing.B) {
r := bufio.NewReaderSize(rand.Reader, 1<<15)
b.ReportAllocs()
@ -894,45 +574,3 @@ func BenchmarkVerify_SM2(b *testing.B) {
}
}
}
func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext []byte) {
r := bufio.NewReaderSize(rand.Reader, 1<<15)
priv, err := ecdsa.GenerateKey(curve, r)
if err != nil {
b.Fatal(err)
}
b.SetBytes(int64(len(plaintext)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil)
}
}
func BenchmarkEncryptNoMoreThan32_P256(b *testing.B) {
benchmarkEncrypt(b, elliptic.P256(), make([]byte, 31))
}
func BenchmarkEncryptNoMoreThan32_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), make([]byte, 31))
}
func BenchmarkEncrypt128_P256(b *testing.B) {
benchmarkEncrypt(b, elliptic.P256(), make([]byte, 128))
}
func BenchmarkEncrypt128_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), make([]byte, 128))
}
func BenchmarkEncrypt512_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), make([]byte, 512))
}
func BenchmarkEncrypt1K_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), make([]byte, 1024))
}
func BenchmarkEncrypt8K_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), make([]byte, 8*1024))
}

View File

@ -79,8 +79,26 @@ func MarshalEnvelopedPrivateKey(rand io.Reader, pub *ecdsa.PublicKey, tobeEnvelo
return b.Bytes()
}
// ParseEnvelopedPrivateKey, parses and decrypts the enveloped SM2 private key.
// This methed just supports SM4 cipher now.
// ParseEnvelopedPrivateKey parses an enveloped private key using the provided private key.
// The enveloped key is expected to be in ASN.1 format and encrypted with a symmetric cipher.
//
// Parameters:
// - priv: The private key used to decrypt the symmetric key.
// - enveloped: The ASN.1 encoded and encrypted enveloped private key.
//
// Returns:
// - A pointer to the decrypted PrivateKey.
// - An error if the parsing or decryption fails.
//
// The function performs the following steps:
// 1. Unmarshals the ASN.1 data to extract the symmetric algorithm identifier, encrypted symmetric key, public key, and encrypted private key.
// 2. Verifies that the symmetric algorithm is supported (SM4 or SM4ECB).
// 3. Parses the public key from the ASN.1 data.
// 4. Decrypts the symmetric key using the provided private key.
// 5. Decrypts the SM2 private key using the decrypted symmetric key.
// 6. Verifies that the decrypted private key matches the public key.
//
// Errors are returned if any of the steps fail, including invalid ASN.1 format, unsupported symmetric cipher, decryption failures, or key mismatches.
func ParseEnvelopedPrivateKey(priv *PrivateKey, enveloped []byte) (*PrivateKey, error) {
// unmarshal the asn.1 data
var (

View File

@ -149,34 +149,34 @@ func (ke *KeyExchange) InitKeyExchange(rand io.Reader) (*ecdsa.PublicKey, error)
func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte {
var buffer []byte
hash := sm3.New()
hash.Write(toBytes(ke.privateKey, ke.v.X))
hash.Write(bigIntToBytes(ke.privateKey, ke.v.X))
if isResponder {
hash.Write(ke.peerZ)
hash.Write(ke.z)
hash.Write(toBytes(ke.privateKey, ke.peerSecret.X))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y))
hash.Write(toBytes(ke.privateKey, ke.secret.X))
hash.Write(toBytes(ke.privateKey, ke.secret.Y))
hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.X))
hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.Y))
hash.Write(bigIntToBytes(ke.privateKey, ke.secret.X))
hash.Write(bigIntToBytes(ke.privateKey, ke.secret.Y))
} else {
hash.Write(ke.z)
hash.Write(ke.peerZ)
hash.Write(toBytes(ke.privateKey, ke.secret.X))
hash.Write(toBytes(ke.privateKey, ke.secret.Y))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.X))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y))
hash.Write(bigIntToBytes(ke.privateKey, ke.secret.X))
hash.Write(bigIntToBytes(ke.privateKey, ke.secret.Y))
hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.X))
hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.Y))
}
buffer = hash.Sum(nil)
hash.Reset()
hash.Write([]byte{prefix})
hash.Write(toBytes(ke.privateKey, ke.v.Y))
hash.Write(bigIntToBytes(ke.privateKey, ke.v.Y))
hash.Write(buffer)
return hash.Sum(nil)
}
func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
var buffer []byte
buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...)
buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...)
buffer = append(buffer, bigIntToBytes(ke.privateKey, ke.v.X)...)
buffer = append(buffer, bigIntToBytes(ke.privateKey, ke.v.Y)...)
if isResponder {
buffer = append(buffer, ke.peerZ...)
buffer = append(buffer, ke.z...)

View File

@ -301,12 +301,9 @@ func calculateSampleZA(pub *ecdsa.PublicKey, a *big.Int, uid []byte) ([]byte, er
if uidLen > 0 {
md.Write(uid)
}
md.Write(toBytes(pub.Curve, a))
md.Write(toBytes(pub.Curve, pub.Params().B))
md.Write(toBytes(pub.Curve, pub.Params().Gx))
md.Write(toBytes(pub.Curve, pub.Params().Gy))
md.Write(toBytes(pub.Curve, pub.X))
md.Write(toBytes(pub.Curve, pub.Y))
writeCurveParams(md, pub.Curve.Params())
md.Write(bigIntToBytes(pub.Curve, pub.X))
md.Write(bigIntToBytes(pub.Curve, pub.Y))
return md.Sum(nil), nil
}

View File

@ -259,7 +259,7 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
//A5, calculate t=KDF(x2||y2, klen)
c2 := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
c2 := sm3.Kdf(append(bigIntToBytes(curve, x2), bigIntToBytes(curve, y2)...), msgLen)
if subtle.ConstantTimeAllZero(c2) == 1 {
retryCount++
if retryCount > maxRetryLimit {
@ -289,9 +289,9 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc
func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte {
md := sm3.New()
md.Write(toBytes(curve, x2))
md.Write(bigIntToBytes(curve, x2))
md.Write(msg)
md.Write(toBytes(curve, y2))
md.Write(bigIntToBytes(curve, y2))
return md.Sum(nil)
}
@ -306,95 +306,6 @@ func mashalASN1Ciphertext(x1, y1 *big.Int, c2, c3 []byte) ([]byte, error) {
return b.Bytes()
}
// ASN1Ciphertext2Plain utility method to convert ASN.1 encoding ciphertext to plain encoding format
func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error) {
if opts == nil {
opts = defaultEncrypterOpts
}
x1, y1, c2, c3, err := unmarshalASN1Ciphertext((ciphertext))
if err != nil {
return nil, err
}
curve := sm2ec.P256()
c1 := opts.pointMarshalMode.mashal(curve, x1, y1)
if opts.ciphertextSplicingOrder == C1C3C2 {
// c1 || c3 || c2
return append(append(c1, c3...), c2...), nil
}
// c1 || c2 || c3
return append(append(c1, c2...), c3...), nil
}
// PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format
func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) {
if ciphertext[0] == 0x30 {
return nil, errors.New("sm2: invalid plain encoding ciphertext")
}
curve := sm2ec.P256()
ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size {
return nil, errCiphertextTooShort
}
// get C1, and check C1
x1, y1, c3Start, err := bytes2Point(curve, ciphertext)
if err != nil {
return nil, err
}
var c2, c3 []byte
if from == C1C3C2 {
c2 = ciphertext[c3Start+sm3.Size:]
c3 = ciphertext[c3Start : c3Start+sm3.Size]
} else {
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
c3 = ciphertext[ciphertextLen-sm3.Size:]
}
return mashalASN1Ciphertext(x1, y1, c2, c3)
}
// AdjustCiphertextSplicingOrder utility method to change c2 c3 order
func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicingOrder) ([]byte, error) {
curve := sm2ec.P256()
if from == to {
return ciphertext, nil
}
ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size {
return nil, errCiphertextTooShort
}
// get C1, and check C1
_, _, c3Start, err := bytes2Point(curve, ciphertext)
if err != nil {
return nil, err
}
var c1, c2, c3 []byte
c1 = ciphertext[:c3Start]
if from == C1C3C2 {
c2 = ciphertext[c3Start+sm3.Size:]
c3 = ciphertext[c3Start : c3Start+sm3.Size]
} else {
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
c3 = ciphertext[ciphertextLen-sm3.Size:]
}
result := make([]byte, ciphertextLen)
copy(result, c1)
if to == C1C3C2 {
// c1 || c3 || c2
copy(result[c3Start:], c3)
copy(result[c3Start+sm3.Size:], c2)
} else {
// c1 || c2 || c3
copy(result[c3Start:], c2)
copy(result[ciphertextLen-sm3.Size:], c3)
}
return result, nil
}
func decryptASN1(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext)
if err != nil {
@ -407,7 +318,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error
curve := priv.Curve
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
msgLen := len(c2)
msg := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
msg := sm3.Kdf(append(bigIntToBytes(curve, x2), bigIntToBytes(curve, y2)...), msgLen)
if subtle.ConstantTimeAllZero(c2) == 1 {
return nil, ErrDecryption
}
@ -428,7 +339,7 @@ func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]
if opts.ciphertextEncoding == ENCODING_ASN1 {
return decryptASN1(priv, ciphertext)
}
splicingOrder = opts.cipherTextSplicingOrder
splicingOrder = opts.ciphertextSplicingOrder
}
if ciphertext[0] == 0x30 {
return decryptASN1(priv, ciphertext)
@ -436,7 +347,7 @@ func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]
ciphertextLen := len(ciphertext)
curve := priv.Curve
// B1, get C1, and check C1
x1, y1, c3Start, err := bytes2Point(curve, ciphertext)
x1, y1, c3Start, err := bytesToPoint(curve, ciphertext)
if err != nil {
return nil, ErrDecryption
}
@ -454,7 +365,20 @@ func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]
return rawDecrypt(priv, x1, y1, c2, c3)
}
func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
switch mode {
case MarshalCompressed:
return elliptic.MarshalCompressed(curve, x, y)
case MarshalHybrid:
buffer := elliptic.Marshal(curve, x, y)
buffer[0] = byte(y.Bit(0)) | hybrid06
return buffer
default:
return elliptic.Marshal(curve, x, y)
}
}
func bytesToPoint(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
if len(bytes) < 1+(curve.Params().BitSize/8) {
return nil, nil, 0, fmt.Errorf("sm2: invalid bytes length %d", len(bytes))
}
@ -486,20 +410,7 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e
}
return x, y, 1 + byteLen, nil
}
return nil, nil, 0, fmt.Errorf("sm2: unsupport point form %d, curve %s", format, curve.Params().Name)
return nil, nil, 0, fmt.Errorf("sm2: unsupported point form %d, curve %s", format, curve.Params().Name)
}
return nil, nil, 0, fmt.Errorf("sm2: unknown point form %d", format)
}
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
switch mode {
case MarshalCompressed:
return elliptic.MarshalCompressed(curve, x, y)
case MarshalHybrid:
buffer := elliptic.Marshal(curve, x, y)
buffer[0] = byte(y.Bit(0)) | hybrid06
return buffer
default:
return elliptic.Marshal(curve, x, y)
}
}

396
sm2/sm2_pke.go Normal file
View File

@ -0,0 +1,396 @@
// Package sm2 implements ShangMi(SM) sm2 digital signature, public key encryption and key exchange algorithms.
package sm2
// Further references:
// [NSA]: Suite B implementer's guide to FIPS 186-3
// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.182.4503&rep=rep1&type=pdf
// [SECG]: SECG, SEC1
// http://www.secg.org/sec1-v2.pdf
// [GM/T]: SM2 GB/T 32918.2-2016, GB/T 32918.4-2016
//
import (
"crypto"
"crypto/ecdsa"
_subtle "crypto/subtle"
"errors"
"fmt"
"io"
"math/big"
"github.com/emmansun/gmsm/internal/bigmod"
_sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
"github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/sm3"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
const (
uncompressed byte = 0x04
compressed02 byte = 0x02
compressed03 byte = compressed02 | 0x01
hybrid06 byte = 0x06
hybrid07 byte = hybrid06 | 0x01
)
type pointMarshalMode byte
const (
//MarshalUncompressed uncompressed marshal mode
MarshalUncompressed pointMarshalMode = iota
//MarshalCompressed compressed marshal mode
MarshalCompressed
//MarshalHybrid hybrid marshal mode
MarshalHybrid
)
type ciphertextSplicingOrder byte
const (
C1C3C2 ciphertextSplicingOrder = iota
C1C2C3
)
// splitC2C3 splits the given ciphertext into two parts, C2 and C3, based on the splicing order.
// If the order is C1C3C2, it returns the first sm3.Size bytes as C3 and the rest as C2.
// Otherwise, it returns the first part as C2 and the last sm3.Size bytes as C3.
func (order ciphertextSplicingOrder) splitC2C3(ciphertext []byte) ([]byte, []byte) {
if order == C1C3C2 {
return ciphertext[sm3.Size:], ciphertext[:sm3.Size]
}
return ciphertext[:len(ciphertext)-sm3.Size], ciphertext[len(ciphertext)-sm3.Size:]
}
// spliceCiphertext splices the given ciphertext components together based on the splicing order.
func (order ciphertextSplicingOrder) spliceCiphertext(c1, c2, c3 []byte) ([]byte, error) {
switch order {
case C1C3C2:
return append(append(c1, c3...), c2...), nil
case C1C2C3:
return append(append(c1, c2...), c3...), nil
default:
return nil, errors.New("sm2: invalid ciphertext splicing order")
}
}
type ciphertextEncoding byte
const (
ENCODING_PLAIN ciphertextEncoding = iota
ENCODING_ASN1
)
// EncrypterOpts represents the options for the SM2 encryption process.
// It includes settings for ciphertext encoding, point marshaling mode,
// and the order in which the ciphertext components are spliced together.
type EncrypterOpts struct {
ciphertextEncoding ciphertextEncoding
pointMarshalMode pointMarshalMode
ciphertextSplicingOrder ciphertextSplicingOrder
}
// DecrypterOpts represents the options for the decryption process.
// It includes settings for how the ciphertext is encoded and how the
// components of the ciphertext are spliced together.
//
// Fields:
// - ciphertextEncoding: Specifies the encoding format of the ciphertext.
// - ciphertextSplicingOrder: Defines the order in which the components
// of the ciphertext are spliced together.
type DecrypterOpts struct {
ciphertextEncoding ciphertextEncoding
ciphertextSplicingOrder ciphertextSplicingOrder
}
// NewPlainEncrypterOpts creates a SM2 non-ASN1 encrypter options.
func NewPlainEncrypterOpts(marshalMode pointMarshalMode, splicingOrder ciphertextSplicingOrder) *EncrypterOpts {
return &EncrypterOpts{ENCODING_PLAIN, marshalMode, splicingOrder}
}
// NewPlainDecrypterOpts creates a SM2 non-ASN1 decrypter options.
func NewPlainDecrypterOpts(splicingOrder ciphertextSplicingOrder) *DecrypterOpts {
return &DecrypterOpts{ENCODING_PLAIN, splicingOrder}
}
var (
defaultEncrypterOpts = &EncrypterOpts{ENCODING_PLAIN, MarshalUncompressed, C1C3C2}
ASN1EncrypterOpts = &EncrypterOpts{ENCODING_ASN1, MarshalUncompressed, C1C3C2}
ASN1DecrypterOpts = &DecrypterOpts{ENCODING_ASN1, C1C3C2}
)
const maxRetryLimit = 100
var errCiphertextTooShort = errors.New("sm2: ciphertext too short")
// EncryptASN1 sm2 encrypt and output ASN.1 result, compliance with GB/T 32918.4-2016.
//
// The random parameter is used as a source of entropy to ensure that
// encrypting the same message twice doesn't result in the same ciphertext.
// Most applications should use [crypto/rand.Reader] as random.
func EncryptASN1(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) {
return Encrypt(random, pub, msg, ASN1EncrypterOpts)
}
// Encrypt sm2 encrypt implementation, compliance with GB/T 32918.4-2016.
//
// The random parameter is used as a source of entropy to ensure that
// encrypting the same message twice doesn't result in the same ciphertext.
// Most applications should use [crypto/rand.Reader] as random.
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) {
//A3, requirement is to check if h*P is infinite point, h is 1
if pub.X.Sign() == 0 && pub.Y.Sign() == 0 {
return nil, errors.New("sm2: public key point is the infinity")
}
if len(msg) == 0 {
return nil, nil
}
if opts == nil {
opts = defaultEncrypterOpts
}
switch pub.Curve.Params() {
case P256().Params():
return encryptSM2EC(p256(), pub, random, msg, opts)
default:
return encryptLegacy(random, pub, msg, opts)
}
}
func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byte, opts *EncrypterOpts) ([]byte, error) {
Q, err := c.pointFromAffine(pub.X, pub.Y)
if err != nil {
return nil, err
}
retryCount := 0
for {
k, C1, err := randomPoint(c, random, false)
if err != nil {
return nil, err
}
C2, err := Q.ScalarMult(Q, k.Bytes(c.N))
if err != nil {
return nil, err
}
C2Bytes := C2.Bytes()[1:]
c2 := sm3.Kdf(C2Bytes, len(msg))
if subtle.ConstantTimeAllZero(c2) == 1 {
retryCount++
if retryCount > maxRetryLimit {
return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", retryCount)
}
continue
}
//A6, C2 = M + t;
subtle.XORBytes(c2, msg, c2)
//A7, C3 = hash(x2||M||y2)
md := sm3.New()
md.Write(C2Bytes[:len(C2Bytes)/2])
md.Write(msg)
md.Write(C2Bytes[len(C2Bytes)/2:])
c3 := md.Sum(nil)
if opts.ciphertextEncoding == ENCODING_PLAIN {
return encodeCiphertext(opts, C1, c2, c3)
}
return encodingCiphertextASN1(C1, c2, c3)
}
}
func encodeCiphertext(opts *EncrypterOpts, C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) {
var c1 []byte
switch opts.pointMarshalMode {
case MarshalCompressed:
c1 = C1.BytesCompressed()
default:
c1 = C1.Bytes()
}
return opts.ciphertextSplicingOrder.spliceCiphertext(c1, c2, c3)
}
func encodingCiphertextASN1(C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) {
c1 := C1.Bytes()
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
addASN1IntBytes(b, c1[1:len(c1)/2+1])
addASN1IntBytes(b, c1[len(c1)/2+1:])
b.AddASN1OctetString(c3)
b.AddASN1OctetString(c2)
})
return b.Bytes()
}
// Decrypt decrypts ciphertext msg to plaintext.
// The opts argument should be appropriate for the primitive used.
// Compliance with GB/T 32918.4-2016 chapter 7.
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
var sm2Opts *DecrypterOpts
sm2Opts, _ = opts.(*DecrypterOpts)
return decrypt(priv, msg, sm2Opts)
}
// Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2}.
// Compliance with GB/T 32918.4-2016.
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
return decrypt(priv, ciphertext, nil)
}
// ErrDecryption represents a failure to decrypt a message.
// It is deliberately vague to avoid adaptive attacks.
var ErrDecryption = errors.New("sm2: decryption error")
func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
return nil, errCiphertextTooShort
}
switch priv.Curve.Params() {
case P256().Params():
return decryptSM2EC(p256(), priv, ciphertext, opts)
default:
return decryptLegacy(priv, ciphertext, opts)
}
}
func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
C1, c2, c3, err := parseCiphertext(c, ciphertext, opts)
if err != nil {
return nil, ErrDecryption
}
d, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
if err != nil {
return nil, ErrDecryption
}
C2, err := C1.ScalarMult(C1, d.Bytes(c.N))
if err != nil {
return nil, ErrDecryption
}
C2Bytes := C2.Bytes()[1:]
msgLen := len(c2)
msg := sm3.Kdf(C2Bytes, msgLen)
if subtle.ConstantTimeAllZero(c2) == 1 {
return nil, ErrDecryption
}
//B5, calculate msg = c2 ^ t
subtle.XORBytes(msg, c2, msg)
md := sm3.New()
md.Write(C2Bytes[:len(C2Bytes)/2])
md.Write(msg)
md.Write(C2Bytes[len(C2Bytes)/2:])
u := md.Sum(nil)
if _subtle.ConstantTimeCompare(u, c3) == 1 {
return msg, nil
}
return nil, ErrDecryption
}
// parseCiphertext parses the given ciphertext according to the specified SM2 curve and decryption options.
// It returns the parsed SM2 point (C1), the decrypted message (C2), the message digest (C3), and an error if any.
func parseCiphertext(c *sm2Curve, ciphertext []byte, opts *DecrypterOpts) (*_sm2ec.SM2P256Point, []byte, []byte, error) {
bitSize := c.curve.Params().BitSize
byteLen := (bitSize + 7) / 8
splicingOrder := C1C3C2
if opts != nil {
splicingOrder = opts.ciphertextSplicingOrder
}
var ciphertextFormat byte = 0xff // invalid
if len(ciphertext) > 0 {
ciphertextFormat = ciphertext[0]
}
var c1Len int
switch ciphertextFormat {
case byte(asn1.SEQUENCE):
return parseCiphertextASN1(c, ciphertext)
case uncompressed:
c1Len = 1 + 2*byteLen
case compressed02, compressed03:
c1Len = 1 + byteLen
default:
return nil, nil, nil, errors.New("sm2: invalid/unsupported ciphertext format")
}
if len(ciphertext) < c1Len+sm3.Size {
return nil, nil, nil, errCiphertextTooShort
}
C1, err := c.newPoint().SetBytes(ciphertext[:c1Len])
if err != nil {
return nil, nil, nil, err
}
c2, c3 := splicingOrder.splitC2C3(ciphertext[c1Len:])
return C1, c2, c3, nil
}
func unmarshalASN1Ciphertext(ciphertext []byte) (*big.Int, *big.Int, []byte, []byte, error) {
var (
x1, y1 = &big.Int{}, &big.Int{}
c2, c3 []byte
inner cryptobyte.String
)
input := cryptobyte.String(ciphertext)
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
!input.Empty() ||
!inner.ReadASN1Integer(x1) ||
!inner.ReadASN1Integer(y1) ||
!inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) ||
!inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) ||
!inner.Empty() {
return nil, nil, nil, nil, errors.New("sm2: invalid asn1 format ciphertext")
}
return x1, y1, c2, c3, nil
}
func parseCiphertextASN1(c *sm2Curve, ciphertext []byte) (*_sm2ec.SM2P256Point, []byte, []byte, error) {
x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext)
if err != nil {
return nil, nil, nil, err
}
C1, err := c.pointFromAffine(x1, y1)
if err != nil {
return nil, nil, nil, err
}
return C1, c2, c3, nil
}
// AdjustCiphertextSplicingOrder utility method to change c2 c3 order
func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicingOrder) ([]byte, error) {
curve := p256()
if from == to {
return ciphertext, nil
}
C1, c2, c3, err := parseCiphertext(curve, ciphertext, NewPlainDecrypterOpts(from))
if err != nil {
return nil, err
}
opts := NewPlainEncrypterOpts(MarshalUncompressed, to)
if ciphertext[0] == compressed02 || ciphertext[0] == compressed03 {
opts.pointMarshalMode = MarshalCompressed
}
return encodeCiphertext(opts, C1, c2, c3)
}
// ASN1Ciphertext2Plain utility method to convert ASN.1 encoding ciphertext to plain encoding format
func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error) {
if opts == nil {
opts = defaultEncrypterOpts
}
C1, c2, c3, err := parseCiphertextASN1(p256(), ciphertext)
if err != nil {
return nil, err
}
return encodeCiphertext(opts, C1, c2, c3)
}
// PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format
func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) {
C1, c2, c3, err := parseCiphertext(p256(), ciphertext, NewPlainDecrypterOpts(from))
if err != nil {
return nil, err
}
return encodingCiphertextASN1(C1, c2, c3)
}

374
sm2/sm2_pke_test.go Normal file
View File

@ -0,0 +1,374 @@
package sm2
import (
"bufio"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/hex"
"math/big"
"reflect"
"testing"
)
func TestSplicingOrder(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
plainText string
from ciphertextSplicingOrder
to ciphertextSplicingOrder
}{
// TODO: Add test cases.
{"less than 32 1", "encryption standard", C1C2C3, C1C3C2},
{"less than 32 2", "encryption standard", C1C3C2, C1C2C3},
{"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2},
{"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3},
{"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2},
{"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewPlainEncrypterOpts(MarshalUncompressed, tt.from))
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.from))
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
//Adjust splicing order
ciphertext, err = AdjustCiphertextSplicingOrder(ciphertext, tt.from, tt.to)
if err != nil {
t.Fatalf("adjust splicing order failed %v", err)
}
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.to))
if err != nil {
t.Fatalf("decrypt failed after adjust splicing order %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func TestEncryptDecryptASN1(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
key2 := new(PrivateKey)
key2.PrivateKey = *priv2
tests := []struct {
name string
plainText string
priv *PrivateKey
}{
// TODO: Add test cases.
{"less than 32", "encryption standard", priv},
{"equals 32", "encryption standard encryption ", priv},
{"long than 32", "encryption standard encryption standard", priv},
{"less than 32", "encryption standard", key2},
{"equals 32", "encryption standard encryption ", key2},
{"long than 32", "encryption standard encryption standard", key2},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
encrypterOpts := ASN1EncrypterOpts
ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
if err != nil {
t.Fatalf("%v encrypt failed %v", tt.priv.Curve.Params().Name, err)
}
plaintext, err := tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
if err != nil {
t.Fatalf("%v decrypt 1 failed %v", tt.priv.Curve.Params().Name, err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
plaintext, err = tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
if err != nil {
t.Fatalf("%v decrypt 2 failed %v", tt.priv.Curve.Params().Name, err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func TestPlainCiphertext2ASN1(t *testing.T) {
ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05")
_, err := PlainCiphertext2ASN1(append([]byte{0x30}, ciphertext...), C1C3C2)
if err == nil {
t.Fatalf("expected error")
}
_, err = PlainCiphertext2ASN1(ciphertext[:65], C1C3C2)
if err == nil {
t.Fatalf("expected error")
}
ciphertext[0] = 0x10
_, err = PlainCiphertext2ASN1(ciphertext, C1C3C2)
if err == nil {
t.Fatalf("expected error")
}
}
func TestAdjustCiphertextSplicingOrder(t *testing.T) {
ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05")
res, err := AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C3C2)
if err != nil || &res[0] != &ciphertext[0] {
t.Fatalf("should be same one")
}
_, err = AdjustCiphertextSplicingOrder(ciphertext[:65], C1C3C2, C1C2C3)
if err == nil {
t.Fatalf("expected error")
}
ciphertext[0] = 0x10
_, err = AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C2C3)
if err == nil {
t.Fatalf("expected error")
}
}
func TestCiphertext2ASN1(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
plainText string
}{
// TODO: Add test cases.
{"less than 32", "encryption standard"},
{"equals 32", "encryption standard encryption "},
{"long than 32", "encryption standard encryption standard"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext1, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
ciphertext, err := PlainCiphertext2ASN1(ciphertext1, C1C3C2)
if err != nil {
t.Fatalf("convert to ASN.1 failed %v", err)
}
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
ciphertext2, err := AdjustCiphertextSplicingOrder(ciphertext1, C1C3C2, C1C2C3)
if err != nil {
t.Fatalf("adjust order failed %v", err)
}
ciphertext, err = PlainCiphertext2ASN1(ciphertext2, C1C2C3)
if err != nil {
t.Fatalf("convert to ASN.1 failed %v", err)
}
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func TestCiphertextASN12Plain(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
plainText string
}{
// TODO: Add test cases.
{"less than 32", "encryption standard"},
{"equals 32", "encryption standard encryption "},
{"long than 32", "encryption standard encryption standard"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext, err := EncryptASN1(rand.Reader, &priv.PublicKey, []byte(tt.plainText))
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
ciphertext, err = ASN1Ciphertext2Plain(ciphertext, nil)
if err != nil {
t.Fatalf("convert to plain failed %v", err)
}
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, nil)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func TestEncryptWithInfinitePublicKey(t *testing.T) {
pub := new(ecdsa.PublicKey)
pub.Curve = P256()
pub.X = big.NewInt(0)
pub.Y = big.NewInt(0)
_, err := Encrypt(rand.Reader, pub, []byte("sm2 encryption standard"), nil)
if err == nil {
t.Fatalf("should be failed")
}
}
func TestEncryptEmptyPlaintext(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, nil, nil)
if err != nil || ciphertext != nil {
t.Fatalf("nil plaintext should return nil")
}
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte{}, nil)
if err != nil || ciphertext != nil {
t.Fatalf("empty plaintext should return nil")
}
}
func TestEncryptDecrypt(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
key2 := new(PrivateKey)
key2.PrivateKey = *priv2
tests := []struct {
name string
plainText string
priv *PrivateKey
}{
// TODO: Add test cases.
{"less than 32", "encryption standard", priv},
{"equals 32", "encryption standard encryption ", priv},
{"long than 32", "encryption standard encryption standard", priv},
{"less than 32", "encryption standard", key2},
{"equals 32", "encryption standard encryption ", key2},
{"long than 32", "encryption standard encryption standard", key2},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), nil)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err := Decrypt(tt.priv, ciphertext)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
// compress mode
encrypterOpts := NewPlainEncrypterOpts(MarshalCompressed, C1C3C2)
ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err = Decrypt(tt.priv, ciphertext)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
// hybrid mode
encrypterOpts = NewPlainEncrypterOpts(MarshalHybrid, C1C3C2)
ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err = Decrypt(tt.priv, ciphertext)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
plaintext, err = Decrypt(tt.priv, ciphertext)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func TestInvalidCiphertext(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
ciphertext []byte
}{
// TODO: Add test cases.
{errCiphertextTooShort.Error(), nil},
{errCiphertextTooShort.Error(), make([]byte, 65)},
{ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 96)...)},
{ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 97)...)},
{ErrDecryption.Error(), append([]byte{0x02}, make([]byte, 65)...)},
{ErrDecryption.Error(), append([]byte{0x30}, make([]byte, 97)...)},
{ErrDecryption.Error(), make([]byte, 97)},
}
for i, tt := range tests {
_, err := Decrypt(priv, tt.ciphertext)
if err.Error() != tt.name {
t.Fatalf("case %v, expected %v, got %v\n", i, tt.name, err.Error())
}
}
}
func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext []byte) {
r := bufio.NewReaderSize(rand.Reader, 1<<15)
priv, err := ecdsa.GenerateKey(curve, r)
if err != nil {
b.Fatal(err)
}
b.SetBytes(int64(len(plaintext)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil)
}
}
func BenchmarkEncryptNoMoreThan32_P256(b *testing.B) {
benchmarkEncrypt(b, elliptic.P256(), make([]byte, 31))
}
func BenchmarkEncryptNoMoreThan32_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), make([]byte, 31))
}
func BenchmarkEncrypt128_P256(b *testing.B) {
benchmarkEncrypt(b, elliptic.P256(), make([]byte, 128))
}
func BenchmarkEncrypt128_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), make([]byte, 128))
}
func BenchmarkEncrypt512_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), make([]byte, 512))
}
func BenchmarkEncrypt1K_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), make([]byte, 1024))
}
func BenchmarkEncrypt8K_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), make([]byte, 8*1024))
}