diff --git a/sm9/example_test.go b/sm9/example_test.go index 9573cf6..14d92dd 100644 --- a/sm9/example_test.go +++ b/sm9/example_test.go @@ -130,7 +130,7 @@ func ExampleEncryptPrivateKey_Decrypt() { } uid := []byte("Bob") cipherDer, _ := hex.DecodeString("307f020100034200042cb3e90b0977211597652f26ee4abbe275ccb18dd7f431876ab5d40cc2fc563d9417791c75bc8909336a4e6562450836cc863f51002e31ecf0c4aae8d98641070420638ca5bfb35d25cff7cbd684f3ed75f2d919da86a921a2e3e2e2f4cbcf583f240414b7e776811774722a8720752fb1355ce45dc3d0df") - plaintext, err := userKey.DecryptASN1(uid, cipherDer) + plaintext, err := userKey.Decrypt(rand.Reader, cipherDer, uid) if err != nil { fmt.Fprintf(os.Stderr, "Error from Decrypt: %s\n", err) return diff --git a/sm9/sm9.go b/sm9/sm9.go index f2f5b96..9e66d53 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -277,27 +277,27 @@ func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte, opts EncrypterOpts c3c2 := ciphertext[64:] c3 := c3c2[:sm3.Size] c2 := c3c2[sm3.Size:] - key1Len := opts.GetKeySize(c2) + return decrypt(priv, uid, c1, c2, c3, opts) +} +func decrypt(priv *EncryptPrivateKey, uid, c1, c2, c3 []byte, opts EncrypterOpts) ([]byte, error) { + key1Len := opts.GetKeySize(c2) key, err := UnwrapKey(priv, uid, c1, key1Len+sm3.Size) if err != nil { return nil, err } _ = key[key1Len] // bounds check elimination hint - return decrypt(key[:key1Len], key[key1Len:], c2, c3, opts) -} -func decrypt(key1, key2, c2, c3 []byte, opts EncrypterOpts) ([]byte, error) { hash := sm3.New() hash.Write(c2) - hash.Write(key2) + hash.Write(key[key1Len:]) c32 := hash.Sum(nil) if goSubtle.ConstantTimeCompare(c3, c32) != 1 { return nil, ErrDecryption } - return opts.Decrypt(key1, c2) + return opts.Decrypt(key[:key1Len], c2) } // DecryptASN1 decrypts chipher, the ciphertext should be with ASN.1 format according @@ -328,23 +328,55 @@ func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error if opts == nil { return nil, ErrDecryption } - key1Len := opts.GetKeySize(c2Bytes) - key, err := UnwrapKey(priv, uid, c1Bytes, key1Len+sm3.Size) - if err != nil { - return nil, err - } - - _ = key[key1Len] // bounds check elimination hint - return decrypt(key[:key1Len], key[key1Len:], c2Bytes, c3Bytes, opts) + return decrypt(priv, uid, c1Bytes, c2Bytes, c3Bytes, opts) } -// Decrypt decrypts chipher, the ciphertext should be with format C1||C3||C2 -func (priv *EncryptPrivateKey) Decrypt(uid, ciphertext []byte, opts EncrypterOpts) ([]byte, error) { - return Decrypt(priv, uid, ciphertext, opts) +type DecrypterOptsWithUID struct { + EncrypterOpts + UID []byte +} + +// NewDecrypterOptsWithUID creates a new DecrypterOptsWithUID instance with the provided +// EncrypterOpts and UID. The UID must not be empty, otherwise an error is returned. +func NewDecrypterOptsWithUID(opts EncrypterOpts, uid []byte) (*DecrypterOptsWithUID, error) { + if len(uid) == 0 { + return nil, errors.New("sm9: invalid uid") + } + return &DecrypterOptsWithUID{EncrypterOpts: opts, UID: uid}, nil +} + +// Decrypt decrypts the given ciphertext using the provided EncryptPrivateKey. +// The decryption process depends on the type of the opts parameter: +// - If opts is of type []byte, it uses DecryptASN1 to decrypt the ciphertext. +// - If opts is of type *DecrypterOptsWithUID, it first checks if the ciphertext +// is a valid ASN.1 sequence. If it is not, and EncrypterOpts is nil, it returns +// an error indicating invalid ASN.1 data. Otherwise, it uses the Decrypt function +// with the provided UID and EncrypterOpts to decrypt the ciphertext. If the +// ciphertext is a valid ASN.1 sequence, it uses DecryptASN1 with the UID to +// decrypt the ciphertext. +// If opts is of an unsupported type, it returns an error indicating invalid decrypter options. +func (priv *EncryptPrivateKey) Decrypt(rand io.Reader, ciphertext []byte, opts crypto.DecrypterOpts) ([]byte, error) { + switch xx := opts.(type) { + case []byte: + return DecryptASN1(priv, xx, ciphertext) + case *DecrypterOptsWithUID: + var inner cryptobyte.String + input := cryptobyte.String(ciphertext) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || !input.Empty() { + if xx.EncrypterOpts == nil { + return nil, errors.New("sm9: invalid ciphertext asn.1 data") + } + return Decrypt(priv, xx.UID, ciphertext, xx.EncrypterOpts) + } else { + return DecryptASN1(priv, xx.UID, ciphertext) + } + } + return nil, errors.New("sm9: invalid decrypter options") } // DecryptASN1 decrypts chipher, the ciphertext should be with ASN.1 format according // SM9 cryptographic algorithm application specification, SM9Cipher definition. +// @Deprecated: Use Decrypt instead. func (priv *EncryptPrivateKey) DecryptASN1(uid, ciphertext []byte) ([]byte, error) { return DecryptASN1(priv, uid, ciphertext) } @@ -401,10 +433,13 @@ type keyExchange struct { ke *sm9.KeyExchange } +// NewKeyExchange initializes a new key exchange process using the provided user IDs and key length. +// It returns a pointer to a keyExchange struct which contains the key exchange instance. func (priv *EncryptPrivateKey) NewKeyExchange(uid, peerUID []byte, keyLen int, genSignature bool) *keyExchange { return &keyExchange{ke: priv.internal.NewKeyExchange(uid, peerUID, keyLen, genSignature)} } +// Destroy securely wipes the key exchange data from memory. func (ke *keyExchange) Destroy() { ke.ke.Destroy() } diff --git a/sm9/sm9_key.go b/sm9/sm9_key.go index b3cc735..a5b6de5 100644 --- a/sm9/sm9_key.go +++ b/sm9/sm9_key.go @@ -212,6 +212,12 @@ func (priv *SignPrivateKey) Equal(x crypto.PrivateKey) bool { return subtle.ConstantTimeCompare(priv.privateKey, xx.privateKey) == 1 } +// Public returns the public key corresponding to the private key. +// Just to satisfy [crypto.Signer] interface. +func (priv *SignPrivateKey) Public() crypto.PublicKey { + return nil +} + func (priv *SignPrivateKey) Bytes() []byte { var buf [65]byte return append(buf[:0], priv.privateKey...) @@ -517,6 +523,12 @@ func (priv *EncryptPrivateKey) Equal(x crypto.PrivateKey) bool { return subtle.ConstantTimeCompare(priv.privateKey, xx.privateKey) == 1 } +// Public returns the public key corresponding to the private key. +// Just to satisfy [crypto.Decrypter] interface. +func (priv *EncryptPrivateKey) Public() crypto.PublicKey { + return nil +} + // Bytes returns the byte representation of the EncryptPrivateKey. // It delegates the call to the Bytes method of the underlying privateKey. func (priv *EncryptPrivateKey) Bytes() []byte { diff --git a/sm9/sm9_test.go b/sm9/sm9_test.go index 28d60d7..efa1c22 100644 --- a/sm9/sm9_test.go +++ b/sm9/sm9_test.go @@ -120,7 +120,11 @@ func TestEncryptDecrypt(t *testing.T) { t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) } - got, err = userKey.Decrypt(uid, cipher, opts) + opts1, err := sm9.NewDecrypterOptsWithUID(opts, uid) + if err != nil { + t.Fatal(err) + } + got, err = userKey.Decrypt(rand.Reader, cipher, opts1) if err != nil { t.Fatalf("encType %v, first byte %x, %v", opts.GetEncryptType(), cipher[0], err) } @@ -128,6 +132,12 @@ func TestEncryptDecrypt(t *testing.T) { if string(got) != string(plaintext) { t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) } + + opts1.EncrypterOpts = nil + _, err = userKey.Decrypt(rand.Reader, cipher, opts1) + if err == nil || err.Error() != "sm9: invalid ciphertext asn.1 data" { + t.Fatalf("sm9: invalid ciphertext asn.1 data") + } } } @@ -187,6 +197,26 @@ func TestEncryptDecryptASN1(t *testing.T) { if string(got) != string(plaintext) { t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) } + + got, err = userKey.Decrypt(rand.Reader, cipher, uid) + if err != nil { + t.Fatal(err) + } + if string(got) != string(plaintext) { + t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } + + opts, err := sm9.NewDecrypterOptsWithUID(nil, uid) + if err != nil { + t.Fatal(err) + } + got, err = userKey.Decrypt(rand.Reader, cipher, opts) + if err != nil { + t.Fatal(err) + } + if string(got) != string(plaintext) { + t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } } }