From c99ad27ce180a201d86eba52bf92d96cc8ea89c0 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Wed, 15 May 2024 08:28:47 +0800 Subject: [PATCH] kdf: share Z hash state #220 --- cfca/pkcs12_sm2.go | 5 ++- kdf/kdf.go | 41 ++++++++++++++++++------ kdf/kdf_64bit_test.go | 2 +- kdf/kdf_test.go | 11 ++++--- sm2/sm2.go | 5 ++- sm2/sm2_keyexchange.go | 3 +- sm2/sm2_legacy.go | 5 ++- sm2/sm2_test.go | 8 +++++ sm3/sm3.go | 23 ++++++++++++++ sm3/sm3_test.go | 71 ++++++++++++++++++++++++++++++++++++++++++ sm9/sm9.go | 7 ++--- sm9/sm9_test.go | 7 ++--- 12 files changed, 153 insertions(+), 35 deletions(-) diff --git a/cfca/pkcs12_sm2.go b/cfca/pkcs12_sm2.go index 1d894ed..cc738fb 100644 --- a/cfca/pkcs12_sm2.go +++ b/cfca/pkcs12_sm2.go @@ -8,7 +8,6 @@ import ( "fmt" "math/big" - "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/padding" "github.com/emmansun/gmsm/pkcs" "github.com/emmansun/gmsm/sm2" @@ -59,7 +58,7 @@ func ParseSM2(password, data []byte) (*sm2.PrivateKey, *smx509.Certificate, erro if !keys.EncryptedKey.Algorithm.Equal(oidSM4) && !keys.EncryptedKey.Algorithm.Equal(oidSM4CBC) { return nil, nil, fmt.Errorf("cfca: unsupported algorithm <%v>", keys.EncryptedKey.Algorithm) } - ivkey := kdf.Kdf(sm3.New(), password, 32) + ivkey := sm3.Kdf(password, 32) marshalledIV, err := asn1.Marshal(ivkey[:16]) if err != nil { return nil, nil, err @@ -91,7 +90,7 @@ func MarshalSM2(password []byte, key *sm2.PrivateKey, cert *smx509.Certificate) if len(password) == 0 { return nil, errors.New("cfca: invalid password") } - ivkey := kdf.Kdf(sm3.New(), password, 32) + ivkey := sm3.Kdf(password, 32) block, err := sm4.NewCipher(ivkey[16:]) if err != nil { return nil, err diff --git a/kdf/kdf.go b/kdf/kdf.go index 0cd2cee..cb12023 100644 --- a/kdf/kdf.go +++ b/kdf/kdf.go @@ -2,27 +2,48 @@ package kdf import ( + "encoding" "encoding/binary" "hash" ) // Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3. // ANSI-X9.63-KDF -func Kdf(md hash.Hash, z []byte, len int) []byte { - limit := uint64(len+md.Size()-1) / uint64(md.Size()) +func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte { + baseMD := newHash() + limit := uint64(keyLen+baseMD.Size()-1) / uint64(baseMD.Size()) if limit >= uint64(1<<32)-1 { panic("kdf: key length too long") } var countBytes [4]byte var ct uint32 = 1 var k []byte - for i := 0; i < int(limit); i++ { - binary.BigEndian.PutUint32(countBytes[:], ct) - md.Write(z) - md.Write(countBytes[:]) - k = md.Sum(k) - ct++ - md.Reset() + + marshaler, ok := baseMD.(encoding.BinaryMarshaler) + if limit == 1 || len(z) < baseMD.BlockSize() || !ok { + for i := 0; i < int(limit); i++ { + binary.BigEndian.PutUint32(countBytes[:], ct) + baseMD.Write(z) + baseMD.Write(countBytes[:]) + k = baseMD.Sum(k) + ct++ + baseMD.Reset() + } + } else { + baseMD.Write(z) + zstate, _ := marshaler.MarshalBinary() + for i := 0; i < int(limit); i++ { + md := newHash() + err := md.(encoding.BinaryUnmarshaler).UnmarshalBinary(zstate) + if err != nil { + panic(err) + } + binary.BigEndian.PutUint32(countBytes[:], ct) + md.Write(countBytes[:]) + k = md.Sum(k) + ct++ + } } - return k[:len] + + return k[:keyLen] } diff --git a/kdf/kdf_64bit_test.go b/kdf/kdf_64bit_test.go index 80372a8..baa1ac0 100644 --- a/kdf/kdf_64bit_test.go +++ b/kdf/kdf_64bit_test.go @@ -11,6 +11,6 @@ import ( // This case should be failed on 32bits system. func TestKdfPanic(t *testing.T) { shouldPanic(t, func() { - Kdf(sm3.New(), []byte("123456"), 1<<37) + Kdf(sm3.New, []byte("123456"), 1<<37) }) } diff --git a/kdf/kdf_test.go b/kdf/kdf_test.go index 294ee44..ae96f19 100644 --- a/kdf/kdf_test.go +++ b/kdf/kdf_test.go @@ -31,7 +31,7 @@ func TestKdf(t *testing.T) { for _, tt := range tests { wantBytes, _ := hex.DecodeString(tt.want) t.Run(tt.name, func(t *testing.T) { - if got := Kdf(tt.args.md, tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) { + if got := Kdf(sm3.New, tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) { t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want) } }) @@ -44,7 +44,7 @@ func TestKdfOldCase(t *testing.T) { expected := "006e30dae231b071dfad8aa379e90264491603" - result := Kdf(sm3.New(), append(x2.Bytes(), y2.Bytes()...), 19) + result := Kdf(sm3.New, append(x2.Bytes(), y2.Bytes()...), 19) resultStr := hex.EncodeToString(result) @@ -71,16 +71,17 @@ func BenchmarkKdf(b *testing.B) { {64, 32}, {64, 64}, {64, 128}, - {440, 32}, + {64, 256}, + {64, 512}, + {64, 1024}, } - sm3Hash := sm3.New() z := make([]byte, 512) for _, tt := range tests { b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - Kdf(sm3Hash, z[:tt.zLen], tt.kLen) + Kdf(sm3.New, z[:tt.zLen], tt.kLen) } }) } diff --git a/sm2/sm2.go b/sm2/sm2.go index c43af1a..15b5a47 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -25,7 +25,6 @@ import ( "github.com/emmansun/gmsm/internal/randutil" _sm2ec "github.com/emmansun/gmsm/internal/sm2ec" "github.com/emmansun/gmsm/internal/subtle" - "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/sm2/sm2ec" "github.com/emmansun/gmsm/sm3" "golang.org/x/crypto/cryptobyte" @@ -251,7 +250,7 @@ func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byt return nil, err } C2Bytes := C2.Bytes()[1:] - c2 := kdf.Kdf(sm3.New(), C2Bytes, len(msg)) + c2 := sm3.Kdf(C2Bytes, len(msg)) if subtle.ConstantTimeAllZero(c2) { retryCount++ if retryCount > maxRetryLimit { @@ -424,7 +423,7 @@ func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *Decryp } C2Bytes := C2.Bytes()[1:] msgLen := len(c2) - msg := kdf.Kdf(sm3.New(), C2Bytes, msgLen) + msg := sm3.Kdf(C2Bytes, msgLen) if subtle.ConstantTimeAllZero(c2) { return nil, ErrDecryption } diff --git a/sm2/sm2_keyexchange.go b/sm2/sm2_keyexchange.go index 7e84dd4..7831d26 100644 --- a/sm2/sm2_keyexchange.go +++ b/sm2/sm2_keyexchange.go @@ -7,7 +7,6 @@ import ( "io" "math/big" - "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/sm3" ) @@ -185,7 +184,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) { buffer = append(buffer, ke.z...) buffer = append(buffer, ke.peerZ...) } - return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil + return sm3.Kdf(buffer, ke.keyLength), nil } // avf is the associative value function. diff --git a/sm2/sm2_legacy.go b/sm2/sm2_legacy.go index 504345e..2cd8a98 100644 --- a/sm2/sm2_legacy.go +++ b/sm2/sm2_legacy.go @@ -11,7 +11,6 @@ import ( "strings" "github.com/emmansun/gmsm/internal/subtle" - "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/sm2/sm2ec" "github.com/emmansun/gmsm/sm3" "golang.org/x/crypto/cryptobyte" @@ -260,7 +259,7 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes()) //A5, calculate t=KDF(x2||y2, klen) - c2 := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + c2 := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) if subtle.ConstantTimeAllZero(c2) { retryCount++ if retryCount > maxRetryLimit { @@ -408,7 +407,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error curve := priv.Curve x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) msgLen := len(c2) - msg := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + msg := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) if subtle.ConstantTimeAllZero(c2) { return nil, ErrDecryption } diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index b3071db..448fe1f 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -851,3 +851,11 @@ func BenchmarkMoreThan32_P256(b *testing.B) { func BenchmarkMoreThan32_SM2(b *testing.B) { benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard") } + +func BenchmarkEncrypt512_SM2(b *testing.B) { + benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption s") +} + +func BenchmarkEncrypt1024_SM2(b *testing.B) { + benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption sencryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption s") +} diff --git a/sm3/sm3.go b/sm3/sm3.go index 10d52e8..450e507 100644 --- a/sm3/sm3.go +++ b/sm3/sm3.go @@ -211,3 +211,26 @@ func Sum(data []byte) [Size]byte { d.Write(data) return d.checkSum() } + +// Kdf key derivation function using SM3, compliance with GB/T 32918.4-2016 5.4.3. +func Kdf(z []byte, keyLen int) []byte { + limit := uint64(keyLen+Size-1) / uint64(Size) + if limit >= uint64(1<<32)-1 { + panic("sm3: key length too long") + } + var countBytes [4]byte + var ct uint32 = 1 + var k []byte + baseMD := new(digest) + baseMD.Reset() + baseMD.Write(z) + for i := 0; i < int(limit); i++ { + binary.BigEndian.PutUint32(countBytes[:], ct) + md := *baseMD + md.Write(countBytes[:]) + h := md.checkSum() + k = append(k, h[:]...) + ct++ + } + return k[:keyLen] +} diff --git a/sm3/sm3_test.go b/sm3/sm3_test.go index 782e6b3..f442735 100644 --- a/sm3/sm3_test.go +++ b/sm3/sm3_test.go @@ -9,6 +9,8 @@ import ( "fmt" "hash" "io" + "math/big" + "reflect" "testing" "golang.org/x/sys/cpu" @@ -403,6 +405,75 @@ func BenchmarkHash8K_SH256(b *testing.B) { benchmarkSize(benchSH256, b, 8192) } +func TestKdf(t *testing.T) { + type args struct { + z []byte + len int + } + tests := []struct { + name string + args args + want string + }{ + {"sm3 case 1", args{[]byte("emmansun"), 16}, "708993ef1388a0ae4245a19bb6c02554"}, + {"sm3 case 2", args{[]byte("emmansun"), 32}, "708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd4"}, + {"sm3 case 3", args{[]byte("emmansun"), 48}, "708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"}, + {"sm3 case 4", args{[]byte("708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"), 48}, "49cf14649f324a07e0d5bb2a00f7f05d5f5bdd6d14dff028e071327ec031104590eddb18f98b763e18bf382ff7c3875f"}, + {"sm3 case 5", args{[]byte("708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"), 128}, "49cf14649f324a07e0d5bb2a00f7f05d5f5bdd6d14dff028e071327ec031104590eddb18f98b763e18bf382ff7c3875f30277f3179baebd795e7853fa643fdf280d8d7b81a2ab7829f615e132ab376d32194cd315908d27090e1180ce442d9be99322523db5bfac40ac5acb03550f5c93e5b01b1d71f2630868909a6a1250edb"}, + } + for _, tt := range tests { + wantBytes, _ := hex.DecodeString(tt.want) + t.Run(tt.name, func(t *testing.T) { + if got := Kdf(tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) { + t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want) + } + }) + } +} + +func TestKdfOldCase(t *testing.T) { + x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16) + y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16) + + expected := "006e30dae231b071dfad8aa379e90264491603" + + result := Kdf(append(x2.Bytes(), y2.Bytes()...), 19) + + resultStr := hex.EncodeToString(result) + + if expected != resultStr { + t.Fatalf("expected %s, real value %s", expected, resultStr) + } +} + +func BenchmarkKdfWithSM3(b *testing.B) { + tests := []struct { + zLen int + kLen int + }{ + {32, 32}, + {32, 64}, + {32, 128}, + {64, 32}, + {64, 64}, + {64, 128}, + {64, 256}, + {64, 512}, + {64, 1024}, + {64, 1024*8}, + } + z := make([]byte, 512) + for _, tt := range tests { + b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Kdf(z[:tt.zLen], tt.kLen) + } + }) + } +} + /* func round1(a, b, c, d, e, f, g, h string, i int) { fmt.Printf("//Round %d\n", i+1) diff --git a/sm9/sm9.go b/sm9/sm9.go index 18b87e1..bd0aaa1 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -12,7 +12,6 @@ import ( "github.com/emmansun/gmsm/internal/bigmod" "github.com/emmansun/gmsm/internal/randutil" "github.com/emmansun/gmsm/internal/subtle" - "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm9/bn256" "golang.org/x/crypto/cryptobyte" @@ -317,7 +316,7 @@ func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, buffer = append(buffer, w.Marshal()...) buffer = append(buffer, uid...) - key = kdf.Kdf(sm3.New(), buffer, kLen) + key = sm3.Kdf(buffer, kLen) if !subtle.ConstantTimeAllZero(key) { break } @@ -403,7 +402,7 @@ func UnwrapKey(priv *EncryptPrivateKey, uid []byte, cipher *bn256.G1, kLen int) buffer = append(buffer, w.Marshal()...) buffer = append(buffer, uid...) - key := kdf.Kdf(sm3.New(), buffer, kLen) + key := sm3.Kdf(buffer, kLen) if subtle.ConstantTimeAllZero(key) { return nil, ErrDecryption } @@ -685,7 +684,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) { buffer = append(buffer, ke.g2.Marshal()...) buffer = append(buffer, ke.g3.Marshal()...) - return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil + return sm3.Kdf(buffer, ke.keyLength), nil } func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA *bn256.G1) (*bn256.G1, []byte, error) { diff --git a/sm9/sm9_test.go b/sm9/sm9_test.go index b72558b..8698d43 100644 --- a/sm9/sm9_test.go +++ b/sm9/sm9_test.go @@ -8,7 +8,6 @@ import ( "github.com/emmansun/gmsm/internal/bigmod" "github.com/emmansun/gmsm/internal/subtle" - "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm9/bn256" "golang.org/x/crypto/cryptobyte" @@ -563,7 +562,7 @@ func TestWrapKeySM9Sample(t *testing.T) { buffer = append(buffer, w.Marshal()...) buffer = append(buffer, uid...) - key := kdf.Kdf(sm3.New(), buffer, 32) + key := sm3.Kdf(buffer, 32) if hex.EncodeToString(key) != expectedKey { t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key)) @@ -629,7 +628,7 @@ func TestEncryptSM9Sample(t *testing.T) { buffer = append(buffer, w.Marshal()...) buffer = append(buffer, uid...) - key := kdf.Kdf(sm3.New(), buffer, len(plaintext)+32) + key := sm3.Kdf(buffer, len(plaintext)+32) if hex.EncodeToString(key) != expectedKey { t.Errorf("not expected key") @@ -697,7 +696,7 @@ func TestEncryptSM9SampleBlockMode(t *testing.T) { buffer = append(buffer, w.Marshal()...) buffer = append(buffer, uid...) - key := kdf.Kdf(sm3.New(), buffer, 16+32) + key := sm3.Kdf(buffer, 16+32) if hex.EncodeToString(key) != expectedKey { t.Errorf("not expected key, expected %v, got %x\n", expectedKey, key)