diff --git a/pkcs/cipher.go b/pkcs/cipher.go index 1227295..16f3641 100644 --- a/pkcs/cipher.go +++ b/pkcs/cipher.go @@ -48,7 +48,7 @@ func GetCipher(alg pkix.AlgorithmIdentifier) (Cipher, error) { } newCipher, ok := ciphers[oid] if !ok { - return nil, fmt.Errorf("pkcs: unsupported cipher (OID: %s)", oid) + return nil, fmt.Errorf("pbes: unsupported cipher (OID: %s)", oid) } return newCipher(), nil } @@ -147,7 +147,7 @@ func (c *cbcBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, cipherte var iv []byte if _, err := asn1.Unmarshal(parameters.FullBytes, &iv); err != nil { - return nil, errors.New("pkcs: invalid cipher parameters") + return nil, errors.New("pbes: invalid cipher parameters") } return cbcDecrypt(block, iv, ciphertext) @@ -233,7 +233,7 @@ func (c *gcmBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, cipherte return nil, err } if params.ICVLen != aead.Overhead() { - return nil, errors.New("pkcs: we do not support non-standard tag size") + return nil, errors.New("pbes: we do not support non-standard tag size") } return aead.Open(nil, params.Nonce, ciphertext, nil) diff --git a/pkcs/kdf_pbkdf2.go b/pkcs/kdf_pbkdf2.go index 047aa50..ed1b89f 100644 --- a/pkcs/kdf_pbkdf2.go +++ b/pkcs/kdf_pbkdf2.go @@ -19,6 +19,7 @@ import ( var ( oidPKCS5PBKDF2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 12} + oidSMPBKDF = asn1.ObjectIdentifier{1, 2, 156, 10197, 6, 4, 1, 5, 1} oidHMACWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 7} oidHMACWithSHA224 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 8} oidHMACWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 9} @@ -33,11 +34,21 @@ func init() { RegisterKDF(oidPKCS5PBKDF2, func() KDFParameters { return new(pbkdf2Params) }) + RegisterKDF(oidSMPBKDF, func() KDFParameters { + return new(pbkdf2Params) + }) } -func newHashFromPRF(ai pkix.AlgorithmIdentifier) (func() hash.Hash, error) { +func newHashFromPRF(oidKDF asn1.ObjectIdentifier, ai pkix.AlgorithmIdentifier) (func() hash.Hash, error) { switch { - case len(ai.Algorithm) == 0 || ai.Algorithm.Equal(oidHMACWithSHA1): + case len(ai.Algorithm) == 0: // handle default case + switch { + case oidKDF.Equal(oidSMPBKDF): + return sm3.New, nil + default: + return sha1.New, nil + } + case ai.Algorithm.Equal(oidHMACWithSHA1): return sha1.New, nil case ai.Algorithm.Equal(oidHMACWithSHA224): return sha256.New224, nil @@ -54,7 +65,7 @@ func newHashFromPRF(ai pkix.AlgorithmIdentifier) (func() hash.Hash, error) { case ai.Algorithm.Equal(oidHMACWithSM3): return sm3.New, nil default: - return nil, errors.New("pkcs8: unsupported hash function") + return nil, errors.New("pbes/pbkdf2: unsupported hash function") } } @@ -94,9 +105,18 @@ func newPRFParamFromHash(h Hash) (pkix.AlgorithmIdentifier, error) { Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil } - return pkix.AlgorithmIdentifier{}, errors.New("pkcs8: unsupported hash function") + return pkix.AlgorithmIdentifier{}, errors.New("pbes/pbkdf2: unsupported hash function") } +// PBKDF2-params ::= SEQUENCE { +// salt CHOICE { +// specified OCTET STRING, +// otherSource AlgorithmIdentifier {{PBKDF2-SaltSources}} +// }, +// iterationCount INTEGER (1..MAX), +// keyLength INTEGER (1..MAX) OPTIONAL, +// prf AlgorithmIdentifier {{PBKDF2-PRFs}} DEFAULT algid-hmacWithSHA1 +//} type pbkdf2Params struct { Salt []byte IterationCount int @@ -104,8 +124,8 @@ type pbkdf2Params struct { PRF pkix.AlgorithmIdentifier `asn1:"optional"` } -func (p pbkdf2Params) DeriveKey(password []byte, size int) (key []byte, err error) { - h, err := newHashFromPRF(p.PRF) +func (p pbkdf2Params) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, size int) (key []byte, err error) { + h, err := newHashFromPRF(oidKDF, p.PRF) if err != nil { return nil, err } @@ -117,6 +137,27 @@ type PBKDF2Opts struct { SaltSize int IterationCount int HMACHash Hash + pbkdfOID asn1.ObjectIdentifier +} + +// NewPBKDF2Opts returns a new PBKDF2Opts with the specified parameters. +func NewPBKDF2Opts(hash Hash, saltSize, iterationCount int) PBKDF2Opts { + return PBKDF2Opts{ + SaltSize: saltSize, + IterationCount: iterationCount, + HMACHash: hash, + pbkdfOID: oidPKCS5PBKDF2, + } +} + +// NewSMPBKDF2Opts returns a new PBKDF2Opts (ShangMi PBKDF) with the specified parameters. +func NewSMPBKDF2Opts(saltSize, iterationCount int) PBKDF2Opts { + return PBKDF2Opts{ + SaltSize: saltSize, + IterationCount: iterationCount, + HMACHash: SM3, + pbkdfOID: oidSMPBKDF, + } } func (p PBKDF2Opts) DeriveKey(password, salt []byte, size int) ( @@ -136,5 +177,9 @@ func (p PBKDF2Opts) GetSaltSize() int { } func (p PBKDF2Opts) OID() asn1.ObjectIdentifier { - return oidPKCS5PBKDF2 + // If the OID is not set, use the default OID for PBKDF2 + if p.pbkdfOID == nil { + return oidPKCS5PBKDF2 + } + return p.pbkdfOID } diff --git a/pkcs/kdf_scrypt.go b/pkcs/kdf_scrypt.go index 822ed3a..e6af136 100644 --- a/pkcs/kdf_scrypt.go +++ b/pkcs/kdf_scrypt.go @@ -27,7 +27,7 @@ type scryptParams struct { ParallelizationParameter int } -func (p scryptParams) DeriveKey(password []byte, size int) (key []byte, err error) { +func (p scryptParams) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, size int) (key []byte, err error) { return scrypt.Key(password, p.Salt, p.CostParameter, p.BlockSize, p.ParallelizationParameter, size) } @@ -40,6 +40,16 @@ type ScryptOpts struct { ParallelizationParameter int } +// NewScryptOpts returns a new ScryptOpts with the specified parameters. +func NewScryptOpts(saltSize, costParameter, blockSize, parallelizationParameter int) ScryptOpts { + return ScryptOpts{ + SaltSize: saltSize, + CostParameter: costParameter, + BlockSize: blockSize, + ParallelizationParameter: parallelizationParameter, + } +} + func (p ScryptOpts) DeriveKey(password, salt []byte, size int) ( key []byte, params KDFParameters, err error) { diff --git a/pkcs/pkcs5_pbes1.go b/pkcs/pkcs5_pbes1.go index 90db1c8..82da347 100644 --- a/pkcs/pkcs5_pbes1.go +++ b/pkcs/pkcs5_pbes1.go @@ -48,7 +48,7 @@ func (pbes1 *PBES1) Key(password []byte) ([]byte, error) { case pbes1.Algorithm.Algorithm.Equal(pbeWithSHA1AndDESCBC) || pbes1.Algorithm.Algorithm.Equal(pbeWithSHA1AndRC2CBC): hash = sha1.New() default: - return nil, errors.New("pkcs5: unsupported pbes1 cipher") + return nil, errors.New("pbes: unsupported pbes1 cipher") } hash.Write(password) hash.Write(param.Salt) @@ -77,7 +77,7 @@ func (pbes1 *PBES1) Decrypt(password, ciphertext []byte) ([]byte, KDFParameters, pbes1.Algorithm.Algorithm.Equal(pbeWithSHA1AndRC2CBC): block, err = rc2.NewCipher(key[:8]) default: - return nil, nil, errors.New("pkcs5: unsupported pbes1 cipher") + return nil, nil, errors.New("pbes: unsupported pbes1 cipher") } if err != nil { return nil, nil, err diff --git a/pkcs/pkcs5_pbes2.go b/pkcs/pkcs5_pbes2.go index a2c452b..0c23797 100644 --- a/pkcs/pkcs5_pbes2.go +++ b/pkcs/pkcs5_pbes2.go @@ -16,7 +16,8 @@ import ( ) var ( - oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} + oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} + oidSMPBES = asn1.ObjectIdentifier{1, 2, 156, 10197, 6, 4, 1, 5, 2} ) // Hash identifies a cryptographic hash function that is implemented in another @@ -56,14 +57,18 @@ func (h Hash) New() hash.Hash { return sha512.New512_256() } - panic("pkcs5: requested hash function #" + strconv.Itoa(int(h)) + " is unavailable") + panic("pbes: requested hash function #" + strconv.Itoa(int(h)) + " is unavailable") } var ( - ErrPBEDecryption = errors.New("pkcs: decryption error, please verify the password and try again") + ErrPBEDecryption = errors.New("pbes: decryption error, please verify the password and try again") ) // PBKDF2Opts contains algorithm identifiers and related parameters for PBKDF2 key derivation function. +// PBES2-params ::= SEQUENCE { +// keyDerivationFunc AlgorithmIdentifier {{PBES2-KDFs}}, +// encryptionScheme AlgorithmIdentifier {{PBES2-Encs}} +// } type PBES2Params struct { KeyDerivationFunc pkix.AlgorithmIdentifier EncryptionScheme pkix.AlgorithmIdentifier @@ -73,6 +78,7 @@ type PBES2Params struct { type PBES2Opts struct { Cipher KDFOpts + pbesOID asn1.ObjectIdentifier } // DefaultOpts are the default options for encrypting a key if none are given. @@ -83,7 +89,32 @@ var DefaultOpts = &PBES2Opts{ SaltSize: 16, IterationCount: 2048, HMACHash: SHA256, + pbkdfOID: oidPKCS5PBKDF2, }, + pbesOID: oidPBES2, +} + +// NewPBES2Encrypter returns a new PBES2Encrypter with the given cipher and KDF options. +func NewPBESEncrypter(cipher Cipher, kdfOpts KDFOpts) PBESEncrypter { + return &PBES2Opts{ + Cipher: cipher, + KDFOpts: kdfOpts, + pbesOID: oidPBES2, + } +} + +// NewSMPBESEncrypterWithKDF returns a new SMPBESEncrypter (ShangMi PBES Encrypter) with the given KDF options. +func NewSMPBESEncrypterWithKDF(kdfOpts KDFOpts) PBESEncrypter { + return &PBES2Opts{ + Cipher: SM4CBC, + KDFOpts: kdfOpts, + pbesOID: oidSMPBES, + } +} + +// NewSMPBESEncrypter returns a new SMPBESEncrypter (ShangMi PBES Encrypter) with the given salt size and iteration count. +func NewSMPBESEncrypter(saltSize, iterationCount int) PBESEncrypter { + return NewSMPBESEncrypterWithKDF(NewSMPBKDF2Opts(saltSize, iterationCount)) } // KDFOpts contains options for a key derivation function. @@ -108,7 +139,7 @@ type PBESEncrypter interface { type KDFParameters interface { // DeriveKey derives a key of size bytes from the given password. // It uses the salt from the decoded parameters. - DeriveKey(password []byte, size int) (key []byte, err error) + DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, size int) (key []byte, err error) } var kdfs = make(map[string]func() KDFParameters) @@ -123,12 +154,12 @@ func (pbes2Params *PBES2Params) parseKeyDerivationFunc() (KDFParameters, error) oid := pbes2Params.KeyDerivationFunc.Algorithm.String() newParams, ok := kdfs[oid] if !ok { - return nil, fmt.Errorf("pkcs5: unsupported KDF (OID: %s)", oid) + return nil, fmt.Errorf("pbes: unsupported KDF (OID: %s)", oid) } params := newParams() _, err := asn1.Unmarshal(pbes2Params.KeyDerivationFunc.Parameters.FullBytes, params) if err != nil { - return nil, errors.New("pkcs5: invalid KDF parameters") + return nil, errors.New("pbes: invalid KDF parameters") } return params, nil } @@ -146,7 +177,7 @@ func (pbes2Params *PBES2Params) Decrypt(password, ciphertext []byte) ([]byte, KD } keySize := cipher.KeySize() - symkey, err := kdfParams.DeriveKey(password, keySize) + symkey, err := kdfParams.DeriveKey(pbes2Params.KeyDerivationFunc.Algorithm, password, keySize) if err != nil { return nil, nil, err } @@ -197,12 +228,22 @@ func (opts *PBES2Opts) Encrypt(rand io.Reader, password, plaintext []byte) (*pki return nil, nil, err } encryptionAlgorithm := pkix.AlgorithmIdentifier{ - Algorithm: oidPBES2, + Algorithm: opts.pbesOID, Parameters: asn1.RawValue{FullBytes: marshalledEncryptionAlgorithmParams}, } + + // fallback to default + if len(encryptionAlgorithm.Algorithm) == 0 { + encryptionAlgorithm.Algorithm = oidPBES2 + } + return &encryptionAlgorithm, ciphertext, nil } func IsPBES2(algorithm pkix.AlgorithmIdentifier) bool { return oidPBES2.Equal(algorithm.Algorithm) } + +func IsSMPBES(algorithm pkix.AlgorithmIdentifier) bool { + return oidSMPBES.Equal(algorithm.Algorithm) +} diff --git a/pkcs8/pkcs8.go b/pkcs8/pkcs8.go index cd5e93d..5798450 100644 --- a/pkcs8/pkcs8.go +++ b/pkcs8/pkcs8.go @@ -64,7 +64,7 @@ func ParsePrivateKey(der []byte, password []byte) (any, pkcs.KDFParameters, erro var decryptedKey []byte var err error switch { - case pkcs.IsPBES2(privKey.EncryptionAlgorithm): + case pkcs.IsPBES2(privKey.EncryptionAlgorithm) || pkcs.IsSMPBES(privKey.EncryptionAlgorithm): var params pkcs.PBES2Params if _, err := asn1.Unmarshal(privKey.EncryptionAlgorithm.Parameters.FullBytes, ¶ms); err != nil { return nil, nil, errors.New("pkcs8: invalid PBES2 parameters") diff --git a/pkcs8/pkcs8_test.go b/pkcs8/pkcs8_test.go index 758626c..cb67fb5 100644 --- a/pkcs8/pkcs8_test.go +++ b/pkcs8/pkcs8_test.go @@ -770,7 +770,7 @@ func TestParseLegacyPBES1PrivateKey(t *testing.T) { if err != nil { t.Errorf("ParsePKCS8PrivateKey returned: %s", err) } - + block, _ = pem.Decode([]byte(encryptedPBEWithSha1AndRC2_64)) _, err = pkcs8.ParsePKCS8PrivateKey(block.Bytes, []byte("12345678")) if err != nil { @@ -788,3 +788,20 @@ func TestParseLegacyPBES1PrivateKey(t *testing.T) { t.Errorf("should have failed") } } + +func TestShangMiPBES(t *testing.T) { + block, _ := pem.Decode([]byte(encryptedPBEWithMD5AndDES)) + priv, err := pkcs8.ParsePKCS8PrivateKey(block.Bytes, []byte("12345678")) + if err != nil { + t.Errorf("ParsePKCS8PrivateKey returned: %s", err) + } + + der, err := pkcs8.MarshalPrivateKey(priv, []byte("12345678"), pkcs.NewSMPBESEncrypter(16, 2048)) + if err != nil { + t.Fatalf("MarshalPrivateKey returned: %s", err) + } + _, _, err = pkcs8.ParsePrivateKey(der, []byte("12345678")) + if err != nil { + t.Fatalf("ParsePrivateKey returned: %s", err) + } +}