diff --git a/docs/sm2.md b/docs/sm2.md index fa8a671..e73633d 100644 --- a/docs/sm2.md +++ b/docs/sm2.md @@ -193,6 +193,9 @@ func ExampleVerifyASN1WithSM2() { #### 验签 调用`sm2.VerifyASN1`方法,同样,你自己负责预先计算杂凑值,确保杂凑算法和签名时使用的杂凑算法保持一致。 +### 如何对大文件签名、验签? +解决方案就是对杂凑值进行签名、验签。`sm2.CalculateSM2Hash`并不适合对大文件进行杂凑计算,请使用专门的`hash.Hash`接口实现。 + ## 密钥交换协议 这里有两个实现,一个是传统实现,位于sm2包中;另外一个参考最新go语言的实现在ecdh包中。在这里不详细介绍使用方法,一般只有tls/tlcp才会用到,普通应用通常不会涉及这一块,感兴趣的话可以参考github.com/Trisia/gotlcp中的应用。 diff --git a/sm2/example_test.go b/sm2/example_test.go index 53e5eb0..0f8b9e9 100644 --- a/sm2/example_test.go +++ b/sm2/example_test.go @@ -112,6 +112,33 @@ func ExamplePrivateKey_Sign_forceSM2() { fmt.Printf("%x\n", sig) } +func ExamplePrivateKey_Sign_withHash() { + toSign := []byte("ShangMi SM2 Sign Standard") + // real private key should be from secret storage + privKey, _ := hex.DecodeString("6c5a0a0b2eed3cbec3e4f1252bfe0e28c504a1c6bf1999eebb0af9ef0f8e6c85") + testkey, err := sm2.NewPrivateKey(privKey) + if err != nil { + log.Fatalf("fail to new private key %v", err) + } + + // caluclate hash value + h, err := sm2.NewHash(&testkey.PublicKey) + if err != nil { + log.Fatalf("fail to new hash %v", err) + } + h.Write(toSign) + hashed := h.Sum(nil) + + sig, err := testkey.Sign(rand.Reader, hashed, nil) + if err != nil { + fmt.Fprintf(os.Stderr, "Error from sign: %s\n", err) + return + } + // Since sign is a randomized function, signature will be + // different each time. + fmt.Printf("%x\n", sig) +} + func ExampleVerifyASN1WithSM2() { // real public key should be from cert or public key pem file keypoints, _ := hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") @@ -120,10 +147,35 @@ func ExampleVerifyASN1WithSM2() { log.Fatalf("fail to new public key %v", err) } - toSign := []byte("ShangMi SM2 Sign Standard") + data := []byte("ShangMi SM2 Sign Standard") signature, _ := hex.DecodeString("304402205b3a799bd94c9063120d7286769220af6b0fa127009af3e873c0e8742edc5f890220097968a4c8b040fd548d1456b33f470cabd8456bfea53e8a828f92f6d4bdcd77") - ok := sm2.VerifyASN1WithSM2(testkey, nil, toSign, signature) + ok := sm2.VerifyASN1WithSM2(testkey, nil, data, signature) + + fmt.Printf("%v\n", ok) + // Output: true +} + +func ExampleVerifyASN1() { + // real public key should be from cert or public key pem file + keypoints, _ := hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") + testkey, err := sm2.NewPublicKey(keypoints) + if err != nil { + log.Fatalf("fail to new public key %v", err) + } + + // caluclate hash value + data := []byte("ShangMi SM2 Sign Standard") + h, err := sm2.NewHash(testkey) + if err != nil { + log.Fatalf("fail to new hash %v", err) + } + h.Write(data) + hashed := h.Sum(nil) + + signature, _ := hex.DecodeString("304402205b3a799bd94c9063120d7286769220af6b0fa127009af3e873c0e8742edc5f890220097968a4c8b040fd548d1456b33f470cabd8456bfea53e8a828f92f6d4bdcd77") + + ok := sm2.VerifyASN1(testkey, hashed, signature) fmt.Printf("%v\n", ok) // Output: true diff --git a/sm2/sm2_dsa.go b/sm2/sm2_dsa.go index e96242e..f550f30 100644 --- a/sm2/sm2_dsa.go +++ b/sm2/sm2_dsa.go @@ -826,3 +826,67 @@ func precomputeParams(c *sm2Curve, curve elliptic.Curve) { c.nMinus1 = c.N.Nat().SubOne(c.N) c.nMinus2 = new(bigmod.Nat).Set(c.nMinus1).SubOne(c.N).Bytes(c.N) } + +// sm2Hasher is a wrapper around a hash.Hash that includes the ZA value for SM2 hashing. +// It is used to perform SM2-specific hashing operations with the provided public key and user ID. +type sm2Hasher struct { + inner hash.Hash + za []byte +} + +// NewHash creates a new hash.Hash instance using the provided SM2 public key. +// It uses the default SM3 hash function and default user ID. +func NewHash(pub *ecdsa.PublicKey) (hash.Hash, error) { + return NewHashWithUserID(pub, defaultUID) +} + +// NewHashWithUserID creates a new hash.Hash instance using the provided SM2 public key and user ID. +// It internally uses the SM3 hash function. +func NewHashWithUserID(pub *ecdsa.PublicKey, userID []byte) (hash.Hash, error) { + return NewHashWithHashAndUserID(pub, sm3.New, userID) +} + +// NewHashWithHashAndUserID creates a new hash.Hash instance that incorporates SM2-specific +// hashing with the provided public key, inner hash and user ID. +// The returned hasher is reset before being returned. +func NewHashWithHashAndUserID(pub *ecdsa.PublicKey, h func() hash.Hash, userID []byte) (hash.Hash, error) { + inner := h() + za, err := CalculateZA(pub, userID) + if err != nil { + return nil, err + } + hasher := &sm2Hasher{inner: inner, za: za} + hasher.Write(za) + return hasher, nil +} + +// Write writes the contents of p into the underlying hash function. +// It returns the number of bytes written from p (n) and any error encountered (err). +// This method satisfies the io.Writer interface. +func (s *sm2Hasher) Write(p []byte) (n int, err error) { + return s.inner.Write(p) +} + +// Sum appends the current hash to b and returns the resulting slice. +// It does not change the underlying hash state. +func (s *sm2Hasher) Sum(b []byte) []byte { + return s.inner.Sum(b) +} + +// Reset clears the current state of the sm2Hasher and reinitializes it. +// It first resets the inner hash state and then writes the ZA value to it. +func (s *sm2Hasher) Reset() { + s.inner.Reset() + s.inner.Write(s.za) +} + +// Size returns the size of the hash in bytes. +func (s *sm2Hasher) Size() int { + return s.inner.Size() +} + +// BlockSize returns the block size of the hash function in bytes. +// It delegates the call to the inner hash function's BlockSize method. +func (s *sm2Hasher) BlockSize() int { + return s.inner.BlockSize() +} diff --git a/sm2/sm2_dsa_test.go b/sm2/sm2_dsa_test.go index 1faa41a..953f6f8 100644 --- a/sm2/sm2_dsa_test.go +++ b/sm2/sm2_dsa_test.go @@ -454,6 +454,74 @@ func TestSignVerify(t *testing.T) { } } +func TestSM2Hasher(t *testing.T) { + tobeHashed := []byte("hello world") + keypoints, _ := hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") + pub, err := NewPublicKey(keypoints) + if err != nil { + t.Fatal(err) + } + md := sm3.New() + hasher1, err := NewHash(pub) + if err != nil { + t.Fatal(err) + } + if hasher1.BlockSize() != md.BlockSize() { + t.Errorf("expected %d, got %d", md.BlockSize(), hasher1.BlockSize()) + } + if hasher1.Size() != md.Size() { + t.Errorf("expected %d, got %d", md.Size(), hasher1.Size()) + } + hasher1.Write(tobeHashed) + hash1 := hasher1.Sum(nil) + expected, err := CalculateSM2Hash(pub, tobeHashed, nil) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(hash1, expected) { + t.Errorf("expected %x, got %x", expected, hash1) + } + + hasher2, err := NewHashWithUserID(pub, []byte("john snow")) + if err != nil { + t.Fatal(err) + } + hasher2.Write(tobeHashed) + hash2 := hasher2.Sum(nil) + expected, err = CalculateSM2Hash(pub, tobeHashed, []byte("john snow")) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(hash2, expected) { + t.Errorf("expected %x, got %x", expected, hash2) + } +} + +func TestSM2HasherReset(t *testing.T) { + tobeHashed := []byte("hello world") + keypoints, _ := hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") + pub, err := NewPublicKey(keypoints) + if err != nil { + t.Fatal(err) + } + + hasher, err := NewHash(pub) + if err != nil { + t.Fatal(err) + } + + hasher.Write(tobeHashed) + hashBeforeReset := hasher.Sum(nil) + + hasher.Reset() + hasher.Write(tobeHashed) + hashAfterReset := hasher.Sum(nil) + + if !bytes.Equal(hashBeforeReset, hashAfterReset) { + t.Errorf("expected %x, got %x", hashBeforeReset, hashAfterReset) + } +} + func BenchmarkGenerateKey_SM2(b *testing.B) { r := bufio.NewReaderSize(rand.Reader, 1<<15) b.ReportAllocs()