From 5b5942db8461f4d8a3faf83033a269f925aaf536 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Wed, 11 May 2022 16:58:36 +0800 Subject: [PATCH] refactor point marshal/unmarshal --- sm2/util.go | 32 +++++++++----------------------- sm2/util_test.go | 23 ----------------------- 2 files changed, 9 insertions(+), 46 deletions(-) diff --git a/sm2/util.go b/sm2/util.go index 4f7f3d6..b546ab5 100644 --- a/sm2/util.go +++ b/sm2/util.go @@ -10,8 +10,6 @@ import ( "sync" ) -var zero = big.NewInt(0) - func toBytes(curve elliptic.Curve, value *big.Int) []byte { bytes := value.Bytes() byteLen := (curve.Params().BitSize + 7) >> 3 @@ -30,31 +28,16 @@ func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { buffer := make([]byte, (curve.Params().BitSize+7)>>3+1) copy(buffer[1:], toBytes(curve, x)) - if getLastBitOfY(x, y) > 0 { - buffer[0] = compressed03 - } else { - buffer[0] = compressed02 - } + buffer[0] = byte(y.Bit(0)) | compressed02 return buffer } func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte { buffer := elliptic.Marshal(curve, x, y) - if getLastBitOfY(x, y) > 0 { - buffer[0] = mixed07 - } else { - buffer[0] = mixed06 - } + buffer[0] = byte(y.Bit(0)) | mixed06 return buffer } -func getLastBitOfY(x, y *big.Int) uint { - if x.Cmp(zero) == 0 { - return 0 - } - return y.Bit(0) -} - func toPointXY(bytes []byte) *big.Int { return new(big.Int).SetBytes(bytes) } @@ -91,13 +74,14 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e x := toPointXY(bytes[1 : 1+byteLen]) y := toPointXY(bytes[1+byteLen : 1+byteLen*2]) if !curve.IsOnCurve(x, y) { - return nil, nil, 0, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name) + return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name) } return x, y, 1 + byteLen*2, nil case compressed02, compressed03: if len(bytes) < 1+byteLen { return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes)) } + // Make sure it's NIST curve or SM2 P-256 curve if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) { // y² = x³ - 3x + b, prime curves x := toPointXY(bytes[1 : 1+byteLen]) @@ -105,9 +89,11 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e if err != nil { return nil, nil, 0, err } - - if (getLastBitOfY(x, y) > 0 && format == compressed02) || (getLastBitOfY(x, y) == 0 && format == compressed03) { - y.Sub(curve.Params().P, y) + if byte(y.Bit(0)) != bytes[0]&1 { + y.Neg(y).Mod(y, curve.Params().P) + } + if !curve.IsOnCurve(x, y) { + return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name) } return x, y, 1 + byteLen, nil } diff --git a/sm2/util_test.go b/sm2/util_test.go index 5f42478..6828f24 100644 --- a/sm2/util_test.go +++ b/sm2/util_test.go @@ -31,29 +31,6 @@ func Test_toBytes(t *testing.T) { } } -func Test_getLastBitOfY(t *testing.T) { - type args struct { - y string - } - tests := []struct { - name string - args args - want uint - }{ - // TODO: Add test cases. - {"0", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, 0}, - {"1", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865ff"}, 1}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - y, _ := new(big.Int).SetString(tt.args.y, 16) - if got := getLastBitOfY(y, y); got != tt.want { - t.Errorf("getLastBitOfY() = %v, want %v", got, tt.want) - } - }) - } -} - func Test_toPointXY(t *testing.T) { type args struct { bytes string