From be62e3a042911d7d215b9fcfbb82231901ebc837 Mon Sep 17 00:00:00 2001 From: Emman Date: Wed, 16 Dec 2020 16:27:36 +0800 Subject: [PATCH] MAGIC - sm2, basic implementation --- sm2/sm2.go | 156 +++++++++++++++++++++++++++++++++++++++++++++++ sm2/sm2_test.go | 57 +++++++++++++++++ sm2/util.go | 120 ++++++++++++++++++++++++++++++++++++ sm2/util_test.go | 79 ++++++++++++++++++++++++ sm3/sm3.go | 4 +- 5 files changed, 414 insertions(+), 2 deletions(-) create mode 100644 sm2/sm2.go create mode 100644 sm2/sm2_test.go create mode 100644 sm2/util.go create mode 100644 sm2/util_test.go diff --git a/sm2/sm2.go b/sm2/sm2.go new file mode 100644 index 0000000..b6027d3 --- /dev/null +++ b/sm2/sm2.go @@ -0,0 +1,156 @@ +package sm2 + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "encoding/binary" + "errors" + "fmt" + "gmsm/sm3" + "io" + "math/big" +) + +const ( + Uncompressed byte = 0x04 + Compressed_02 byte = 0x02 + Compressed_03 byte = 0x03 + Mixed_06 byte = 0x06 + Mixed_07 byte = 0x07 +) + +///////////////// below code ship from golan crypto/ecdsa //////////////////// +var one = new(big.Int).SetInt64(1) + +// randFieldElement returns a random element of the field underlying the given +// curve using the procedure given in [NSA] A.2.1. +func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) { + params := c.Params() + b := make([]byte, params.BitSize/8+8) + _, err = io.ReadFull(rand, b) + if err != nil { + return + } + + k = new(big.Int).SetBytes(b) + n := new(big.Int).Sub(params.N, one) + k.Mod(k, n) + k.Add(k, one) + return +} + +/////////////////////////////////////////////////////////////////////////////////// +func kdf(z []byte, len int) ([]byte, bool) { + limit := (len + sm3.Size - 1) / sm3.Size + sm3Hasher := sm3.New() + var countBytes [4]byte + var ct uint32 = 1 + k := make([]byte, len+sm3.Size-1) + for i := 0; i < limit; i++ { + binary.BigEndian.PutUint32(countBytes[:], ct) + sm3Hasher.Write(z) + sm3Hasher.Write(countBytes[:]) + copy(k[i*sm3.Size:], sm3Hasher.Sum(nil)) + ct++ + sm3Hasher.Reset() + } + for i := 0; i < len; i++ { + if k[i] != 0 { + return k[:len], true + } + } + return k, false +} + +func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte { + hasher := sm3.New() + hasher.Write(toBytes(curve, x2)) + hasher.Write(msg) + hasher.Write(toBytes(curve, y2)) + return hasher.Sum(nil) +} + +// Encrypt sm2 encrypt implementation +func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) { + curve := pub.Curve + msgLen := len(msg) + for { + //A1, generate random k + k, err := randFieldElement(curve, random) + if err != nil { + return nil, err + } + + //A2, calculate C1 = k * G + x1, y1 := curve.ScalarBaseMult(k.Bytes()) + c1 := point2CompressedBytes(curve, x1, y1) + + //A3, skipped + //A4, calculate k * P (point of Public Key) + x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes()) + + //A5, calculate t=KDF(x2||y2, klen) + t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + if !success { + fmt.Println("A5, failed to get valid t") + continue + } + + //A6, C2 = M + t; + c2 := make([]byte, msgLen) + for i := 0; i < msgLen; i++ { + c2[i] = msg[i] ^ t[i] + } + + //A7, C3 = hash(x2||M||y2) + c3 := calculateC3(curve, x2, y2, msg) + + return append(append(c1, c2...), c3...), nil + } +} + +// Decrypt sm2 decrypt implementation +func Decrypt(priv *ecdsa.PrivateKey, ciphertext []byte) ([]byte, error) { + ciphertextLen := len(ciphertext) + if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size { + return nil, errors.New("invalid ciphertext length") + } + curve := priv.Curve + // B1, get C1, and check C1 + x1, y1, c2Start, err := bytes2Point(curve, ciphertext) + if err != nil { + return nil, err + } + if !curve.IsOnCurve(x1, y1) { + return nil, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name) + } + + //B2 is ignored + //B3, calculate x2, y2 + x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) + + //B4, calculate t=KDF(x2||y2, klen) + c2 := ciphertext[c2Start : ciphertextLen-sm3.Size] + msgLen := len(c2) + t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + if !success { + return nil, errors.New("invalid cipher text") + } + + //B5, calculate msg = c2 ^ t + msg := make([]byte, msgLen) + for i := 0; i < msgLen; i++ { + msg[i] = c2[i] ^ t[i] + } + + //B6, calculate hash and compare it + c3 := ciphertext[ciphertextLen-sm3.Size:] + u := calculateC3(curve, x2, y2, msg) + for i := 0; i < sm3.Size; i++ { + if c3[i] != u[i] { + return nil, errors.New("invalid hash value") + } + } + + return msg, nil +} diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go new file mode 100644 index 0000000..5e5c8ae --- /dev/null +++ b/sm2/sm2_test.go @@ -0,0 +1,57 @@ +package sm2 + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/hex" + "math/big" + "reflect" + "testing" +) + +func Test_kdf(t *testing.T) { + x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16) + y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16) + + expected := "006e30dae231b071dfad8aa379e90264491603" + + result, success := kdf(append(x2.Bytes(), y2.Bytes()...), 19) + if !success { + t.Fatalf("failed") + } + + resultStr := hex.EncodeToString(result) + + if expected != resultStr { + t.Fatalf("expected %s, real value %s", expected, resultStr) + } +} + +func Test_encryptDecrypt(t *testing.T) { + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + tests := []struct { + name string + plainText string + }{ + // TODO: Add test cases. + {"less than 32", "encryption standard"}, + {"equals 32", "encryption standard encryption "}, + {"long than 32", "encryption standard encryption standard"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText)) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + plaintext, err := Decrypt(priv, ciphertext) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + }) + } +} diff --git a/sm2/util.go b/sm2/util.go new file mode 100644 index 0000000..c8e6244 --- /dev/null +++ b/sm2/util.go @@ -0,0 +1,120 @@ +package sm2 + +import ( + "crypto/elliptic" + "errors" + "fmt" + "math/big" + "strings" +) + +var zero = new(big.Int).SetInt64(0) + +func toBytes(curve elliptic.Curve, value *big.Int) []byte { + bytes := value.Bytes() + byteLen := (curve.Params().BitSize + 7) >> 3 + if byteLen == len(bytes) { + return bytes + } + result := make([]byte, byteLen) + copy(result[byteLen-len(bytes):], bytes) + return result +} + +func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { + return elliptic.Marshal(curve, x, y) +} + +func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { + buffer := make([]byte, (curve.Params().BitSize+7)>>3+1) + copy(buffer[1:], toBytes(curve, x)) + if getLastBitOfY(x, y) > 0 { + buffer[0] = Compressed_03 + } else { + buffer[0] = Compressed_02 + } + return buffer +} + +func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte { + buffer := elliptic.Marshal(curve, x, y) + if getLastBitOfY(x, y) > 0 { + buffer[0] = Mixed_07 + } else { + buffer[0] = Mixed_06 + } + return buffer +} + +func getLastBitOfY(x, y *big.Int) uint { + if x.Cmp(zero) == 0 { + return 0 + } + return y.Bit(0) +} + +func toPointXY(bytes []byte) *big.Int { + return new(big.Int).SetBytes(bytes) +} + +func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) { + x3 := new(big.Int).Mul(x, x) + x3.Mul(x3, x) + + threeX := new(big.Int).Lsh(x, 1) + threeX.Add(threeX, x) + + x3.Sub(x3, threeX) + x3.Add(x3, curve.Params().B) + x3.Mod(x3, curve.Params().P) + y := x3.ModSqrt(x3, curve.Params().P) + + if y == nil { + return nil, errors.New("can't calculate y based on x") + } + return y, nil +} + +func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) { + if len(bytes) < 1+(curve.Params().BitSize/8) { + return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes)) + } + format := bytes[0] + byteLen := (curve.Params().BitSize + 7) >> 3 + switch format { + case Uncompressed: + if len(bytes) < 1+byteLen*2 { + return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes)) + } + x := toPointXY(bytes[1 : 1+byteLen]) + y := toPointXY(bytes[1+byteLen : 1+byteLen*2]) + return x, y, 1 + byteLen*2, nil + case Compressed_02, Compressed_03: + if len(bytes) < 1+byteLen { + return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes)) + } + if strings.HasPrefix(curve.Params().Name, "P-") { + // y² = x³ - 3x + b + x := toPointXY(bytes[1 : 1+byteLen]) + y, err := calculatePrimeCurveY(curve, x) + if err != nil { + return nil, nil, 0, err + } + + if (getLastBitOfY(x, y) > 0 && format == Compressed_02) || (getLastBitOfY(x, y) == 0 && format == Compressed_03) { + y.Sub(curve.Params().P, y) + } + return x, y, 1 + byteLen, nil + } + return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name) + case Mixed_06, Mixed_07: + // what's the mixed format purpose? + if len(bytes) < 1+byteLen*2 { + return nil, nil, 0, fmt.Errorf("invalid mixed bytes length %d", len(bytes)) + } + x := toPointXY(bytes[1 : 1+byteLen]) + y := toPointXY(bytes[1+byteLen : 1+byteLen*2]) + return x, y, 1 + byteLen*2, nil + } + return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format) +} diff --git a/sm2/util_test.go b/sm2/util_test.go new file mode 100644 index 0000000..5f42478 --- /dev/null +++ b/sm2/util_test.go @@ -0,0 +1,79 @@ +package sm2 + +import ( + "crypto/elliptic" + "encoding/hex" + "math/big" + "reflect" + "testing" +) + +func Test_toBytes(t *testing.T) { + type args struct { + value string + } + tests := []struct { + name string + args args + want string + }{ + // TODO: Add test cases. + {"less than 32", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "00d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, + {"equals 32", args{"58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v, _ := new(big.Int).SetString(tt.args.value, 16) + if got := toBytes(elliptic.P256(), v); !reflect.DeepEqual(hex.EncodeToString(got), tt.want) { + t.Errorf("toBytes() = %v, want %v", hex.EncodeToString(got), tt.want) + } + }) + } +} + +func Test_getLastBitOfY(t *testing.T) { + type args struct { + y string + } + tests := []struct { + name string + args args + want uint + }{ + // TODO: Add test cases. + {"0", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, 0}, + {"1", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865ff"}, 1}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + y, _ := new(big.Int).SetString(tt.args.y, 16) + if got := getLastBitOfY(y, y); got != tt.want { + t.Errorf("getLastBitOfY() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_toPointXY(t *testing.T) { + type args struct { + bytes string + } + tests := []struct { + name string + args args + want string + }{ + // TODO: Add test cases. + {"has zero padding", args{"00d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, + {"no zero padding", args{"58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bytes, _ := hex.DecodeString(tt.args.bytes) + expectedInt, _ := new(big.Int).SetString(tt.want, 16) + if got := toPointXY(bytes); !reflect.DeepEqual(got, expectedInt) { + t.Errorf("toPointXY() = %v, want %v", got, expectedInt) + } + }) + } +} diff --git a/sm3/sm3.go b/sm3/sm3.go index ea219b6..d2f8409 100644 --- a/sm3/sm3.go +++ b/sm3/sm3.go @@ -7,10 +7,10 @@ import ( ) // Size the size of a SM3 checksum in bytes. -const Size = 32 +const Size int = 32 // BlockSize the blocksize of SM3 in bytes. -const BlockSize = 64 +const BlockSize int = 64 const ( chunk = 64