diff --git a/ecdh/ecdh.go b/ecdh/ecdh.go index 3648ca3..d2c115e 100644 --- a/ecdh/ecdh.go +++ b/ecdh/ecdh.go @@ -5,29 +5,15 @@ package ecdh import ( "crypto" "crypto/subtle" + "hash" "io" "sync" + + "github.com/emmansun/gmsm/kdf" + "github.com/emmansun/gmsm/sm3" ) type Curve interface { - // ECDH performs a ECDH exchange and returns the shared secret. - // - // For NIST curves, this performs ECDH as specified in SEC 1, Version 2.0, - // Section 3.3.1, and returns the x-coordinate encoded according to SEC 1, - // Version 2.0, Section 2.3.5. In particular, if the result is the point at - // infinity, ECDH returns an error. (Note that for NIST curves, that's only - // possible if the private key is the all-zero value.) - // - // For X25519, this performs ECDH as specified in RFC 7748, Section 6.1. If - // the result is the all-zero value, ECDH returns an error. - ECDH(local *PrivateKey, remote *PublicKey) ([]byte, error) - - // SM2MQV performs a SM2 specific style ECMQV exchange and return the shared secret. - SM2MQV(sLocal, eLocal *PrivateKey, sRemote, eRemote *PublicKey) (*PublicKey, error) - - // SM2SharedKey performs SM2 key derivation to generate shared keying data, the uv was generated by SM2MQV. - SM2SharedKey(isResponder bool, kenLen int, uv, sPub, sRemote *PublicKey, uid []byte, remoteUID []byte) ([]byte, error) - // GenerateKey generates a new PrivateKey from rand. GenerateKey(rand io.Reader) (*PrivateKey, error) @@ -53,6 +39,20 @@ type Curve interface { // selected public keys can cause ECDH to return an error. NewPublicKey(key []byte) (*PublicKey, error) + // ecdh performs a ECDH exchange and returns the shared secret. It's exposed + // as the PrivateKey.ECDH method. + // + // The private method also allow us to expand the ECDH interface with more + // methods in the future without breaking backwards compatibility. + ecdh(local *PrivateKey, remote *PublicKey) ([]byte, error) + + // sm2mqv performs a SM2 specific style ECMQV exchange and return the shared secret. + sm2mqv(sLocal, eLocal *PrivateKey, sRemote, eRemote *PublicKey) (*PublicKey, error) + + // sm2za ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA). + // Compliance with GB/T 32918.2-2016 5.5 + sm2za(md hash.Hash, pub *PublicKey, uid []byte) ([]byte, error) + // privateKeyToPublicKey converts a PrivateKey to a PublicKey. It's exposed // as the PrivateKey.PublicKey method. // @@ -99,6 +99,35 @@ func (k *PublicKey) Curve() Curve { return k.curve } +// SM2ZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA). +// Compliance with GB/T 32918.2-2016 5.5 +func (k *PublicKey) SM2ZA(md hash.Hash, uid []byte) ([]byte, error) { + return k.curve.sm2za(md, k, uid) +} + +// SM2SharedKey performs SM2 key derivation to generate shared keying data, the uv was generated by SM2MQV. +func (uv *PublicKey) SM2SharedKey(isResponder bool, kenLen int, sPub, sRemote *PublicKey, uid []byte, remoteUID []byte) ([]byte, error) { + var buffer [128]byte + copy(buffer[:], uv.publicKey[1:]) + peerZ, err := sRemote.SM2ZA(sm3.New(), remoteUID) + if err != nil { + return nil, err + } + z, err := sPub.SM2ZA(sm3.New(), uid) + if err != nil { + return nil, err + } + if isResponder { + copy(buffer[64:], peerZ) + copy(buffer[96:], z) + } else { + copy(buffer[64:], z) + copy(buffer[96:], peerZ) + } + + return kdf.Kdf(sm3.New(), buffer[:], kenLen), nil +} + // PrivateKey is an ECDH private key, usually kept secret. type PrivateKey struct { curve Curve @@ -109,6 +138,23 @@ type PrivateKey struct { publicKeyOnce sync.Once } +// ECDH performs a ECDH exchange and returns the shared secret. +// +// For NIST curves, this performs ECDH as specified in SEC 1, Version 2.0, +// Section 3.3.1, and returns the x-coordinate encoded according to SEC 1, +// Version 2.0, Section 2.3.5. The result is never the point at infinity. +// +// For X25519, this performs ECDH as specified in RFC 7748, Section 6.1. If +// the result is the all-zero value, ECDH returns an error. +func (k *PrivateKey) ECDH(remote *PublicKey) ([]byte, error) { + return k.curve.ecdh(k, remote) +} + +// SM2MQV performs a SM2 specific style ECMQV exchange and return the shared secret. +func (k *PrivateKey) SM2MQV(eLocal *PrivateKey, sRemote, eRemote *PublicKey) (*PublicKey, error) { + return k.curve.sm2mqv(k, eLocal, sRemote, eRemote) +} + // Bytes returns a copy of the encoding of the private key. func (k *PrivateKey) Bytes() []byte { // Copy the private key to a fixed size buffer that can get allocated on the diff --git a/ecdh/ecdh_test.go b/ecdh/ecdh_test.go index 8b2c3c9..e8b436d 100644 --- a/ecdh/ecdh_test.go +++ b/ecdh/ecdh_test.go @@ -64,11 +64,11 @@ func TestECDH(t *testing.T) { t.Error("encoded and decoded private keys are different") } - bobSecret, err := ecdh.P256().ECDH(bobKey, aliceKey.PublicKey()) + bobSecret, err := bobKey.ECDH(aliceKey.PublicKey()) if err != nil { t.Fatal(err) } - aliceSecret, err := ecdh.P256().ECDH(aliceKey, bobKey.PublicKey()) + aliceSecret, err := aliceKey.ECDH(bobKey.PublicKey()) if err != nil { t.Fatal(err) } @@ -97,12 +97,12 @@ func TestSM2MQV(t *testing.T) { t.Fatal(err) } - bobSecret, err := ecdh.P256().SM2MQV(bobSKey, bobEKey, aliceSKey.PublicKey(), aliceEKey.PublicKey()) + bobSecret, err := bobSKey.SM2MQV(bobEKey, aliceSKey.PublicKey(), aliceEKey.PublicKey()) if err != nil { t.Fatal(err) } - aliceSecret, err := ecdh.P256().SM2MQV(aliceSKey, aliceEKey, bobSKey.PublicKey(), bobEKey.PublicKey()) + aliceSecret, err := aliceSKey.SM2MQV(aliceEKey, bobSKey.PublicKey(), bobEKey.PublicKey()) if err != nil { t.Fatal(err) } @@ -131,12 +131,12 @@ func TestSM2SharedKey(t *testing.T) { t.Fatal(err) } - bobSecret, err := ecdh.P256().SM2MQV(bobSKey, bobEKey, aliceSKey.PublicKey(), aliceEKey.PublicKey()) + bobSecret, err := bobSKey.SM2MQV(bobEKey, aliceSKey.PublicKey(), aliceEKey.PublicKey()) if err != nil { t.Fatal(err) } - aliceSecret, err := ecdh.P256().SM2MQV(aliceSKey, aliceEKey, bobSKey.PublicKey(), bobEKey.PublicKey()) + aliceSecret, err := aliceSKey.SM2MQV(aliceEKey, bobSKey.PublicKey(), bobEKey.PublicKey()) if err != nil { t.Fatal(err) } @@ -145,12 +145,12 @@ func TestSM2SharedKey(t *testing.T) { t.Error("two SM2MQV computations came out different") } - bobKey, err := ecdh.P256().SM2SharedKey(true, 48, bobSecret, bobSKey.PublicKey(), aliceSKey.PublicKey(), []byte("Bob"), []byte("Alice")) + bobKey, err := bobSecret.SM2SharedKey(true, 48, bobSKey.PublicKey(), aliceSKey.PublicKey(), []byte("Bob"), []byte("Alice")) if err != nil { t.Fatal(err) } - aliceKey, err := ecdh.P256().SM2SharedKey(false, 48, aliceSecret, aliceSKey.PublicKey(), bobSKey.PublicKey(), []byte("Alice"), []byte("Bob")) + aliceKey, err := aliceSecret.SM2SharedKey(false, 48, aliceSKey.PublicKey(), bobSKey.PublicKey(), []byte("Alice"), []byte("Bob")) if err != nil { t.Fatal(err) } @@ -214,12 +214,12 @@ func TestSM2SharedKeyVectors(t *testing.T) { t.Fatal(err) } - bobSecret, err := ecdh.P256().SM2MQV(bobSKey, bobEKey, aliceSKey.PublicKey(), aliceEKey.PublicKey()) + bobSecret, err := bobSKey.SM2MQV(bobEKey, aliceSKey.PublicKey(), aliceEKey.PublicKey()) if err != nil { t.Fatal(err) } - aliceSecret, err := ecdh.P256().SM2MQV(aliceSKey, aliceEKey, bobSKey.PublicKey(), bobEKey.PublicKey()) + aliceSecret, err := aliceSKey.SM2MQV(aliceEKey, bobSKey.PublicKey(), bobEKey.PublicKey()) if err != nil { t.Fatal(err) } @@ -232,12 +232,12 @@ func TestSM2SharedKeyVectors(t *testing.T) { t.Errorf("%v shared secret is not expected.", i) } - bobKey, err := ecdh.P256().SM2SharedKey(true, kenLen, bobSecret, bobSKey.PublicKey(), aliceSKey.PublicKey(), responder, initiator) + bobKey, err := bobSecret.SM2SharedKey(true, kenLen, bobSKey.PublicKey(), aliceSKey.PublicKey(), responder, initiator) if err != nil { t.Fatal(err) } - aliceKey, err := ecdh.P256().SM2SharedKey(false, kenLen, aliceSecret, aliceSKey.PublicKey(), bobSKey.PublicKey(), initiator, responder) + aliceKey, err := aliceSecret.SM2SharedKey(false, kenLen, aliceSKey.PublicKey(), bobSKey.PublicKey(), initiator, responder) if err != nil { t.Fatal(err) } @@ -317,7 +317,7 @@ func BenchmarkECDH(b *testing.B) { if err != nil { b.Fatal(err) } - secret, err := curve.ECDH(key, peerPubKey) + secret, err := key.ECDH(peerPubKey) if err != nil { b.Fatal(err) } diff --git a/ecdh/sm2ec.go b/ecdh/sm2ec.go index 51a3236..0648e3c 100644 --- a/ecdh/sm2ec.go +++ b/ecdh/sm2ec.go @@ -10,8 +10,6 @@ import ( "github.com/emmansun/gmsm/internal/randutil" sm2ec "github.com/emmansun/gmsm/internal/sm2ec" "github.com/emmansun/gmsm/internal/subtle" - "github.com/emmansun/gmsm/kdf" - "github.com/emmansun/gmsm/sm3" ) type sm2Curve struct { @@ -101,7 +99,7 @@ func (c *sm2Curve) NewPublicKey(key []byte) (*PublicKey, error) { }, nil } -func (c *sm2Curve) ECDH(local *PrivateKey, remote *PublicKey) ([]byte, error) { +func (c *sm2Curve) ecdh(local *PrivateKey, remote *PublicKey) ([]byte, error) { p, err := c.newPoint().SetBytes(remote.publicKey) if err != nil { return nil, err @@ -122,7 +120,7 @@ func (c *sm2Curve) sm2avf(secret *PublicKey) []byte { return result[:] } -func (c *sm2Curve) SM2MQV(sLocal, eLocal *PrivateKey, sRemote, eRemote *PublicKey) (*PublicKey, error) { +func (c *sm2Curve) sm2mqv(sLocal, eLocal *PrivateKey, sRemote, eRemote *PublicKey) (*PublicKey, error) { // implicitSig: (sLocal + avf(eLocal.Pub) * ePriv) mod N x2 := c.sm2avf(eLocal.PublicKey()) t, err := sm2ec.ImplicitSig(sLocal.privateKey, eLocal.privateKey, x2) @@ -151,28 +149,6 @@ func (c *sm2Curve) SM2MQV(sLocal, eLocal *PrivateKey, sRemote, eRemote *PublicKe return c.NewPublicKey(p2.Bytes()) } -func (c *sm2Curve) SM2SharedKey(isResponder bool, kenLen int, uv, sPub, sRemote *PublicKey, uid []byte, remoteUID []byte) ([]byte, error) { - var buffer [128]byte - copy(buffer[:], uv.publicKey[1:]) - peerZ, err := c.sm2za(sm3.New(), sRemote, remoteUID) - if err != nil { - return nil, err - } - z, err := c.sm2za(sm3.New(), sPub, uid) - if err != nil { - return nil, err - } - if isResponder { - copy(buffer[64:], peerZ) - copy(buffer[96:], z) - } else { - copy(buffer[64:], z) - copy(buffer[96:], peerZ) - } - - return kdf.Kdf(sm3.New(), buffer[:], kenLen), nil -} - var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38} // CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA).