diff --git a/pkcs/cipher.go b/pkcs/cipher.go index 4c049f7..1227295 100644 --- a/pkcs/cipher.go +++ b/pkcs/cipher.go @@ -3,24 +3,27 @@ package pkcs import ( "crypto/cipher" - "crypto/rand" "crypto/x509/pkix" "encoding/asn1" "errors" "fmt" + "io" smcipher "github.com/emmansun/gmsm/cipher" "github.com/emmansun/gmsm/padding" ) -// Cipher represents a cipher for encrypting the key material. +// Cipher represents a cipher for encrypting the key material +// which is used in PBES2. type Cipher interface { // KeySize returns the key size of the cipher, in bytes. KeySize() int - // Encrypt encrypts the key material. - Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) - // Decrypt decrypts the key material. - Decrypt(key []byte, parameters *asn1.RawValue, encryptedKey []byte) ([]byte, error) + // Encrypt encrypts the key material. The returned AlgorithmIdentifier is + // the algorithm identifier used for encryption including parameters. + Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) + // Decrypt decrypts the key material. The parameters are the parameters from the + // DER-encoded AlgorithmIdentifier's. + Decrypt(key []byte, parameters *asn1.RawValue, ciphertext []byte) ([]byte, error) // OID returns the OID of the cipher specified. OID() asn1.ObjectIdentifier } @@ -33,6 +36,7 @@ func RegisterCipher(oid asn1.ObjectIdentifier, cipher func() Cipher) { ciphers[oid.String()] = cipher } +// GetCipher returns an instance of the cipher specified by the given algorithm identifier. func GetCipher(alg pkix.AlgorithmIdentifier) (Cipher, error) { oid := alg.Algorithm.String() if oid == oidSM4.String() { @@ -67,7 +71,7 @@ type ecbBlockCipher struct { baseBlockCipher } -func (ecb *ecbBlockCipher) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) { +func (ecb *ecbBlockCipher) Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) { block, err := ecb.newBlock(key) if err != nil { return nil, nil, err @@ -106,15 +110,17 @@ type cbcBlockCipher struct { ivSize int } -func (c *cbcBlockCipher) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) { +func (c *cbcBlockCipher) Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) { block, err := c.newBlock(key) if err != nil { return nil, nil, err } - iv, err := genRandom(c.ivSize) - if err != nil { + + iv := make([]byte, c.ivSize) + if _, err := rand.Read(iv); err != nil { return nil, nil, err } + ciphertext, err := cbcEncrypt(block, iv, plaintext) if err != nil { return nil, nil, err @@ -133,7 +139,7 @@ func (c *cbcBlockCipher) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifi return &encryptionScheme, ciphertext, nil } -func (c *cbcBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, encryptedKey []byte) ([]byte, error) { +func (c *cbcBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, ciphertext []byte) ([]byte, error) { block, err := c.newBlock(key) if err != nil { return nil, err @@ -144,7 +150,7 @@ func (c *cbcBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, encrypte return nil, errors.New("pkcs: invalid cipher parameters") } - return cbcDecrypt(block, iv, encryptedKey) + return cbcDecrypt(block, iv, ciphertext) } func cbcEncrypt(block cipher.Block, iv, plaintext []byte) ([]byte, error) { @@ -170,21 +176,23 @@ type gcmBlockCipher struct { } // https://datatracker.ietf.org/doc/rfc5084/ -// GCMParameters ::= SEQUENCE { -// aes-nonce OCTET STRING, -- recommended size is 12 octets -// aes-ICVlen AES-GCM-ICVlen DEFAULT 12 } +// +// GCMParameters ::= SEQUENCE { +// aes-nonce OCTET STRING, -- recommended size is 12 octets +// aes-ICVlen AES-GCM-ICVlen DEFAULT 12 } type gcmParameters struct { Nonce []byte ICVLen int `asn1:"default:12,optional"` } -func (c *gcmBlockCipher) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) { +func (c *gcmBlockCipher) Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) { block, err := c.newBlock(key) if err != nil { return nil, nil, err } - nonce, err := genRandom(c.nonceSize) - if err != nil { + + nonce := make([]byte, c.nonceSize) + if _, err := rand.Read(nonce); err != nil { return nil, nil, err } @@ -210,7 +218,7 @@ func (c *gcmBlockCipher) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifi return &encryptionAlgorithm, ciphertext, nil } -func (c *gcmBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, encryptedKey []byte) ([]byte, error) { +func (c *gcmBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, ciphertext []byte) ([]byte, error) { block, err := c.newBlock(key) if err != nil { return nil, err @@ -228,11 +236,5 @@ func (c *gcmBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, encrypte return nil, errors.New("pkcs: we do not support non-standard tag size") } - return aead.Open(nil, params.Nonce, encryptedKey, nil) -} - -func genRandom(len int) ([]byte, error) { - value := make([]byte, len) - _, err := rand.Read(value) - return value, err + return aead.Open(nil, params.Nonce, ciphertext, nil) } diff --git a/pkcs/cipher_test.go b/pkcs/cipher_test.go index c1c9746..de4779d 100644 --- a/pkcs/cipher_test.go +++ b/pkcs/cipher_test.go @@ -2,6 +2,7 @@ package pkcs import ( "bytes" + "crypto/rand" "crypto/x509/pkix" "encoding/asn1" "testing" @@ -36,7 +37,7 @@ func TestGetCipher(t *testing.T) { func TestInvalidKeyLen(t *testing.T) { plaintext := []byte("Hello World") invalidKey := []byte("123456") - _, _, err := SM4ECB.Encrypt(invalidKey, plaintext) + _, _, err := SM4ECB.Encrypt(rand.Reader, invalidKey, plaintext) if err == nil { t.Errorf("should be error") } @@ -44,7 +45,7 @@ func TestInvalidKeyLen(t *testing.T) { if err == nil { t.Errorf("should be error") } - _, _, err = SM4CBC.Encrypt(invalidKey, plaintext) + _, _, err = SM4CBC.Encrypt(rand.Reader, invalidKey, plaintext) if err == nil { t.Errorf("should be error") } @@ -52,7 +53,7 @@ func TestInvalidKeyLen(t *testing.T) { if err == nil { t.Errorf("should be error") } - _, _, err = SM4GCM.Encrypt(invalidKey, plaintext) + _, _, err = SM4GCM.Encrypt(rand.Reader, invalidKey, plaintext) if err == nil { t.Errorf("should be error") } diff --git a/pkcs8/kdf_pbkdf2.go b/pkcs/kdf_pbkdf2.go similarity index 96% rename from pkcs8/kdf_pbkdf2.go rename to pkcs/kdf_pbkdf2.go index d579036..047aa50 100644 --- a/pkcs8/kdf_pbkdf2.go +++ b/pkcs/kdf_pbkdf2.go @@ -1,141 +1,140 @@ -package pkcs8 - -// -// Reference https://datatracker.ietf.org/doc/html/rfc8018#section-5.2 -// - -import ( - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "crypto/x509/pkix" - "encoding/asn1" - "errors" - "hash" - - "github.com/emmansun/gmsm/sm3" - "golang.org/x/crypto/pbkdf2" -) - -// http://gmssl.org/docs/oid.html -var ( - oidPKCS5PBKDF2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 12} - oidHMACWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 7} - oidHMACWithSHA224 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 8} - oidHMACWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 9} - oidHMACWithSHA384 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 10} - oidHMACWithSHA512 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 11} - oidHMACWithSHA512_224 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 12} - oidHMACWithSHA512_256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 13} - oidHMACWithSM3 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 401, 2} -) - -func init() { - RegisterKDF(oidPKCS5PBKDF2, func() KDFParameters { - return new(pbkdf2Params) - }) -} - -func newHashFromPRF(ai pkix.AlgorithmIdentifier) (func() hash.Hash, error) { - switch { - case len(ai.Algorithm) == 0 || ai.Algorithm.Equal(oidHMACWithSHA1): - return sha1.New, nil - case ai.Algorithm.Equal(oidHMACWithSHA224): - return sha256.New224, nil - case ai.Algorithm.Equal(oidHMACWithSHA256): - return sha256.New, nil - case ai.Algorithm.Equal(oidHMACWithSHA384): - return sha512.New384, nil - case ai.Algorithm.Equal(oidHMACWithSHA512): - return sha512.New, nil - case ai.Algorithm.Equal(oidHMACWithSHA512_224): - return sha512.New512_224, nil - case ai.Algorithm.Equal(oidHMACWithSHA512_256): - return sha512.New512_256, nil - case ai.Algorithm.Equal(oidHMACWithSM3): - return sm3.New, nil - default: - return nil, errors.New("pkcs8: unsupported hash function") - } -} - -func newPRFParamFromHash(h Hash) (pkix.AlgorithmIdentifier, error) { - switch h { - case SHA1: - return pkix.AlgorithmIdentifier{ - Algorithm: oidHMACWithSHA1, - Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil - case SHA224: - return pkix.AlgorithmIdentifier{ - Algorithm: oidHMACWithSHA224, - Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil - case SHA256: - return pkix.AlgorithmIdentifier{ - Algorithm: oidHMACWithSHA256, - Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil - case SHA384: - return pkix.AlgorithmIdentifier{ - Algorithm: oidHMACWithSHA384, - Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil - case SHA512: - return pkix.AlgorithmIdentifier{ - Algorithm: oidHMACWithSHA512, - Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil - case SHA512_224: - return pkix.AlgorithmIdentifier{ - Algorithm: oidHMACWithSHA512_224, - Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil - case SHA512_256: - return pkix.AlgorithmIdentifier{ - Algorithm: oidHMACWithSHA512_256, - Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil - case SM3: - return pkix.AlgorithmIdentifier{ - Algorithm: oidHMACWithSM3, - Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil - - } - return pkix.AlgorithmIdentifier{}, errors.New("pkcs8: unsupported hash function") -} - -type pbkdf2Params struct { - Salt []byte - IterationCount int - KeyLen int `asn1:"optional"` - PRF pkix.AlgorithmIdentifier `asn1:"optional"` -} - -func (p pbkdf2Params) DeriveKey(password []byte, size int) (key []byte, err error) { - h, err := newHashFromPRF(p.PRF) - if err != nil { - return nil, err - } - return pbkdf2.Key(password, p.Salt, p.IterationCount, size, h), nil -} - -// PBKDF2Opts contains options for the PBKDF2 key derivation function. -type PBKDF2Opts struct { - SaltSize int - IterationCount int - HMACHash Hash -} - -func (p PBKDF2Opts) DeriveKey(password, salt []byte, size int) ( - key []byte, params KDFParameters, err error) { - - key = pbkdf2.Key(password, salt, p.IterationCount, size, p.HMACHash.New) - prfParam, err := newPRFParamFromHash(p.HMACHash) - if err != nil { - return nil, nil, err - } - params = pbkdf2Params{salt, p.IterationCount, size, prfParam} - return key, params, nil -} - -func (p PBKDF2Opts) GetSaltSize() int { - return p.SaltSize -} - -func (p PBKDF2Opts) OID() asn1.ObjectIdentifier { - return oidPKCS5PBKDF2 -} +package pkcs + +// +// Reference https://datatracker.ietf.org/doc/html/rfc8018#section-5.2 +// + +import ( + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "hash" + + "github.com/emmansun/gmsm/sm3" + "golang.org/x/crypto/pbkdf2" +) + +var ( + oidPKCS5PBKDF2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 12} + oidHMACWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 7} + oidHMACWithSHA224 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 8} + oidHMACWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 9} + oidHMACWithSHA384 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 10} + oidHMACWithSHA512 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 11} + oidHMACWithSHA512_224 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 12} + oidHMACWithSHA512_256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 13} + oidHMACWithSM3 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 401, 2} +) + +func init() { + RegisterKDF(oidPKCS5PBKDF2, func() KDFParameters { + return new(pbkdf2Params) + }) +} + +func newHashFromPRF(ai pkix.AlgorithmIdentifier) (func() hash.Hash, error) { + switch { + case len(ai.Algorithm) == 0 || ai.Algorithm.Equal(oidHMACWithSHA1): + return sha1.New, nil + case ai.Algorithm.Equal(oidHMACWithSHA224): + return sha256.New224, nil + case ai.Algorithm.Equal(oidHMACWithSHA256): + return sha256.New, nil + case ai.Algorithm.Equal(oidHMACWithSHA384): + return sha512.New384, nil + case ai.Algorithm.Equal(oidHMACWithSHA512): + return sha512.New, nil + case ai.Algorithm.Equal(oidHMACWithSHA512_224): + return sha512.New512_224, nil + case ai.Algorithm.Equal(oidHMACWithSHA512_256): + return sha512.New512_256, nil + case ai.Algorithm.Equal(oidHMACWithSM3): + return sm3.New, nil + default: + return nil, errors.New("pkcs8: unsupported hash function") + } +} + +func newPRFParamFromHash(h Hash) (pkix.AlgorithmIdentifier, error) { + switch h { + case SHA1: + return pkix.AlgorithmIdentifier{ + Algorithm: oidHMACWithSHA1, + Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil + case SHA224: + return pkix.AlgorithmIdentifier{ + Algorithm: oidHMACWithSHA224, + Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil + case SHA256: + return pkix.AlgorithmIdentifier{ + Algorithm: oidHMACWithSHA256, + Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil + case SHA384: + return pkix.AlgorithmIdentifier{ + Algorithm: oidHMACWithSHA384, + Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil + case SHA512: + return pkix.AlgorithmIdentifier{ + Algorithm: oidHMACWithSHA512, + Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil + case SHA512_224: + return pkix.AlgorithmIdentifier{ + Algorithm: oidHMACWithSHA512_224, + Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil + case SHA512_256: + return pkix.AlgorithmIdentifier{ + Algorithm: oidHMACWithSHA512_256, + Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil + case SM3: + return pkix.AlgorithmIdentifier{ + Algorithm: oidHMACWithSM3, + Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil + + } + return pkix.AlgorithmIdentifier{}, errors.New("pkcs8: unsupported hash function") +} + +type pbkdf2Params struct { + Salt []byte + IterationCount int + KeyLen int `asn1:"optional"` + PRF pkix.AlgorithmIdentifier `asn1:"optional"` +} + +func (p pbkdf2Params) DeriveKey(password []byte, size int) (key []byte, err error) { + h, err := newHashFromPRF(p.PRF) + if err != nil { + return nil, err + } + return pbkdf2.Key(password, p.Salt, p.IterationCount, size, h), nil +} + +// PBKDF2Opts contains options for the PBKDF2 key derivation function. +type PBKDF2Opts struct { + SaltSize int + IterationCount int + HMACHash Hash +} + +func (p PBKDF2Opts) DeriveKey(password, salt []byte, size int) ( + key []byte, params KDFParameters, err error) { + + key = pbkdf2.Key(password, salt, p.IterationCount, size, p.HMACHash.New) + prfParam, err := newPRFParamFromHash(p.HMACHash) + if err != nil { + return nil, nil, err + } + params = pbkdf2Params{salt, p.IterationCount, size, prfParam} + return key, params, nil +} + +func (p PBKDF2Opts) GetSaltSize() int { + return p.SaltSize +} + +func (p PBKDF2Opts) OID() asn1.ObjectIdentifier { + return oidPKCS5PBKDF2 +} diff --git a/pkcs8/kdf_scrypt.go b/pkcs/kdf_scrypt.go similarity index 94% rename from pkcs8/kdf_scrypt.go rename to pkcs/kdf_scrypt.go index 102c34f..822ed3a 100644 --- a/pkcs8/kdf_scrypt.go +++ b/pkcs/kdf_scrypt.go @@ -1,66 +1,66 @@ -package pkcs8 - -// -// Reference https://datatracker.ietf.org/doc/html/rfc7914 -// - -import ( - "encoding/asn1" - - "golang.org/x/crypto/scrypt" -) - -var ( - oidScrypt = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11591, 4, 11} -) - -func init() { - RegisterKDF(oidScrypt, func() KDFParameters { - return new(scryptParams) - }) -} - -type scryptParams struct { - Salt []byte - CostParameter int - BlockSize int - ParallelizationParameter int -} - -func (p scryptParams) DeriveKey(password []byte, size int) (key []byte, err error) { - return scrypt.Key(password, p.Salt, p.CostParameter, p.BlockSize, - p.ParallelizationParameter, size) -} - -// ScryptOpts contains options for the scrypt key derivation function. -type ScryptOpts struct { - SaltSize int - CostParameter int - BlockSize int - ParallelizationParameter int -} - -func (p ScryptOpts) DeriveKey(password, salt []byte, size int) ( - key []byte, params KDFParameters, err error) { - - key, err = scrypt.Key(password, salt, p.CostParameter, p.BlockSize, - p.ParallelizationParameter, size) - if err != nil { - return nil, nil, err - } - params = scryptParams{ - BlockSize: p.BlockSize, - CostParameter: p.CostParameter, - ParallelizationParameter: p.ParallelizationParameter, - Salt: salt, - } - return key, params, nil -} - -func (p ScryptOpts) GetSaltSize() int { - return p.SaltSize -} - -func (p ScryptOpts) OID() asn1.ObjectIdentifier { - return oidScrypt -} +package pkcs + +// +// Reference https://datatracker.ietf.org/doc/html/rfc7914 +// + +import ( + "encoding/asn1" + + "golang.org/x/crypto/scrypt" +) + +var ( + oidScrypt = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11591, 4, 11} +) + +func init() { + RegisterKDF(oidScrypt, func() KDFParameters { + return new(scryptParams) + }) +} + +type scryptParams struct { + Salt []byte + CostParameter int + BlockSize int + ParallelizationParameter int +} + +func (p scryptParams) DeriveKey(password []byte, size int) (key []byte, err error) { + return scrypt.Key(password, p.Salt, p.CostParameter, p.BlockSize, + p.ParallelizationParameter, size) +} + +// ScryptOpts contains options for the scrypt key derivation function. +type ScryptOpts struct { + SaltSize int + CostParameter int + BlockSize int + ParallelizationParameter int +} + +func (p ScryptOpts) DeriveKey(password, salt []byte, size int) ( + key []byte, params KDFParameters, err error) { + + key, err = scrypt.Key(password, salt, p.CostParameter, p.BlockSize, + p.ParallelizationParameter, size) + if err != nil { + return nil, nil, err + } + params = scryptParams{ + BlockSize: p.BlockSize, + CostParameter: p.CostParameter, + ParallelizationParameter: p.ParallelizationParameter, + Salt: salt, + } + return key, params, nil +} + +func (p ScryptOpts) GetSaltSize() int { + return p.SaltSize +} + +func (p ScryptOpts) OID() asn1.ObjectIdentifier { + return oidScrypt +} diff --git a/pkcs/pbes2.go b/pkcs/pbes2.go new file mode 100644 index 0000000..22b087e --- /dev/null +++ b/pkcs/pbes2.go @@ -0,0 +1,200 @@ +package pkcs + +import ( + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "fmt" + "hash" + "io" + "strconv" + + "github.com/emmansun/gmsm/sm3" +) + +var ( + oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} +) + +// Hash identifies a cryptographic hash function that is implemented in another +// package. +type Hash uint + +const ( + SHA1 Hash = 1 + iota + SHA224 + SHA256 + SHA384 + SHA512 + SHA512_224 + SHA512_256 + SM3 +) + +// New returns a new hash.Hash calculating the given hash function. New panics +// if the hash function is not linked into the binary. +func (h Hash) New() hash.Hash { + switch h { + case SM3: + return sm3.New() + case SHA1: + return sha1.New() + case SHA224: + return sha256.New224() + case SHA256: + return sha256.New() + case SHA384: + return sha512.New384() + case SHA512: + return sha512.New() + case SHA512_224: + return sha512.New512_224() + case SHA512_256: + return sha512.New512_256() + + } + panic("pkcs5: requested hash function #" + strconv.Itoa(int(h)) + " is unavailable") +} + +// PBKDF2Opts contains algorithm identifiers and related parameters for PBKDF2 key derivation function. +type PBES2Params struct { + KeyDerivationFunc pkix.AlgorithmIdentifier + EncryptionScheme pkix.AlgorithmIdentifier +} + +// PBES2Opts contains options for encrypting a key using PBES2. +type PBES2Opts struct { + Cipher + KDFOpts +} + +// DefaultOpts are the default options for encrypting a key if none are given. +// The defaults can be changed by the library user. +var DefaultOpts = &PBES2Opts{ + Cipher: AES256CBC, + KDFOpts: PBKDF2Opts{ + SaltSize: 16, + IterationCount: 2048, + HMACHash: SHA256, + }, +} + +// KDFOpts contains options for a key derivation function. +// An implementation of this interface must be specified when encrypting a PKCS#8 key. +type KDFOpts interface { + // DeriveKey derives a key of size bytes from the given password and salt. + // It returns the key and the ASN.1-encodable parameters used. + DeriveKey(password, salt []byte, size int) (key []byte, params KDFParameters, err error) + // GetSaltSize returns the salt size specified. + GetSaltSize() int + // OID returns the OID of the KDF specified. + OID() asn1.ObjectIdentifier +} + +// KDFParameters contains parameters (salt, etc.) for a key deriviation function. +// It must be a ASN.1-decodable structure. +// An implementation of this interface is created when decoding an encrypted PKCS#8 key. +type KDFParameters interface { + // DeriveKey derives a key of size bytes from the given password. + // It uses the salt from the decoded parameters. + DeriveKey(password []byte, size int) (key []byte, err error) +} + +var kdfs = make(map[string]func() KDFParameters) + +// RegisterKDF registers a function that returns a new instance of the given KDF +// parameters. This allows the library to support client-provided KDFs. +func RegisterKDF(oid asn1.ObjectIdentifier, params func() KDFParameters) { + kdfs[oid.String()] = params +} + +func (pbes2Params *PBES2Params) parseKeyDerivationFunc() (KDFParameters, error) { + oid := pbes2Params.KeyDerivationFunc.Algorithm.String() + newParams, ok := kdfs[oid] + if !ok { + return nil, fmt.Errorf("pkcs5: unsupported KDF (OID: %s)", oid) + } + params := newParams() + _, err := asn1.Unmarshal(pbes2Params.KeyDerivationFunc.Parameters.FullBytes, params) + if err != nil { + return nil, errors.New("pkcs5: invalid KDF parameters") + } + return params, nil +} + +// Decrypt decrypts the given ciphertext using the given password and the options specified. +func (pbes2Params *PBES2Params) Decrypt(password, ciphertext []byte) ([]byte, KDFParameters, error) { + cipher, err := GetCipher(pbes2Params.EncryptionScheme) + if err != nil { + return nil, nil, err + } + + kdfParams, err := pbes2Params.parseKeyDerivationFunc() + if err != nil { + return nil, nil, err + } + + keySize := cipher.KeySize() + symkey, err := kdfParams.DeriveKey(password, keySize) + if err != nil { + return nil, nil, err + } + + plaintext, err := cipher.Decrypt(symkey, &pbes2Params.EncryptionScheme.Parameters, ciphertext) + if err != nil { + return nil, nil, err + } + return plaintext, kdfParams, nil +} + +// Encrypt encrypts the given plaintext using the given password and the options specified. +func (opts *PBES2Opts) Encrypt(rand io.Reader, password, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) { + // Generate a random salt + salt := make([]byte, opts.KDFOpts.GetSaltSize()) + if _, err := rand.Read(salt); err != nil { + return nil, nil, err + } + + // Derive the key + encAlg := opts.Cipher + key, kdfParams, err := opts.KDFOpts.DeriveKey(password, salt, encAlg.KeySize()) + if err != nil { + return nil, nil, err + } + + // Encrypt the plaintext + encryptionScheme, ciphertext, err := encAlg.Encrypt(rand, key, plaintext) + if err != nil { + return nil, nil, err + } + + marshalledParams, err := asn1.Marshal(kdfParams) + if err != nil { + return nil, nil, err + } + keyDerivationFunc := pkix.AlgorithmIdentifier{ + Algorithm: opts.KDFOpts.OID(), + Parameters: asn1.RawValue{FullBytes: marshalledParams}, + } + + encryptionAlgorithmParams := PBES2Params{ + EncryptionScheme: *encryptionScheme, + KeyDerivationFunc: keyDerivationFunc, + } + marshalledEncryptionAlgorithmParams, err := asn1.Marshal(encryptionAlgorithmParams) + if err != nil { + return nil, nil, err + } + encryptionAlgorithm := pkix.AlgorithmIdentifier{ + Algorithm: oidPBES2, + Parameters: asn1.RawValue{FullBytes: marshalledEncryptionAlgorithmParams}, + } + return &encryptionAlgorithm, ciphertext, nil +} + +func IsPBES2(algorithm pkix.AlgorithmIdentifier) bool { + return oidPBES2.Equal(algorithm.Algorithm) +} diff --git a/pkcs7/encrypt.go b/pkcs7/encrypt.go index c3043f8..f8766f2 100644 --- a/pkcs7/encrypt.go +++ b/pkcs7/encrypt.go @@ -1,6 +1,7 @@ package pkcs7 import ( + "crypto/rand" "encoding/asn1" "errors" @@ -36,7 +37,7 @@ func encryptUsingPSK(cipher pkcs.Cipher, content []byte, key []byte, contentType return nil, ErrPSKNotProvided } - id, ciphertext, err := cipher.Encrypt(key, content) + id, ciphertext, err := cipher.Encrypt(rand.Reader, key, content) if err != nil { return nil, err } diff --git a/pkcs7/envelope.go b/pkcs7/envelope.go index c92d6fa..d224390 100644 --- a/pkcs7/envelope.go +++ b/pkcs7/envelope.go @@ -121,7 +121,7 @@ func NewEnvelopedData(cipher pkcs.Cipher, content []byte) (*EnvelopedData, error return nil, err } - id, ciphertext, err := cipher.Encrypt(key, content) + id, ciphertext, err := cipher.Encrypt(rand.Reader, key, content) if err != nil { return nil, err } @@ -148,7 +148,7 @@ func NewSM2EnvelopedData(cipher pkcs.Cipher, content []byte) (*EnvelopedData, er return nil, err } - id, ciphertext, err := cipher.Encrypt(key, content) + id, ciphertext, err := cipher.Encrypt(rand.Reader, key, content) if err != nil { return nil, err } diff --git a/pkcs7/sign_enveloped.go b/pkcs7/sign_enveloped.go index c3bbfde..01d15ae 100644 --- a/pkcs7/sign_enveloped.go +++ b/pkcs7/sign_enveloped.go @@ -147,7 +147,7 @@ func NewSignedAndEnvelopedData(data []byte, cipher pkcs.Cipher) (*SignedAndEnvel return nil, err } - id, ciphertext, err := cipher.Encrypt(key, data) + id, ciphertext, err := cipher.Encrypt(rand.Reader, key, data) if err != nil { return nil, err } diff --git a/pkcs8/pkcs8.go b/pkcs8/pkcs8.go index 5d1a2b1..089eda8 100644 --- a/pkcs8/pkcs8.go +++ b/pkcs8/pkcs8.go @@ -5,103 +5,29 @@ import ( "crypto/ecdsa" "crypto/rand" "crypto/rsa" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" "crypto/x509/pkix" "encoding/asn1" "encoding/pem" "errors" - "fmt" - "hash" - "strconv" "github.com/emmansun/gmsm/pkcs" "github.com/emmansun/gmsm/sm2" - "github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm9" "github.com/emmansun/gmsm/smx509" ) -// Hash identifies a cryptographic hash function that is implemented in another -// package. -type Hash uint +type Opts = pkcs.PBES2Opts +type PBKDF2Opts = pkcs.PBKDF2Opts +type ScryptOpts = pkcs.ScryptOpts -const ( - SHA1 Hash = 1 + iota - SHA224 - SHA256 - SHA384 - SHA512 - SHA512_224 - SHA512_256 - SM3 -) - -// New returns a new hash.Hash calculating the given hash function. New panics -// if the hash function is not linked into the binary. -func (h Hash) New() hash.Hash { - switch h { - case SM3: - return sm3.New() - case SHA1: - return sha1.New() - case SHA224: - return sha256.New224() - case SHA256: - return sha256.New() - case SHA384: - return sha512.New384() - case SHA512: - return sha512.New() - case SHA512_224: - return sha512.New512_224() - case SHA512_256: - return sha512.New512_256() - - } - panic("pkcs8: requested hash function #" + strconv.Itoa(int(h)) + " is unavailable") -} - -// DefaultOpts are the default options for encrypting a key if none are given. -// The defaults can be changed by the library user. -var DefaultOpts = &Opts{ - Cipher: pkcs.AES256CBC, - KDFOpts: PBKDF2Opts{ - SaltSize: 8, - IterationCount: 10000, - HMACHash: SHA256, - }, -} - -// KDFOpts contains options for a key derivation function. -// An implementation of this interface must be specified when encrypting a PKCS#8 key. -type KDFOpts interface { - // DeriveKey derives a key of size bytes from the given password and salt. - // It returns the key and the ASN.1-encodable parameters used. - DeriveKey(password, salt []byte, size int) (key []byte, params KDFParameters, err error) - // GetSaltSize returns the salt size specified. - GetSaltSize() int - // OID returns the OID of the KDF specified. - OID() asn1.ObjectIdentifier -} - -// KDFParameters contains parameters (salt, etc.) for a key deriviation function. -// It must be a ASN.1-decodable structure. -// An implementation of this interface is created when decoding an encrypted PKCS#8 key. -type KDFParameters interface { - // DeriveKey derives a key of size bytes from the given password. - // It uses the salt from the decoded parameters. - DeriveKey(password []byte, size int) (key []byte, err error) -} - -var kdfs = make(map[string]func() KDFParameters) - -// RegisterKDF registers a function that returns a new instance of the given KDF -// parameters. This allows the library to support client-provided KDFs. -func RegisterKDF(oid asn1.ObjectIdentifier, params func() KDFParameters) { - kdfs[oid.String()] = params -} +var SM3 = pkcs.SM3 +var SHA1 = pkcs.SHA1 +var SHA224 = pkcs.SHA224 +var SHA256 = pkcs.SHA256 +var SHA384 = pkcs.SHA384 +var SHA512 = pkcs.SHA512 +var SHA512_224 = pkcs.SHA512_224 +var SHA512_256 = pkcs.SHA512_256 // for encrypted private-key information type encryptedPrivateKeyInfo struct { @@ -109,40 +35,10 @@ type encryptedPrivateKeyInfo struct { EncryptedData []byte } -// Opts contains options for encrypting a PKCS#8 key. -type Opts struct { - Cipher pkcs.Cipher - KDFOpts KDFOpts -} - -// Unecrypted PKCS8 -var ( - oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} -) - -type pbes2Params struct { - KeyDerivationFunc pkix.AlgorithmIdentifier - EncryptionScheme pkix.AlgorithmIdentifier -} - -func parseKeyDerivationFunc(keyDerivationFunc pkix.AlgorithmIdentifier) (KDFParameters, error) { - oid := keyDerivationFunc.Algorithm.String() - newParams, ok := kdfs[oid] - if !ok { - return nil, fmt.Errorf("pkcs8: unsupported KDF (OID: %s)", oid) - } - params := newParams() - _, err := asn1.Unmarshal(keyDerivationFunc.Parameters.FullBytes, params) - if err != nil { - return nil, errors.New("pkcs8: invalid KDF parameters") - } - return params, nil -} - // ParsePrivateKey parses a DER-encoded PKCS#8 private key. // Password can be nil. // This is equivalent to ParsePKCS8PrivateKey. -func ParsePrivateKey(der []byte, password []byte) (any, KDFParameters, error) { +func ParsePrivateKey(der []byte, password []byte) (any, pkcs.KDFParameters, error) { // No password provided, assume the private key is unencrypted if len(password) == 0 { privateKey, err := smx509.ParsePKCS8PrivateKey(der) @@ -158,33 +54,16 @@ func ParsePrivateKey(der []byte, password []byte) (any, KDFParameters, error) { return nil, nil, errors.New("pkcs8: only PKCS #5 v2.0 supported") } - if !privKey.EncryptionAlgorithm.Algorithm.Equal(oidPBES2) { + if !pkcs.IsPBES2(privKey.EncryptionAlgorithm) { return nil, nil, errors.New("pkcs8: only PBES2 supported") } - var params pbes2Params + var params pkcs.PBES2Params if _, err := asn1.Unmarshal(privKey.EncryptionAlgorithm.Parameters.FullBytes, ¶ms); err != nil { return nil, nil, errors.New("pkcs8: invalid PBES2 parameters") } - cipher, err := pkcs.GetCipher(params.EncryptionScheme) - if err != nil { - return nil, nil, err - } - - kdfParams, err := parseKeyDerivationFunc(params.KeyDerivationFunc) - if err != nil { - return nil, nil, err - } - - keySize := cipher.KeySize() - symkey, err := kdfParams.DeriveKey(password, keySize) - if err != nil { - return nil, nil, err - } - - encryptedKey := privKey.EncryptedData - decryptedKey, err := cipher.Decrypt(symkey, ¶ms.EncryptionScheme.Parameters, encryptedKey) + decryptedKey, kdfParams, err := params.Decrypt(password, privKey.EncryptedData) if err != nil { return nil, nil, err } @@ -204,7 +83,7 @@ func MarshalPrivateKey(priv any, password []byte, opts *Opts) ([]byte, error) { } if opts == nil { - opts = DefaultOpts + opts = pkcs.DefaultOpts } // Convert private key into PKCS8 format @@ -213,47 +92,13 @@ func MarshalPrivateKey(priv any, password []byte, opts *Opts) ([]byte, error) { return nil, err } - encAlg := opts.Cipher - salt := make([]byte, opts.KDFOpts.GetSaltSize()) - _, err = rand.Read(salt) + encryptionAlgorithm, encryptedKey, err := opts.Encrypt(rand.Reader, password, pkey) if err != nil { return nil, err } - key, kdfParams, err := opts.KDFOpts.DeriveKey(password, salt, encAlg.KeySize()) - if err != nil { - return nil, err - } - - encryptionScheme, encryptedKey, err := encAlg.Encrypt(key, pkey) - if err != nil { - return nil, err - } - - marshalledParams, err := asn1.Marshal(kdfParams) - if err != nil { - return nil, err - } - keyDerivationFunc := pkix.AlgorithmIdentifier{ - Algorithm: opts.KDFOpts.OID(), - Parameters: asn1.RawValue{FullBytes: marshalledParams}, - } - - encryptionAlgorithmParams := pbes2Params{ - EncryptionScheme: *encryptionScheme, - KeyDerivationFunc: keyDerivationFunc, - } - marshalledEncryptionAlgorithmParams, err := asn1.Marshal(encryptionAlgorithmParams) - if err != nil { - return nil, err - } - encryptionAlgorithm := pkix.AlgorithmIdentifier{ - Algorithm: oidPBES2, - Parameters: asn1.RawValue{FullBytes: marshalledEncryptionAlgorithmParams}, - } - encryptedPkey := encryptedPrivateKeyInfo{ - EncryptionAlgorithm: encryptionAlgorithm, + EncryptionAlgorithm: *encryptionAlgorithm, EncryptedData: encryptedKey, }