From a1cb0a26169cf78e8813b8d935d01aa311fb3f6a Mon Sep 17 00:00:00 2001 From: emmansun Date: Mon, 15 Feb 2021 10:36:28 +0800 Subject: [PATCH] MAGIC - refactor --- sm2/sm2.go | 106 ++++++++++++++++++++++++++++++-------------- sm2/sm2_test.go | 31 ++++++++++++- smx509/x509_test.go | 2 +- 3 files changed, 102 insertions(+), 37 deletions(-) diff --git a/sm2/sm2.go b/sm2/sm2.go index 4a2cd49..7175463 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -47,6 +47,35 @@ type ecdsaSignature struct { R, S *big.Int } +type pointMarshalMode byte + +const ( + //MarshalUncompressed uncompressed mashal mode + MarshalUncompressed pointMarshalMode = iota + //MarshalCompressed compressed mashal mode + MarshalCompressed + //MarshalMixed mixed mashal mode + MarshalMixed +) + +// EncrypterOpts encryption options +type EncrypterOpts struct { + PointMarshalMode pointMarshalMode +} + +func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte { + switch mode { + case MarshalCompressed: + return point2CompressedBytes(curve, x, y) + case MarshalMixed: + return point2MixedBytes(curve, x, y) + default: + return point2UncompressedBytes(curve, x, y) + } +} + +var defaultEncrypterOpts = EncrypterOpts{MarshalUncompressed} + // Sign signs digest with priv, reading randomness from rand. The opts argument // is not currently used but, in keeping with the crypto.Signer interface, // should be the hash function used to digest the message. @@ -73,6 +102,12 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er return asn1.Marshal(ecdsaSignature{r, s}) } +// Decrypt decrypts msg. The opts argument should be appropriate for +// the primitive used. +func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) { + return Decrypt(priv, msg) +} + var ( one = new(big.Int).SetInt64(1) initonce sync.Once @@ -106,17 +141,17 @@ func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) /////////////////////////////////////////////////////////////////////////////////// func kdf(z []byte, len int) ([]byte, bool) { limit := (len + sm3.Size - 1) >> sm3.SizeBitSize - sm3Hasher := sm3.New() + md := sm3.New() var countBytes [4]byte var ct uint32 = 1 k := make([]byte, len+sm3.Size-1) for i := 0; i < limit; i++ { binary.BigEndian.PutUint32(countBytes[:], ct) - sm3Hasher.Write(z) - sm3Hasher.Write(countBytes[:]) - copy(k[i*sm3.Size:], sm3Hasher.Sum(nil)) + md.Write(z) + md.Write(countBytes[:]) + copy(k[i*sm3.Size:], md.Sum(nil)) ct++ - sm3Hasher.Reset() + md.Reset() } for i := 0; i < len; i++ { if k[i] != 0 { @@ -127,17 +162,20 @@ func kdf(z []byte, len int) ([]byte, bool) { } func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte { - hasher := sm3.New() - hasher.Write(toBytes(curve, x2)) - hasher.Write(msg) - hasher.Write(toBytes(curve, y2)) - return hasher.Sum(nil) + md := sm3.New() + md.Write(toBytes(curve, x2)) + md.Write(msg) + md.Write(toBytes(curve, y2)) + return md.Sum(nil) } // Encrypt sm2 encrypt implementation -func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) { +func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) { curve := pub.Curve msgLen := len(msg) + if opts == nil { + opts = &defaultEncrypterOpts + } for { //A1, generate random k k, err := randFieldElement(curve, random) @@ -147,7 +185,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) //A2, calculate C1 = k * G x1, y1 := curve.ScalarBaseMult(k.Bytes()) - c1 := point2UncompressedBytes(curve, x1, y1) + c1 := opts.PointMarshalMode.mashal(curve, x1, y1) //A3, skipped //A4, calculate k * P (point of Public Key) @@ -362,26 +400,26 @@ func Sign(rand io.Reader, priv *ecdsa.PrivateKey, hash []byte) (r, s *big.Int, e var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38} -// CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA) -func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) { +// calculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA) +func calculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) { uidLen := len(uid) if uidLen >= 0x2000 { return nil, errors.New("the uid is too long") } entla := uint16(uidLen) << 3 - hasher := sm3.New() - hasher.Write([]byte{byte(entla >> 8), byte(entla)}) + md := sm3.New() + md.Write([]byte{byte(entla >> 8), byte(entla)}) if uidLen > 0 { - hasher.Write(uid) + md.Write(uid) } a := new(big.Int).Sub(pub.Params().P, big.NewInt(3)) - hasher.Write(toBytes(pub.Curve, a)) - hasher.Write(toBytes(pub.Curve, pub.Params().B)) - hasher.Write(toBytes(pub.Curve, pub.Params().Gx)) - hasher.Write(toBytes(pub.Curve, pub.Params().Gy)) - hasher.Write(toBytes(pub.Curve, pub.X)) - hasher.Write(toBytes(pub.Curve, pub.Y)) - return hasher.Sum(nil), nil + md.Write(toBytes(pub.Curve, a)) + md.Write(toBytes(pub.Curve, pub.Params().B)) + md.Write(toBytes(pub.Curve, pub.Params().Gx)) + md.Write(toBytes(pub.Curve, pub.Params().Gy)) + md.Write(toBytes(pub.Curve, pub.X)) + md.Write(toBytes(pub.Curve, pub.Y)) + return md.Sum(nil), nil } // SignWithSM2 follow sm2 dsa standards for hash part @@ -389,15 +427,15 @@ func SignWithSM2(rand io.Reader, priv *ecdsa.PrivateKey, uid, msg []byte) (r, s if len(uid) == 0 { uid = defaultUID } - za, err := CalculateZA(&priv.PublicKey, uid) + za, err := calculateZA(&priv.PublicKey, uid) if err != nil { return nil, nil, err } - hasher := sm3.New() - hasher.Write(za) - hasher.Write(msg) + md := sm3.New() + md.Write(za) + md.Write(msg) - return Sign(rand, priv, hasher.Sum(nil)) + return Sign(rand, priv, md.Sum(nil)) } // Verify verifies the signature in r, s of hash using the public key, pub. Its @@ -442,14 +480,14 @@ func VerifyWithSM2(pub *ecdsa.PublicKey, uid, msg []byte, r, s *big.Int) bool { if len(uid) == 0 { uid = defaultUID } - za, err := CalculateZA(pub, uid) + za, err := calculateZA(pub, uid) if err != nil { return false } - hasher := sm3.New() - hasher.Write(za) - hasher.Write(msg) - return Verify(pub, hasher.Sum(nil), r, s) + md := sm3.New() + md.Write(za) + md.Write(msg) + return Verify(pub, md.Sum(nil), r, s) } type zr struct { diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index 4806153..e69d51c 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -43,7 +43,7 @@ func Test_encryptDecrypt(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText)) + ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil) if err != nil { t.Fatalf("encrypt failed %v", err) } @@ -54,6 +54,33 @@ func Test_encryptDecrypt(t *testing.T) { if !reflect.DeepEqual(string(plaintext), tt.plainText) { t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) } + // compress mode + encrypterOpts := EncrypterOpts{MarshalCompressed} + ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + plaintext, err = Decrypt(priv, ciphertext) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + + // mixed mode + encrypterOpts = EncrypterOpts{MarshalMixed} + ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + plaintext, err = Decrypt(priv, ciphertext) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } }) } } @@ -87,7 +114,7 @@ func Test_signVerify(t *testing.T) { func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext string) { for i := 0; i < b.N; i++ { priv, _ := ecdsa.GenerateKey(curve, rand.Reader) - Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext)) + Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil) } } diff --git a/smx509/x509_test.go b/smx509/x509_test.go index 6434333..ab6d381 100644 --- a/smx509/x509_test.go +++ b/smx509/x509_test.go @@ -118,7 +118,7 @@ func TestParsePKIXPublicKey(t *testing.T) { t.Fatal(err) } pub1 := pub.(*ecdsa.PublicKey) - encrypted, err := sm2.Encrypt(rand.Reader, pub1, []byte("testfile")) + encrypted, err := sm2.Encrypt(rand.Reader, pub1, []byte("testfile"), nil) if err != nil { t.Fatal(err) }