From dd8b2f61dd02cb943fbc7e3cb5df2cc9ab8c2e48 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Tue, 14 Feb 2023 16:03:05 +0800 Subject: [PATCH] smx509: ParseTypedECPrivateKey, return real privae key type --- sm9/enc_mode.go | 15 +++++++++++---- smx509/name_constraints_test.go | 2 +- smx509/sec1.go | 18 +++++++++++++++++- smx509/sec1_test.go | 28 ++++++++++++++++++++++++++++ 4 files changed, 57 insertions(+), 6 deletions(-) diff --git a/sm9/enc_mode.go b/sm9/enc_mode.go index 1924a6f..0b50c13 100644 --- a/sm9/enc_mode.go +++ b/sm9/enc_mode.go @@ -22,7 +22,7 @@ type EncrypterOpts interface { Decrypt(key, ciphertext []byte) ([]byte, error) } -// XOREncrypterOpts represents XOR encrypt type/mode. +// XOREncrypterOpts represents XOR mode. type XOREncrypterOpts struct{} func (opts *XOREncrypterOpts) GetEncryptType() encryptType { @@ -62,6 +62,7 @@ func (opts *baseBlockEncrypterOpts) GetKeySize(plaintext []byte) int { return opts.cipherKeySize } +// CBCEncrypterOpts represents CBC (Cipher block chaining) mode. type CBCEncrypterOpts struct { baseBlockEncrypterOpts padding padding.Padding @@ -76,6 +77,7 @@ func NewCBCEncrypterOpts(padding padding.Padding, cipherFactory CipherFactory, k return opts } +// Encrypt encrypts the plaintext with the key, includes generated IV at the beginning of the ciphertext. func (opts *CBCEncrypterOpts) Encrypt(rand io.Reader, key, plaintext []byte) ([]byte, error) { block, err := opts.cipherFactory(key) if err != nil { @@ -99,7 +101,7 @@ func (opts *CBCEncrypterOpts) Decrypt(key, ciphertext []byte) ([]byte, error) { return nil, err } blockSize := block.BlockSize() - if len(ciphertext) < blockSize { + if len(ciphertext) <= blockSize { return nil, ErrDecryption } iv := ciphertext[:blockSize] @@ -110,6 +112,7 @@ func (opts *CBCEncrypterOpts) Decrypt(key, ciphertext []byte) ([]byte, error) { return opts.padding.Unpad(plaintext) } +// ECBEncrypterOpts represents ECB (Electronic Code Book) mode. type ECBEncrypterOpts struct { baseBlockEncrypterOpts padding padding.Padding @@ -150,6 +153,7 @@ func (opts *ECBEncrypterOpts) Decrypt(key, ciphertext []byte) ([]byte, error) { return opts.padding.Unpad(plaintext) } +// CFBEncrypterOpts represents CFB (Cipher Feedback) mode. type CFBEncrypterOpts struct { baseBlockEncrypterOpts } @@ -162,6 +166,7 @@ func NewCFBEncrypterOpts(cipherFactory CipherFactory, keySize int) EncrypterOpts return opts } +// Encrypt encrypts the plaintext with the key, includes generated IV at the beginning of the ciphertext. func (opts *CFBEncrypterOpts) Encrypt(rand io.Reader, key, plaintext []byte) ([]byte, error) { block, err := opts.cipherFactory(key) if err != nil { @@ -184,7 +189,7 @@ func (opts *CFBEncrypterOpts) Decrypt(key, ciphertext []byte) ([]byte, error) { return nil, err } blockSize := block.BlockSize() - if len(ciphertext) < blockSize { + if len(ciphertext) <= blockSize { return nil, ErrDecryption } iv := ciphertext[:blockSize] @@ -195,6 +200,7 @@ func (opts *CFBEncrypterOpts) Decrypt(key, ciphertext []byte) ([]byte, error) { return plaintext, nil } +// OFBEncrypterOpts represents OFB (Output Feedback) mode. type OFBEncrypterOpts struct { baseBlockEncrypterOpts } @@ -207,6 +213,7 @@ func NewOFBEncrypterOpts(cipherFactory CipherFactory, keySize int) EncrypterOpts return opts } +// Encrypt encrypts the plaintext with the key, includes generated IV at the beginning of the ciphertext. func (opts *OFBEncrypterOpts) Encrypt(rand io.Reader, key, plaintext []byte) ([]byte, error) { block, err := opts.cipherFactory(key) if err != nil { @@ -229,7 +236,7 @@ func (opts *OFBEncrypterOpts) Decrypt(key, ciphertext []byte) ([]byte, error) { return nil, err } blockSize := block.BlockSize() - if len(ciphertext) < blockSize { + if len(ciphertext) <= blockSize { return nil, ErrDecryption } iv := ciphertext[:blockSize] diff --git a/smx509/name_constraints_test.go b/smx509/name_constraints_test.go index 92a7255..64cd4ec 100644 --- a/smx509/name_constraints_test.go +++ b/smx509/name_constraints_test.go @@ -1716,7 +1716,7 @@ func makeConstraintsLeafCert(leaf leafSpec, key *ecdsa.PrivateKey, parent *Certi parent = template } - derBytes, err := CreateCertificate(rand.Reader, template.asX509(), parent.asX509(), &key.PublicKey, parentKey) + derBytes, err := CreateCertificate(rand.Reader, template, parent, &key.PublicKey, parentKey) if err != nil { return nil, err } diff --git a/smx509/sec1.go b/smx509/sec1.go index 503388d..126a0c5 100644 --- a/smx509/sec1.go +++ b/smx509/sec1.go @@ -34,7 +34,7 @@ func ParseECPrivateKey(der []byte) (*ecdsa.PrivateKey, error) { return parseECPrivateKey(nil, der) } -// ParseSM2PrivateKey parses an SM2 private key +// ParseSM2PrivateKey parses an SM2 private key in SEC 1, ASN.1 DER form. func ParseSM2PrivateKey(der []byte) (*sm2.PrivateKey, error) { key, err := parseECPrivateKey(nil, der) if err != nil { @@ -43,6 +43,22 @@ func ParseSM2PrivateKey(der []byte) (*sm2.PrivateKey, error) { return new(sm2.PrivateKey).FromECPrivateKey(key) } +// ParseTypedECPrivateKey parses an EC private key in SEC 1, ASN.1 DER form. +// +// It returns a *ecdsa.PrivateKey or a *sm2.PrivateKey. +// +// This kind of key is commonly encoded in PEM blocks of type "EC PRIVATE KEY". +func ParseTypedECPrivateKey(der []byte) (interface{}, error) { + key, err := parseECPrivateKey(nil, der) + if err != nil { + return nil, err + } + if key.Curve == sm2.P256() { + return new(sm2.PrivateKey).FromECPrivateKey(key) + } + return key, nil +} + // MarshalECPrivateKey converts an EC private key to SEC 1, ASN.1 DER form. // // This kind of key is commonly encoded in PEM blocks of type "EC PRIVATE KEY". diff --git a/smx509/sec1_test.go b/smx509/sec1_test.go index 9faf4bc..5c60c93 100644 --- a/smx509/sec1_test.go +++ b/smx509/sec1_test.go @@ -2,6 +2,7 @@ package smx509 import ( "bytes" + "crypto/ecdsa" "crypto/rand" "encoding/hex" "fmt" @@ -75,3 +76,30 @@ func TestMarshalSM2PrivateKey(t *testing.T) { } fmt.Printf("%s\n", hex.EncodeToString(res)) } + +func TestParseTypedECPrivateKey(t *testing.T) { + for i, test := range ecKeyTests { + derBytes, _ := hex.DecodeString(test.derHex) + key, err := ParseTypedECPrivateKey(derBytes) + if err != nil { + t.Fatalf("#%d: failed to decode EC private key: %s", i, err) + } + var serialized []byte + switch privKey := key.(type) { + case *ecdsa.PrivateKey: + serialized, err = MarshalECPrivateKey(privKey) + if err != nil { + t.Fatalf("#%d: failed to encode EC private key: %s", i, err) + } + case *sm2.PrivateKey: + serialized, err = MarshalSM2PrivateKey(privKey) + if err != nil { + t.Fatalf("#%d: failed to encode SM2 private key: %s", i, err) + } + } + matches := bytes.Equal(serialized, derBytes) + if matches != test.shouldReserialize { + t.Fatalf("#%d: when serializing key: matches=%t, should match=%t: original %x, reserialized %x", i, matches, test.shouldReserialize, serialized, derBytes) + } + } +}