sm9: implement crypto.Singer crypto.Decrypter interface

This commit is contained in:
Sun Yimin 2025-03-25 14:58:16 +08:00 committed by GitHub
parent 3eea15b3b8
commit 88df15c64c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 96 additions and 19 deletions

View File

@ -130,7 +130,7 @@ func ExampleEncryptPrivateKey_Decrypt() {
} }
uid := []byte("Bob") uid := []byte("Bob")
cipherDer, _ := hex.DecodeString("307f020100034200042cb3e90b0977211597652f26ee4abbe275ccb18dd7f431876ab5d40cc2fc563d9417791c75bc8909336a4e6562450836cc863f51002e31ecf0c4aae8d98641070420638ca5bfb35d25cff7cbd684f3ed75f2d919da86a921a2e3e2e2f4cbcf583f240414b7e776811774722a8720752fb1355ce45dc3d0df") cipherDer, _ := hex.DecodeString("307f020100034200042cb3e90b0977211597652f26ee4abbe275ccb18dd7f431876ab5d40cc2fc563d9417791c75bc8909336a4e6562450836cc863f51002e31ecf0c4aae8d98641070420638ca5bfb35d25cff7cbd684f3ed75f2d919da86a921a2e3e2e2f4cbcf583f240414b7e776811774722a8720752fb1355ce45dc3d0df")
plaintext, err := userKey.DecryptASN1(uid, cipherDer) plaintext, err := userKey.Decrypt(rand.Reader, cipherDer, uid)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error from Decrypt: %s\n", err) fmt.Fprintf(os.Stderr, "Error from Decrypt: %s\n", err)
return return

View File

@ -277,27 +277,27 @@ func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte, opts EncrypterOpts
c3c2 := ciphertext[64:] c3c2 := ciphertext[64:]
c3 := c3c2[:sm3.Size] c3 := c3c2[:sm3.Size]
c2 := c3c2[sm3.Size:] c2 := c3c2[sm3.Size:]
key1Len := opts.GetKeySize(c2) return decrypt(priv, uid, c1, c2, c3, opts)
}
func decrypt(priv *EncryptPrivateKey, uid, c1, c2, c3 []byte, opts EncrypterOpts) ([]byte, error) {
key1Len := opts.GetKeySize(c2)
key, err := UnwrapKey(priv, uid, c1, key1Len+sm3.Size) key, err := UnwrapKey(priv, uid, c1, key1Len+sm3.Size)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_ = key[key1Len] // bounds check elimination hint _ = key[key1Len] // bounds check elimination hint
return decrypt(key[:key1Len], key[key1Len:], c2, c3, opts)
}
func decrypt(key1, key2, c2, c3 []byte, opts EncrypterOpts) ([]byte, error) {
hash := sm3.New() hash := sm3.New()
hash.Write(c2) hash.Write(c2)
hash.Write(key2) hash.Write(key[key1Len:])
c32 := hash.Sum(nil) c32 := hash.Sum(nil)
if goSubtle.ConstantTimeCompare(c3, c32) != 1 { if goSubtle.ConstantTimeCompare(c3, c32) != 1 {
return nil, ErrDecryption return nil, ErrDecryption
} }
return opts.Decrypt(key1, c2) return opts.Decrypt(key[:key1Len], c2)
} }
// DecryptASN1 decrypts chipher, the ciphertext should be with ASN.1 format according // DecryptASN1 decrypts chipher, the ciphertext should be with ASN.1 format according
@ -328,23 +328,55 @@ func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error
if opts == nil { if opts == nil {
return nil, ErrDecryption return nil, ErrDecryption
} }
key1Len := opts.GetKeySize(c2Bytes) return decrypt(priv, uid, c1Bytes, c2Bytes, c3Bytes, opts)
key, err := UnwrapKey(priv, uid, c1Bytes, key1Len+sm3.Size)
if err != nil {
return nil, err
}
_ = key[key1Len] // bounds check elimination hint
return decrypt(key[:key1Len], key[key1Len:], c2Bytes, c3Bytes, opts)
} }
// Decrypt decrypts chipher, the ciphertext should be with format C1||C3||C2 type DecrypterOptsWithUID struct {
func (priv *EncryptPrivateKey) Decrypt(uid, ciphertext []byte, opts EncrypterOpts) ([]byte, error) { EncrypterOpts
return Decrypt(priv, uid, ciphertext, opts) UID []byte
}
// NewDecrypterOptsWithUID creates a new DecrypterOptsWithUID instance with the provided
// EncrypterOpts and UID. The UID must not be empty, otherwise an error is returned.
func NewDecrypterOptsWithUID(opts EncrypterOpts, uid []byte) (*DecrypterOptsWithUID, error) {
if len(uid) == 0 {
return nil, errors.New("sm9: invalid uid")
}
return &DecrypterOptsWithUID{EncrypterOpts: opts, UID: uid}, nil
}
// Decrypt decrypts the given ciphertext using the provided EncryptPrivateKey.
// The decryption process depends on the type of the opts parameter:
// - If opts is of type []byte, it uses DecryptASN1 to decrypt the ciphertext.
// - If opts is of type *DecrypterOptsWithUID, it first checks if the ciphertext
// is a valid ASN.1 sequence. If it is not, and EncrypterOpts is nil, it returns
// an error indicating invalid ASN.1 data. Otherwise, it uses the Decrypt function
// with the provided UID and EncrypterOpts to decrypt the ciphertext. If the
// ciphertext is a valid ASN.1 sequence, it uses DecryptASN1 with the UID to
// decrypt the ciphertext.
// If opts is of an unsupported type, it returns an error indicating invalid decrypter options.
func (priv *EncryptPrivateKey) Decrypt(rand io.Reader, ciphertext []byte, opts crypto.DecrypterOpts) ([]byte, error) {
switch xx := opts.(type) {
case []byte:
return DecryptASN1(priv, xx, ciphertext)
case *DecrypterOptsWithUID:
var inner cryptobyte.String
input := cryptobyte.String(ciphertext)
if !input.ReadASN1(&inner, asn1.SEQUENCE) || !input.Empty() {
if xx.EncrypterOpts == nil {
return nil, errors.New("sm9: invalid ciphertext asn.1 data")
}
return Decrypt(priv, xx.UID, ciphertext, xx.EncrypterOpts)
} else {
return DecryptASN1(priv, xx.UID, ciphertext)
}
}
return nil, errors.New("sm9: invalid decrypter options")
} }
// DecryptASN1 decrypts chipher, the ciphertext should be with ASN.1 format according // DecryptASN1 decrypts chipher, the ciphertext should be with ASN.1 format according
// SM9 cryptographic algorithm application specification, SM9Cipher definition. // SM9 cryptographic algorithm application specification, SM9Cipher definition.
// @Deprecated: Use Decrypt instead.
func (priv *EncryptPrivateKey) DecryptASN1(uid, ciphertext []byte) ([]byte, error) { func (priv *EncryptPrivateKey) DecryptASN1(uid, ciphertext []byte) ([]byte, error) {
return DecryptASN1(priv, uid, ciphertext) return DecryptASN1(priv, uid, ciphertext)
} }
@ -401,10 +433,13 @@ type keyExchange struct {
ke *sm9.KeyExchange ke *sm9.KeyExchange
} }
// NewKeyExchange initializes a new key exchange process using the provided user IDs and key length.
// It returns a pointer to a keyExchange struct which contains the key exchange instance.
func (priv *EncryptPrivateKey) NewKeyExchange(uid, peerUID []byte, keyLen int, genSignature bool) *keyExchange { func (priv *EncryptPrivateKey) NewKeyExchange(uid, peerUID []byte, keyLen int, genSignature bool) *keyExchange {
return &keyExchange{ke: priv.internal.NewKeyExchange(uid, peerUID, keyLen, genSignature)} return &keyExchange{ke: priv.internal.NewKeyExchange(uid, peerUID, keyLen, genSignature)}
} }
// Destroy securely wipes the key exchange data from memory.
func (ke *keyExchange) Destroy() { func (ke *keyExchange) Destroy() {
ke.ke.Destroy() ke.ke.Destroy()
} }

View File

@ -212,6 +212,12 @@ func (priv *SignPrivateKey) Equal(x crypto.PrivateKey) bool {
return subtle.ConstantTimeCompare(priv.privateKey, xx.privateKey) == 1 return subtle.ConstantTimeCompare(priv.privateKey, xx.privateKey) == 1
} }
// Public returns the public key corresponding to the private key.
// Just to satisfy [crypto.Signer] interface.
func (priv *SignPrivateKey) Public() crypto.PublicKey {
return nil
}
func (priv *SignPrivateKey) Bytes() []byte { func (priv *SignPrivateKey) Bytes() []byte {
var buf [65]byte var buf [65]byte
return append(buf[:0], priv.privateKey...) return append(buf[:0], priv.privateKey...)
@ -517,6 +523,12 @@ func (priv *EncryptPrivateKey) Equal(x crypto.PrivateKey) bool {
return subtle.ConstantTimeCompare(priv.privateKey, xx.privateKey) == 1 return subtle.ConstantTimeCompare(priv.privateKey, xx.privateKey) == 1
} }
// Public returns the public key corresponding to the private key.
// Just to satisfy [crypto.Decrypter] interface.
func (priv *EncryptPrivateKey) Public() crypto.PublicKey {
return nil
}
// Bytes returns the byte representation of the EncryptPrivateKey. // Bytes returns the byte representation of the EncryptPrivateKey.
// It delegates the call to the Bytes method of the underlying privateKey. // It delegates the call to the Bytes method of the underlying privateKey.
func (priv *EncryptPrivateKey) Bytes() []byte { func (priv *EncryptPrivateKey) Bytes() []byte {

View File

@ -120,7 +120,11 @@ func TestEncryptDecrypt(t *testing.T) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
} }
got, err = userKey.Decrypt(uid, cipher, opts) opts1, err := sm9.NewDecrypterOptsWithUID(opts, uid)
if err != nil {
t.Fatal(err)
}
got, err = userKey.Decrypt(rand.Reader, cipher, opts1)
if err != nil { if err != nil {
t.Fatalf("encType %v, first byte %x, %v", opts.GetEncryptType(), cipher[0], err) t.Fatalf("encType %v, first byte %x, %v", opts.GetEncryptType(), cipher[0], err)
} }
@ -128,6 +132,12 @@ func TestEncryptDecrypt(t *testing.T) {
if string(got) != string(plaintext) { if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
} }
opts1.EncrypterOpts = nil
_, err = userKey.Decrypt(rand.Reader, cipher, opts1)
if err == nil || err.Error() != "sm9: invalid ciphertext asn.1 data" {
t.Fatalf("sm9: invalid ciphertext asn.1 data")
}
} }
} }
@ -187,6 +197,26 @@ func TestEncryptDecryptASN1(t *testing.T) {
if string(got) != string(plaintext) { if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
} }
got, err = userKey.Decrypt(rand.Reader, cipher, uid)
if err != nil {
t.Fatal(err)
}
if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
}
opts, err := sm9.NewDecrypterOptsWithUID(nil, uid)
if err != nil {
t.Fatal(err)
}
got, err = userKey.Decrypt(rand.Reader, cipher, opts)
if err != nil {
t.Fatal(err)
}
if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
}
} }
} }