kdf: share Z hash state #220

This commit is contained in:
Sun Yimin 2024-05-15 08:28:47 +08:00 committed by GitHub
parent 57318eaf5b
commit c99ad27ce1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 153 additions and 35 deletions

View File

@ -8,7 +8,6 @@ import (
"fmt"
"math/big"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/padding"
"github.com/emmansun/gmsm/pkcs"
"github.com/emmansun/gmsm/sm2"
@ -59,7 +58,7 @@ func ParseSM2(password, data []byte) (*sm2.PrivateKey, *smx509.Certificate, erro
if !keys.EncryptedKey.Algorithm.Equal(oidSM4) && !keys.EncryptedKey.Algorithm.Equal(oidSM4CBC) {
return nil, nil, fmt.Errorf("cfca: unsupported algorithm <%v>", keys.EncryptedKey.Algorithm)
}
ivkey := kdf.Kdf(sm3.New(), password, 32)
ivkey := sm3.Kdf(password, 32)
marshalledIV, err := asn1.Marshal(ivkey[:16])
if err != nil {
return nil, nil, err
@ -91,7 +90,7 @@ func MarshalSM2(password []byte, key *sm2.PrivateKey, cert *smx509.Certificate)
if len(password) == 0 {
return nil, errors.New("cfca: invalid password")
}
ivkey := kdf.Kdf(sm3.New(), password, 32)
ivkey := sm3.Kdf(password, 32)
block, err := sm4.NewCipher(ivkey[16:])
if err != nil {
return nil, err

View File

@ -2,27 +2,48 @@
package kdf
import (
"encoding"
"encoding/binary"
"hash"
)
// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3.
// ANSI-X9.63-KDF
func Kdf(md hash.Hash, z []byte, len int) []byte {
limit := uint64(len+md.Size()-1) / uint64(md.Size())
func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte {
baseMD := newHash()
limit := uint64(keyLen+baseMD.Size()-1) / uint64(baseMD.Size())
if limit >= uint64(1<<32)-1 {
panic("kdf: key length too long")
}
var countBytes [4]byte
var ct uint32 = 1
var k []byte
for i := 0; i < int(limit); i++ {
binary.BigEndian.PutUint32(countBytes[:], ct)
md.Write(z)
md.Write(countBytes[:])
k = md.Sum(k)
ct++
md.Reset()
marshaler, ok := baseMD.(encoding.BinaryMarshaler)
if limit == 1 || len(z) < baseMD.BlockSize() || !ok {
for i := 0; i < int(limit); i++ {
binary.BigEndian.PutUint32(countBytes[:], ct)
baseMD.Write(z)
baseMD.Write(countBytes[:])
k = baseMD.Sum(k)
ct++
baseMD.Reset()
}
} else {
baseMD.Write(z)
zstate, _ := marshaler.MarshalBinary()
for i := 0; i < int(limit); i++ {
md := newHash()
err := md.(encoding.BinaryUnmarshaler).UnmarshalBinary(zstate)
if err != nil {
panic(err)
}
binary.BigEndian.PutUint32(countBytes[:], ct)
md.Write(countBytes[:])
k = md.Sum(k)
ct++
}
}
return k[:len]
return k[:keyLen]
}

View File

@ -11,6 +11,6 @@ import (
// This case should be failed on 32bits system.
func TestKdfPanic(t *testing.T) {
shouldPanic(t, func() {
Kdf(sm3.New(), []byte("123456"), 1<<37)
Kdf(sm3.New, []byte("123456"), 1<<37)
})
}

View File

@ -31,7 +31,7 @@ func TestKdf(t *testing.T) {
for _, tt := range tests {
wantBytes, _ := hex.DecodeString(tt.want)
t.Run(tt.name, func(t *testing.T) {
if got := Kdf(tt.args.md, tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) {
if got := Kdf(sm3.New, tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) {
t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want)
}
})
@ -44,7 +44,7 @@ func TestKdfOldCase(t *testing.T) {
expected := "006e30dae231b071dfad8aa379e90264491603"
result := Kdf(sm3.New(), append(x2.Bytes(), y2.Bytes()...), 19)
result := Kdf(sm3.New, append(x2.Bytes(), y2.Bytes()...), 19)
resultStr := hex.EncodeToString(result)
@ -71,16 +71,17 @@ func BenchmarkKdf(b *testing.B) {
{64, 32},
{64, 64},
{64, 128},
{440, 32},
{64, 256},
{64, 512},
{64, 1024},
}
sm3Hash := sm3.New()
z := make([]byte, 512)
for _, tt := range tests {
b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
Kdf(sm3Hash, z[:tt.zLen], tt.kLen)
Kdf(sm3.New, z[:tt.zLen], tt.kLen)
}
})
}

View File

@ -25,7 +25,6 @@ import (
"github.com/emmansun/gmsm/internal/randutil"
_sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
"github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm2/sm2ec"
"github.com/emmansun/gmsm/sm3"
"golang.org/x/crypto/cryptobyte"
@ -251,7 +250,7 @@ func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byt
return nil, err
}
C2Bytes := C2.Bytes()[1:]
c2 := kdf.Kdf(sm3.New(), C2Bytes, len(msg))
c2 := sm3.Kdf(C2Bytes, len(msg))
if subtle.ConstantTimeAllZero(c2) {
retryCount++
if retryCount > maxRetryLimit {
@ -424,7 +423,7 @@ func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *Decryp
}
C2Bytes := C2.Bytes()[1:]
msgLen := len(c2)
msg := kdf.Kdf(sm3.New(), C2Bytes, msgLen)
msg := sm3.Kdf(C2Bytes, msgLen)
if subtle.ConstantTimeAllZero(c2) {
return nil, ErrDecryption
}

View File

@ -7,7 +7,6 @@ import (
"io"
"math/big"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm3"
)
@ -185,7 +184,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
buffer = append(buffer, ke.z...)
buffer = append(buffer, ke.peerZ...)
}
return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil
return sm3.Kdf(buffer, ke.keyLength), nil
}
// avf is the associative value function.

View File

@ -11,7 +11,6 @@ import (
"strings"
"github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm2/sm2ec"
"github.com/emmansun/gmsm/sm3"
"golang.org/x/crypto/cryptobyte"
@ -260,7 +259,7 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
//A5, calculate t=KDF(x2||y2, klen)
c2 := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
c2 := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if subtle.ConstantTimeAllZero(c2) {
retryCount++
if retryCount > maxRetryLimit {
@ -408,7 +407,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error
curve := priv.Curve
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
msgLen := len(c2)
msg := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
msg := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if subtle.ConstantTimeAllZero(c2) {
return nil, ErrDecryption
}

View File

@ -851,3 +851,11 @@ func BenchmarkMoreThan32_P256(b *testing.B) {
func BenchmarkMoreThan32_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard")
}
func BenchmarkEncrypt512_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption s")
}
func BenchmarkEncrypt1024_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption sencryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption s")
}

View File

@ -211,3 +211,26 @@ func Sum(data []byte) [Size]byte {
d.Write(data)
return d.checkSum()
}
// Kdf key derivation function using SM3, compliance with GB/T 32918.4-2016 5.4.3.
func Kdf(z []byte, keyLen int) []byte {
limit := uint64(keyLen+Size-1) / uint64(Size)
if limit >= uint64(1<<32)-1 {
panic("sm3: key length too long")
}
var countBytes [4]byte
var ct uint32 = 1
var k []byte
baseMD := new(digest)
baseMD.Reset()
baseMD.Write(z)
for i := 0; i < int(limit); i++ {
binary.BigEndian.PutUint32(countBytes[:], ct)
md := *baseMD
md.Write(countBytes[:])
h := md.checkSum()
k = append(k, h[:]...)
ct++
}
return k[:keyLen]
}

View File

@ -9,6 +9,8 @@ import (
"fmt"
"hash"
"io"
"math/big"
"reflect"
"testing"
"golang.org/x/sys/cpu"
@ -403,6 +405,75 @@ func BenchmarkHash8K_SH256(b *testing.B) {
benchmarkSize(benchSH256, b, 8192)
}
func TestKdf(t *testing.T) {
type args struct {
z []byte
len int
}
tests := []struct {
name string
args args
want string
}{
{"sm3 case 1", args{[]byte("emmansun"), 16}, "708993ef1388a0ae4245a19bb6c02554"},
{"sm3 case 2", args{[]byte("emmansun"), 32}, "708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd4"},
{"sm3 case 3", args{[]byte("emmansun"), 48}, "708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"},
{"sm3 case 4", args{[]byte("708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"), 48}, "49cf14649f324a07e0d5bb2a00f7f05d5f5bdd6d14dff028e071327ec031104590eddb18f98b763e18bf382ff7c3875f"},
{"sm3 case 5", args{[]byte("708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"), 128}, "49cf14649f324a07e0d5bb2a00f7f05d5f5bdd6d14dff028e071327ec031104590eddb18f98b763e18bf382ff7c3875f30277f3179baebd795e7853fa643fdf280d8d7b81a2ab7829f615e132ab376d32194cd315908d27090e1180ce442d9be99322523db5bfac40ac5acb03550f5c93e5b01b1d71f2630868909a6a1250edb"},
}
for _, tt := range tests {
wantBytes, _ := hex.DecodeString(tt.want)
t.Run(tt.name, func(t *testing.T) {
if got := Kdf(tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) {
t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want)
}
})
}
}
func TestKdfOldCase(t *testing.T) {
x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16)
y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16)
expected := "006e30dae231b071dfad8aa379e90264491603"
result := Kdf(append(x2.Bytes(), y2.Bytes()...), 19)
resultStr := hex.EncodeToString(result)
if expected != resultStr {
t.Fatalf("expected %s, real value %s", expected, resultStr)
}
}
func BenchmarkKdfWithSM3(b *testing.B) {
tests := []struct {
zLen int
kLen int
}{
{32, 32},
{32, 64},
{32, 128},
{64, 32},
{64, 64},
{64, 128},
{64, 256},
{64, 512},
{64, 1024},
{64, 1024*8},
}
z := make([]byte, 512)
for _, tt := range tests {
b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
Kdf(z[:tt.zLen], tt.kLen)
}
})
}
}
/*
func round1(a, b, c, d, e, f, g, h string, i int) {
fmt.Printf("//Round %d\n", i+1)

View File

@ -12,7 +12,6 @@ import (
"github.com/emmansun/gmsm/internal/bigmod"
"github.com/emmansun/gmsm/internal/randutil"
"github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm3"
"github.com/emmansun/gmsm/sm9/bn256"
"golang.org/x/crypto/cryptobyte"
@ -317,7 +316,7 @@ func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte,
buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...)
key = kdf.Kdf(sm3.New(), buffer, kLen)
key = sm3.Kdf(buffer, kLen)
if !subtle.ConstantTimeAllZero(key) {
break
}
@ -403,7 +402,7 @@ func UnwrapKey(priv *EncryptPrivateKey, uid []byte, cipher *bn256.G1, kLen int)
buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...)
key := kdf.Kdf(sm3.New(), buffer, kLen)
key := sm3.Kdf(buffer, kLen)
if subtle.ConstantTimeAllZero(key) {
return nil, ErrDecryption
}
@ -685,7 +684,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
buffer = append(buffer, ke.g2.Marshal()...)
buffer = append(buffer, ke.g3.Marshal()...)
return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil
return sm3.Kdf(buffer, ke.keyLength), nil
}
func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA *bn256.G1) (*bn256.G1, []byte, error) {

View File

@ -8,7 +8,6 @@ import (
"github.com/emmansun/gmsm/internal/bigmod"
"github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm3"
"github.com/emmansun/gmsm/sm9/bn256"
"golang.org/x/crypto/cryptobyte"
@ -563,7 +562,7 @@ func TestWrapKeySM9Sample(t *testing.T) {
buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...)
key := kdf.Kdf(sm3.New(), buffer, 32)
key := sm3.Kdf(buffer, 32)
if hex.EncodeToString(key) != expectedKey {
t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key))
@ -629,7 +628,7 @@ func TestEncryptSM9Sample(t *testing.T) {
buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...)
key := kdf.Kdf(sm3.New(), buffer, len(plaintext)+32)
key := sm3.Kdf(buffer, len(plaintext)+32)
if hex.EncodeToString(key) != expectedKey {
t.Errorf("not expected key")
@ -697,7 +696,7 @@ func TestEncryptSM9SampleBlockMode(t *testing.T) {
buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...)
key := kdf.Kdf(sm3.New(), buffer, 16+32)
key := sm3.Kdf(buffer, 16+32)
if hex.EncodeToString(key) != expectedKey {
t.Errorf("not expected key, expected %v, got %x\n", expectedKey, key)