From c7f3aa3b6ef3ab14fb5d9257edefdd659b1a8e27 Mon Sep 17 00:00:00 2001 From: Emman Date: Thu, 2 Dec 2021 17:33:39 +0800 Subject: [PATCH] support ASN.1 format --- sm2/sm2.go | 193 ++++++++++++++++++++++++++++++++++++++---------- sm2/sm2_test.go | 103 ++++++++++++++++++++++++-- 2 files changed, 253 insertions(+), 43 deletions(-) diff --git a/sm2/sm2.go b/sm2/sm2.go index 86c61b7..2f3eea4 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -62,23 +62,32 @@ const ( C1C2C3 ) +type ciphertextEncoding byte + +const ( + ENCODING_PLAIN ciphertextEncoding = iota + ENCODING_ASN1 +) + // EncrypterOpts encryption options type EncrypterOpts struct { + CiphertextEncoding ciphertextEncoding PointMarshalMode pointMarshalMode CiphertextSplicingOrder ciphertextSplicingOrder } // DecrypterOpts decryption options type DecrypterOpts struct { + CiphertextEncoding ciphertextEncoding CipherTextSplicingOrder ciphertextSplicingOrder } -func NewEncrypterOpts(marhsalMode pointMarshalMode, splicingOrder ciphertextSplicingOrder) *EncrypterOpts { - return &EncrypterOpts{marhsalMode, splicingOrder} +func NewPlainEncrypterOpts(marhsalMode pointMarshalMode, splicingOrder ciphertextSplicingOrder) *EncrypterOpts { + return &EncrypterOpts{ENCODING_PLAIN, marhsalMode, splicingOrder} } -func NewDecrypterOpts(splicingOrder ciphertextSplicingOrder) *DecrypterOpts { - return &DecrypterOpts{splicingOrder} +func NewPlainDecrypterOpts(splicingOrder ciphertextSplicingOrder) *DecrypterOpts { + return &DecrypterOpts{ENCODING_PLAIN, splicingOrder} } func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte { @@ -92,7 +101,11 @@ func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte } } -var defaultEncrypterOpts = &EncrypterOpts{MarshalUncompressed, C1C3C2} +var defaultEncrypterOpts = &EncrypterOpts{ENCODING_PLAIN, MarshalUncompressed, C1C3C2} + +var ASN1EncrypterOpts = &EncrypterOpts{ENCODING_ASN1, MarshalUncompressed, C1C3C2} + +var ASN1DecrypterOpts = &DecrypterOpts{ENCODING_ASN1, C1C3C2} // directSigning is a standard Hash value that signals that no pre-hashing // should be performed. @@ -237,6 +250,11 @@ func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte { return md.Sum(nil) } +// sm2 encrypt and output ASN.1 result +func EncryptASN1(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) { + return Encrypt(random, pub, msg, ASN1EncrypterOpts) +} + // Encrypt sm2 encrypt implementation func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) { curve := pub.Curve @@ -285,12 +303,23 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter //A7, C3 = hash(x2||M||y2) c3 := calculateC3(curve, x2, y2, msg) - if opts.CiphertextSplicingOrder == C1C3C2 { - // c1 || c3 || c2 - return append(append(c1, c3...), c2...), nil + if opts.CiphertextEncoding == ENCODING_PLAIN { + if opts.CiphertextSplicingOrder == C1C3C2 { + // c1 || c3 || c2 + return append(append(c1, c3...), c2...), nil + } + // c1 || c2 || c3 + return append(append(c1, c2...), c3...), nil + } else { // ASN.1 format will force C3 C2 order + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1BigInt(x1) + b.AddASN1BigInt(y1) + b.AddASN1OctetString(c3) + b.AddASN1OctetString(c2) + }) + return b.Bytes() } - // c1 || c2 || c3 - return append(append(c1, c2...), c3...), nil } } @@ -314,11 +343,59 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { return decrypt(priv, ciphertext, nil) } +func decryptASN1(priv *PrivateKey, ciphertext []byte) ([]byte, error) { + var ( + x1, y1 = &big.Int{}, &big.Int{} + c2, c3 []byte + inner cryptobyte.String + ) + input := cryptobyte.String(ciphertext) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(x1) || + !inner.ReadASN1Integer(y1) || + !inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) || + !inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) || + !inner.Empty() { + return nil, errors.New("SM2: invalid asn1 format ciphertext") + } + return rawDecrypt(priv, x1, y1, c2, c3) +} + +func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error) { + curve := priv.Curve + x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) + msgLen := len(c2) + t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + if !success { + return nil, errors.New("SM2: invalid cipher text") + } + + //B5, calculate msg = c2 ^ t + msg := make([]byte, msgLen) + for i := 0; i < msgLen; i++ { + msg[i] = c2[i] ^ t[i] + } + u := calculateC3(curve, x2, y2, msg) + for i := 0; i < sm3.Size; i++ { + if c3[i] != u[i] { + return nil, errors.New("SM2: invalid hash value") + } + } + return msg, nil +} + func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) { splicingOrder := C1C3C2 if opts != nil { + if opts.CiphertextEncoding == ENCODING_ASN1 { + return decryptASN1(priv, ciphertext) + } splicingOrder = opts.CipherTextSplicingOrder } + if ciphertext[0] == 0x30 { + return decryptASN1(priv, ciphertext) + } ciphertextLen := len(ciphertext) if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size { return nil, errors.New("SM2: invalid ciphertext length") @@ -330,45 +407,85 @@ func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, return nil, err } - //B2 is ignored - //B3, calculate x2, y2 - x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) - //B4, calculate t=KDF(x2||y2, klen) var c2, c3 []byte if splicingOrder == C1C3C2 { c2 = ciphertext[c3Start+sm3.Size:] - } else { - c2 = ciphertext[c3Start : ciphertextLen-sm3.Size] - } - msgLen := len(c2) - t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) - if !success { - return nil, errors.New("SM2: invalid cipher text") - } - - //B5, calculate msg = c2 ^ t - msg := make([]byte, msgLen) - for i := 0; i < msgLen; i++ { - msg[i] = c2[i] ^ t[i] - } - - //B6, calculate hash and compare it - if splicingOrder == C1C3C2 { c3 = ciphertext[c3Start : c3Start+sm3.Size] } else { + c2 = ciphertext[c3Start : ciphertextLen-sm3.Size] c3 = ciphertext[ciphertextLen-sm3.Size:] } - u := calculateC3(curve, x2, y2, msg) - for i := 0; i < sm3.Size; i++ { - if c3[i] != u[i] { - return nil, errors.New("SM2: invalid hash value") - } - } - return msg, nil + return rawDecrypt(priv, x1, y1, c2, c3) } +// utility method to convert ASN.1 encoding ciphertext to plain encoding format +func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error) { + if opts == nil { + opts = defaultEncrypterOpts + } + var ( + x1, y1 = &big.Int{}, &big.Int{} + c2, c3 []byte + inner cryptobyte.String + ) + input := cryptobyte.String(ciphertext) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(x1) || + !inner.ReadASN1Integer(y1) || + !inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) || + !inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) || + !inner.Empty() { + return nil, errors.New("SM2: invalid asn1 format ciphertext") + } + curve := P256() + c1 := opts.PointMarshalMode.mashal(curve, x1, y1) + if opts.CiphertextSplicingOrder == C1C3C2 { + // c1 || c3 || c2 + return append(append(c1, c3...), c2...), nil + } + // c1 || c2 || c3 + return append(append(c1, c2...), c3...), nil +} + +// utility method to convert plain encoding ciphertext to ASN.1 encoding format +func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) { + if ciphertext[0] == 0x30 { + return nil, errors.New("SM2: invalid plain encoding ciphertext") + } + curve := P256() + ciphertextLen := len(ciphertext) + if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size { + return nil, errors.New("SM2: invalid ciphertext length") + } + // get C1, and check C1 + x1, y1, c3Start, err := bytes2Point(curve, ciphertext) + if err != nil { + return nil, err + } + + var c2, c3 []byte + + if from == C1C3C2 { + c2 = ciphertext[c3Start+sm3.Size:] + c3 = ciphertext[c3Start : c3Start+sm3.Size] + } else { + c2 = ciphertext[c3Start : ciphertextLen-sm3.Size] + c3 = ciphertext[ciphertextLen-sm3.Size:] + } + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1BigInt(x1) + b.AddASN1BigInt(y1) + b.AddASN1OctetString(c3) + b.AddASN1OctetString(c2) + }) + return b.Bytes() +} + +// utility method func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicingOrder) ([]byte, error) { curve := P256() if from == to { diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index 438eece..545c342 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -48,11 +48,11 @@ func Test_SplicingOrder(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewEncrypterOpts(MarshalUncompressed, tt.from)) + ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewPlainEncrypterOpts(MarshalUncompressed, tt.from)) if err != nil { t.Fatalf("encrypt failed %v", err) } - plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.from)) + plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.from)) if err != nil { t.Fatalf("decrypt failed %v", err) } @@ -65,7 +65,7 @@ func Test_SplicingOrder(t *testing.T) { if err != nil { t.Fatalf("adjust splicing order failed %v", err) } - plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.to)) + plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.to)) if err != nil { t.Fatalf("decrypt failed after adjust splicing order %v", err) } @@ -76,6 +76,99 @@ func Test_SplicingOrder(t *testing.T) { } } +func Test_encryptDecrypt_ASN1(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + tests := []struct { + name string + plainText string + }{ + // TODO: Add test cases. + {"less than 32", "emmansun"}, + {"equals 32", "encryption standard encryption "}, + {"long than 32", "encryption standard encryption standard"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encrypterOpts := ASN1EncrypterOpts + ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + plaintext, err := priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + }) + } +} + +func Test_Ciphertext2ASN1(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + tests := []struct { + name string + plainText string + }{ + // TODO: Add test cases. + {"less than 32", "emmansun"}, + {"equals 32", "encryption standard encryption "}, + {"long than 32", "encryption standard encryption standard"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + ciphertext, err = PlainCiphertext2ASN1(ciphertext, C1C3C2) + if err != nil { + t.Fatalf("convert to ASN.1 failed %v", err) + } + plaintext, err := priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + }) + } +} + +func Test_ASN1Ciphertext2Plain(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + tests := []struct { + name string + plainText string + }{ + // TODO: Add test cases. + {"less than 32", "emmansun"}, + {"equals 32", "encryption standard encryption "}, + {"long than 32", "encryption standard encryption standard"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ciphertext, err := EncryptASN1(rand.Reader, &priv.PublicKey, []byte(tt.plainText)) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + ciphertext, err = ASN1Ciphertext2Plain(ciphertext, nil) + if err != nil { + t.Fatalf("convert to plain failed %v", err) + } + plaintext, err := priv.Decrypt(rand.Reader, ciphertext, nil) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + }) + } +} + func Test_encryptDecrypt(t *testing.T) { priv, _ := GenerateKey(rand.Reader) tests := []struct { @@ -101,7 +194,7 @@ func Test_encryptDecrypt(t *testing.T) { t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) } // compress mode - encrypterOpts := NewEncrypterOpts(MarshalCompressed, C1C3C2) + encrypterOpts := NewPlainEncrypterOpts(MarshalCompressed, C1C3C2) ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts) if err != nil { t.Fatalf("encrypt failed %v", err) @@ -115,7 +208,7 @@ func Test_encryptDecrypt(t *testing.T) { } // mixed mode - encrypterOpts = NewEncrypterOpts(MarshalMixed, C1C3C2) + encrypterOpts = NewPlainEncrypterOpts(MarshalMixed, C1C3C2) ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts) if err != nil { t.Fatalf("encrypt failed %v", err)