diff --git a/sm2/p256.go b/sm2/p256.go new file mode 100644 index 0000000..c581398 --- /dev/null +++ b/sm2/p256.go @@ -0,0 +1,32 @@ +package sm2 + +import ( + "crypto/elliptic" + "math/big" + "sync" +) + +type p256Curve struct { + *elliptic.CurveParams +} + +var ( + p256 p256Curve + initonce sync.Once +) + +func initP256() { + p256.CurveParams = &elliptic.CurveParams{Name: "P-256/SM2"} + p256.P, _ = new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16) + p256.N, _ = new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123", 16) + p256.B, _ = new(big.Int).SetString("28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93", 16) + p256.Gx, _ = new(big.Int).SetString("32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", 16) + p256.Gy, _ = new(big.Int).SetString("BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", 16) + p256.BitSize = 256 +} + +// P256 init and return the singleton +func P256() elliptic.Curve { + initonce.Do(initP256) + return p256 +} diff --git a/sm2/sm2.go b/sm2/sm2.go index b6027d3..ad2be72 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -12,11 +12,11 @@ import ( ) const ( - Uncompressed byte = 0x04 - Compressed_02 byte = 0x02 - Compressed_03 byte = 0x03 - Mixed_06 byte = 0x06 - Mixed_07 byte = 0x07 + uncompressed byte = 0x04 + compressed02 byte = 0x02 + compressed03 byte = 0x03 + mixed06 byte = 0x06 + mixed07 byte = 0x07 ) ///////////////// below code ship from golan crypto/ecdsa //////////////////// @@ -41,7 +41,7 @@ func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) /////////////////////////////////////////////////////////////////////////////////// func kdf(z []byte, len int) ([]byte, bool) { - limit := (len + sm3.Size - 1) / sm3.Size + limit := (len + sm3.Size - 1) >> sm3.SizeBitSize sm3Hasher := sm3.New() var countBytes [4]byte var ct uint32 = 1 @@ -121,9 +121,6 @@ func Decrypt(priv *ecdsa.PrivateKey, ciphertext []byte) ([]byte, error) { if err != nil { return nil, err } - if !curve.IsOnCurve(x1, y1) { - return nil, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name) - } //B2 is ignored //B3, calculate x2, y2 diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index 5e5c8ae..20b43f1 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -2,7 +2,6 @@ package sm2 import ( "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" "encoding/hex" "math/big" @@ -29,7 +28,7 @@ func Test_kdf(t *testing.T) { } func Test_encryptDecrypt(t *testing.T) { - priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + priv, _ := ecdsa.GenerateKey(P256(), rand.Reader) tests := []struct { name string plainText string @@ -55,3 +54,18 @@ func Test_encryptDecrypt(t *testing.T) { }) } } + +func benchmarkEncrypt(b *testing.B, plaintext string) { + for i := 0; i < b.N; i++ { + priv, _ := ecdsa.GenerateKey(P256(), rand.Reader) + Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext)) + } +} + +func BenchmarkLessThan32(b *testing.B) { + benchmarkEncrypt(b, "encryption standard") +} + +func BenchmarkMoreThan32(b *testing.B) { + benchmarkEncrypt(b, "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard") +} diff --git a/sm2/util.go b/sm2/util.go index c8e6244..4fc0291 100644 --- a/sm2/util.go +++ b/sm2/util.go @@ -29,9 +29,9 @@ func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { buffer := make([]byte, (curve.Params().BitSize+7)>>3+1) copy(buffer[1:], toBytes(curve, x)) if getLastBitOfY(x, y) > 0 { - buffer[0] = Compressed_03 + buffer[0] = compressed03 } else { - buffer[0] = Compressed_02 + buffer[0] = compressed02 } return buffer } @@ -39,9 +39,9 @@ func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte { buffer := elliptic.Marshal(curve, x, y) if getLastBitOfY(x, y) > 0 { - buffer[0] = Mixed_07 + buffer[0] = mixed07 } else { - buffer[0] = Mixed_06 + buffer[0] = mixed06 } return buffer } @@ -82,14 +82,17 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e format := bytes[0] byteLen := (curve.Params().BitSize + 7) >> 3 switch format { - case Uncompressed: + case uncompressed, mixed06, mixed07: // what's the mixed format purpose? if len(bytes) < 1+byteLen*2 { return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes)) } x := toPointXY(bytes[1 : 1+byteLen]) y := toPointXY(bytes[1+byteLen : 1+byteLen*2]) + if !curve.IsOnCurve(x, y) { + return nil, nil, 0, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name) + } return x, y, 1 + byteLen*2, nil - case Compressed_02, Compressed_03: + case compressed02, compressed03: if len(bytes) < 1+byteLen { return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes)) } @@ -101,20 +104,12 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e return nil, nil, 0, err } - if (getLastBitOfY(x, y) > 0 && format == Compressed_02) || (getLastBitOfY(x, y) == 0 && format == Compressed_03) { + if (getLastBitOfY(x, y) > 0 && format == compressed02) || (getLastBitOfY(x, y) == 0 && format == compressed03) { y.Sub(curve.Params().P, y) } return x, y, 1 + byteLen, nil } return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name) - case Mixed_06, Mixed_07: - // what's the mixed format purpose? - if len(bytes) < 1+byteLen*2 { - return nil, nil, 0, fmt.Errorf("invalid mixed bytes length %d", len(bytes)) - } - x := toPointXY(bytes[1 : 1+byteLen]) - y := toPointXY(bytes[1+byteLen : 1+byteLen*2]) - return x, y, 1 + byteLen*2, nil } return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format) } diff --git a/sm3/sm3.go b/sm3/sm3.go index d2f8409..91b507d 100644 --- a/sm3/sm3.go +++ b/sm3/sm3.go @@ -9,6 +9,9 @@ import ( // Size the size of a SM3 checksum in bytes. const Size int = 32 +// SizeBitSize the bit size of Size. +const SizeBitSize = 5 + // BlockSize the blocksize of SM3 in bytes. const BlockSize int = 64