diff --git a/sm2/sm2.go b/sm2/sm2.go index 2bfbdbd..c27d9f3 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -1,13 +1,19 @@ package sm2 import ( + "crypto" + "crypto/aes" + "crypto/cipher" "crypto/ecdsa" "crypto/elliptic" + "crypto/sha512" + "encoding/asn1" "encoding/binary" "errors" "fmt" "io" "math/big" + "strings" "github.com/emmansun/gmsm/sm3" ) @@ -20,6 +26,27 @@ const ( mixed07 byte = 0x07 ) +// PrivateKey represents an ECDSA private key. +type PrivateKey struct { + ecdsa.PrivateKey +} + +// Sign signs digest with priv, reading randomness from rand. The opts argument +// is not currently used but, in keeping with the crypto.Signer interface, +// should be the hash function used to digest the message. +// +// This method implements crypto.Signer, which is an interface to support keys +// where the private part is kept in, for example, a hardware module. Common +// uses should use the Sign function in this package directly. +func (priv *PrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + r, s, err := Sign(rand, &priv.PrivateKey, digest) + if err != nil { + return nil, err + } + + return asn1.Marshal(ecdsaSignature{r, s}) +} + ///////////////// below code ship from golan crypto/ecdsa //////////////////// var one = new(big.Int).SetInt64(1) @@ -111,8 +138,23 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) } } +// GenerateKey generates a public and private key pair. +func GenerateKey(rand io.Reader) (*PrivateKey, error) { + c := P256() + k, err := randFieldElement(c, rand) + if err != nil { + return nil, err + } + + priv := new(PrivateKey) + priv.PublicKey.Curve = c + priv.D = k + priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes()) + return priv, nil +} + // Decrypt sm2 decrypt implementation -func Decrypt(priv *ecdsa.PrivateKey, ciphertext []byte) ([]byte, error) { +func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { ciphertextLen := len(ciphertext) if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size { return nil, errors.New("invalid ciphertext length") @@ -153,3 +195,209 @@ func Decrypt(priv *ecdsa.PrivateKey, ciphertext []byte) ([]byte, error) { return msg, nil } + +// hashToInt converts a hash value to an integer. There is some disagreement +// about how this is done. [NSA] suggests that this is done in the obvious +// manner, but [SECG] truncates the hash to the bit-length of the curve order +// first. We follow [SECG] because that's what OpenSSL does. Additionally, +// OpenSSL right shifts excess bits from the number if the hash is too large +// and we mirror that too. +func hashToInt(hash []byte, c elliptic.Curve) *big.Int { + orderBits := c.Params().N.BitLen() + orderBytes := (orderBits + 7) / 8 + if len(hash) > orderBytes { + hash = hash[:orderBytes] + } + + ret := new(big.Int).SetBytes(hash) + excess := len(hash)*8 - orderBits + if excess > 0 { + ret.Rsh(ret, uint(excess)) + } + return ret +} + +const ( + aesIV = "IV for ECDSA CTR" +) + +var errZeroParam = errors.New("zero parameter") + +// Sign signs a hash (which should be the result of hashing a larger message) +// using the private key, priv. If the hash is longer than the bit-length of the +// private key's curve order, the hash will be truncated to that length. It +// returns the signature as a pair of integers. The security of the private key +// depends on the entropy of rand. +func Sign(rand io.Reader, priv *ecdsa.PrivateKey, hash []byte) (r, s *big.Int, err error) { + if !strings.EqualFold(priv.Params().Name, P256().Params().Name) { + return ecdsa.Sign(rand, priv, hash) + } + maybeReadByte(rand) + + // Get min(log2(q) / 2, 256) bits of entropy from rand. + entropylen := (priv.Curve.Params().BitSize + 7) / 16 + if entropylen > 32 { + entropylen = 32 + } + entropy := make([]byte, entropylen) + _, err = io.ReadFull(rand, entropy) + if err != nil { + return + } + + // Initialize an SHA-512 hash context; digest ... + md := sha512.New() + md.Write(priv.D.Bytes()) // the private key, + md.Write(entropy) // the entropy, + md.Write(hash) // and the input hash; + key := md.Sum(nil)[:32] // and compute ChopMD-256(SHA-512), + // which is an indifferentiable MAC. + + // Create an AES-CTR instance to use as a CSPRNG. + block, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + + // Create a CSPRNG that xors a stream of zeros with + // the output of the AES-CTR instance. + csprng := cipher.StreamReader{ + R: zeroReader, + S: cipher.NewCTR(block, []byte(aesIV)), + } + + // See [NSA] 3.4.1 + c := priv.PublicKey.Curve + N := c.Params().N + if N.Sign() == 0 { + return nil, nil, errZeroParam + } + var k *big.Int + e := hashToInt(hash, c) + for { + for { + k, err = randFieldElement(c, csprng) + if err != nil { + r = nil + return + } + + r, _ = priv.Curve.ScalarBaseMult(k.Bytes()) // (x, y) = k*G + r.Add(r, e) // r = x + e + r.Mod(r, N) // r = (x + e) mod N + if r.Sign() != 0 { + t := new(big.Int).Add(r, k) + if t.Cmp(N) != 0 { // if r != 0 && (r + k) != N then ok + break + } + } + } + s = new(big.Int).Mul(priv.D, r) + s = new(big.Int).Sub(k, s) + dp1 := new(big.Int).Add(priv.D, one) + dp1Inv := new(big.Int).ModInverse(dp1, N) + s.Mul(s, dp1Inv) + s.Mod(s, N) // N != 0 + if s.Sign() != 0 { + break + } + } + + return +} + +var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38} + +// CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA) +func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) { + uidLen := len(uid) + if uidLen >= 0x2000 { + return nil, errors.New("the uid is too long") + } + entla := uint16(uidLen) << 3 + hasher := sm3.New() + hasher.Write([]byte{byte(entla >> 8), byte(entla)}) + if uidLen > 0 { + hasher.Write(uid) + } + a := new(big.Int).Sub(pub.Params().P, big.NewInt(3)) + hasher.Write(toBytes(pub.Curve, a)) + hasher.Write(toBytes(pub.Curve, pub.Params().B)) + hasher.Write(toBytes(pub.Curve, pub.Params().Gx)) + hasher.Write(toBytes(pub.Curve, pub.Params().Gy)) + hasher.Write(toBytes(pub.Curve, pub.X)) + hasher.Write(toBytes(pub.Curve, pub.Y)) + return hasher.Sum(nil), nil +} + +// SignWithSM2 follow sm2 dsa standards for hash part +func SignWithSM2(rand io.Reader, priv *ecdsa.PrivateKey, uid, msg []byte) (r, s *big.Int, err error) { + za, err := CalculateZA(&priv.PublicKey, uid) + if err != nil { + return nil, nil, err + } + hasher := sm3.New() + hasher.Write(za) + hasher.Write(msg) + + return Sign(rand, priv, hasher.Sum(nil)) +} + +// Verify verifies the signature in r, s of hash using the public key, pub. Its +// return value records whether the signature is valid. +func Verify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool { + if strings.EqualFold(pub.Params().Name, P256().Params().Name) { + c := pub.Curve + N := c.Params().N + + if r.Sign() <= 0 || s.Sign() <= 0 { + return false + } + if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { + return false + } + e := hashToInt(hash, c) + t := new(big.Int).Add(r, s) + t.Mod(t, N) + if t.Sign() == 0 { + return false + } + + var x *big.Int + x1, y1 := c.ScalarBaseMult(s.Bytes()) + x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) + x, _ = c.Add(x1, y1, x2, y2) + + x.Add(x, e) + x.Mod(x, N) + return x.Cmp(r) == 0 + } + return ecdsa.Verify(pub, hash, r, s) +} + +// VerifyWithSM2 verifies the signature in r, s of hash using the public key, pub. Its +// return value records whether the signature is valid. +func VerifyWithSM2(pub *ecdsa.PublicKey, uid, msg []byte, r, s *big.Int) bool { + za, err := CalculateZA(pub, uid) + if err != nil { + return false + } + hasher := sm3.New() + hasher.Write(za) + hasher.Write(msg) + return Verify(pub, hasher.Sum(nil), r, s) +} + +type zr struct { + io.Reader +} + +// Read replaces the contents of dst with zeros. +func (z *zr) Read(dst []byte) (n int, err error) { + for i := range dst { + dst[i] = 0 + } + return len(dst), nil +} + +var zeroReader = &zr{} diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index eafa350..4806153 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -8,6 +8,8 @@ import ( "math/big" "reflect" "testing" + + "github.com/emmansun/gmsm/sm3" ) func Test_kdf(t *testing.T) { @@ -29,7 +31,7 @@ func Test_kdf(t *testing.T) { } func Test_encryptDecrypt(t *testing.T) { - priv, _ := ecdsa.GenerateKey(P256(), rand.Reader) + priv, _ := GenerateKey(rand.Reader) tests := []struct { name string plainText string @@ -56,6 +58,32 @@ func Test_encryptDecrypt(t *testing.T) { } } +func Test_signVerify(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + tests := []struct { + name string + plainText string + }{ + // TODO: Add test cases. + {"less than 32", "encryption standard"}, + {"equals 32", "encryption standard encryption "}, + {"long than 32", "encryption standard encryption standard"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash := sm3.Sum([]byte(tt.plainText)) + r, s, err := Sign(rand.Reader, &priv.PrivateKey, hash[:]) + if err != nil { + t.Fatalf("sign failed %v", err) + } + result := Verify(&priv.PublicKey, hash[:], r, s) + if !result { + t.Fatal("verify failed") + } + }) + } +} + func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext string) { for i := 0; i < b.N; i++ { priv, _ := ecdsa.GenerateKey(curve, rand.Reader) diff --git a/sm2/util.go b/sm2/util.go index 5e54bc1..c233113 100644 --- a/sm2/util.go +++ b/sm2/util.go @@ -4,8 +4,10 @@ import ( "crypto/elliptic" "errors" "fmt" + "io" "math/big" "strings" + "sync" ) var zero = big.NewInt(0) @@ -113,3 +115,29 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e } return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format) } + +var ( + closedChanOnce sync.Once + closedChan chan struct{} +) + +// maybeReadByte reads a single byte from r with ~50% probability. This is used +// to ensure that callers do not depend on non-guaranteed behaviour, e.g. +// assuming that rsa.GenerateKey is deterministic w.r.t. a given random stream. +// +// This does not affect tests that pass a stream of fixed bytes as the random +// source (e.g. a zeroReader). +func maybeReadByte(r io.Reader) { + closedChanOnce.Do(func() { + closedChan = make(chan struct{}) + close(closedChan) + }) + + select { + case <-closedChan: + return + case <-closedChan: + var buf [1]byte + r.Read(buf[:]) + } +} diff --git a/sm2/x509.go b/sm2/x509.go index ea2347e..4a884cd 100644 --- a/sm2/x509.go +++ b/sm2/x509.go @@ -29,6 +29,12 @@ type pkcs1PublicKey struct { E int } +type dsaSignature struct { + R, S *big.Int +} + +type ecdsaSignature dsaSignature + // http://gmssl.org/docs/oid.html var ( oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} diff --git a/sm2/x509_test.go b/sm2/x509_test.go index f433b8d..a7dc2ee 100644 --- a/sm2/x509_test.go +++ b/sm2/x509_test.go @@ -3,8 +3,12 @@ package sm2 import ( "crypto/ecdsa" "crypto/rand" + "encoding/asn1" + "encoding/base64" + "encoding/hex" "encoding/pem" "errors" + "fmt" "strings" "testing" ) @@ -15,6 +19,14 @@ MFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAELfjZP28bYfGSvbODYlXiB5bcoXE+ -----END PUBLIC KEY----- ` +const publicKeyPemFromAliKmsForSign = `-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAERrsLH25zLm2LIo6tivZM9afLprSX +6TCKAmQJArAO7VOtZyW4PQwfaTsUIF7IXEFG4iI8bNuTQwMykUzLu2ypEA== +-----END PUBLIC KEY----- +` +const hashBase64 = `Zsfw9GLu7dnR8tRr3BDk4kFnxIdc8veiKX2gK49LqOA=` +const signature = `MEUCIHV5hOCgYzlO4HkrUhct1Cc8BeKmbXNP+ASje5rGOcCYAiEA2XOajXo3/IihtCEJmNpImtWw3uHIy5CX5TIxit7V0gQ=` + func getPublicKey(pemContent []byte) (interface{}, error) { block, _ := pem.Decode(pemContent) if block == nil { @@ -23,16 +35,41 @@ func getPublicKey(pemContent []byte) (interface{}, error) { return ParsePKIXPublicKey(block.Bytes) } +func TestSignByAliVerifyAtLocal(t *testing.T) { + var rs = &ecdsaSignature{} + dig, err := base64.StdEncoding.DecodeString(signature) + if err != nil { + t.Fatal(err) + } + rest, err := asn1.Unmarshal(dig, rs) + if err != nil { + t.Fatal(err) + } + if len(rest) != 0 { + t.Errorf("rest len=%d", len(rest)) + } + + fmt.Printf("r=%s, s=%s\n", hex.EncodeToString(rs.R.Bytes()), hex.EncodeToString(rs.S.Bytes())) + pub, err := getPublicKey([]byte(publicKeyPemFromAliKmsForSign)) + pub1 := pub.(*ecdsa.PublicKey) + hashValue, _ := base64.StdEncoding.DecodeString(hashBase64) + result := Verify(pub1, hashValue, rs.R, rs.S) + if !result { + t.Error("Verify fail") + } +} + func TestParsePKIXPublicKey(t *testing.T) { pub, err := getPublicKey([]byte(publicKeyPemFromAliKms)) if err != nil { t.Fatal(err) } pub1 := pub.(*ecdsa.PublicKey) - _, err = Encrypt(rand.Reader, pub1, []byte("testfile")) + encrypted, err := Encrypt(rand.Reader, pub1, []byte("testfile")) if err != nil { t.Fatal(err) } + fmt.Printf("encrypted=%s\n", base64.StdEncoding.EncodeToString(encrypted)) } func TestMarshalPKIXPublicKey(t *testing.T) { diff --git a/sm3/sm3_test.go b/sm3/sm3_test.go index 6b7afcf..3f81bfb 100644 --- a/sm3/sm3_test.go +++ b/sm3/sm3_test.go @@ -3,6 +3,7 @@ package sm3 import ( "bytes" "encoding" + "encoding/base64" "fmt" "hash" "io" @@ -23,7 +24,9 @@ var golden = []sm3Test{ func TestGolden(t *testing.T) { for i := 0; i < len(golden); i++ { g := golden[i] - s := fmt.Sprintf("%x", Sum([]byte(g.in))) + h := Sum([]byte(g.in)) + s := fmt.Sprintf("%x", h) + fmt.Printf("%s\n", base64.StdEncoding.EncodeToString(h[:])) if s != g.out { t.Fatalf("SM3 function: sm3(%s) = %s want %s", g.in, s, g.out) }