refactor point marshal/unmarshal

This commit is contained in:
Sun Yimin 2022-05-11 16:58:36 +08:00 committed by GitHub
parent dafbb30c6e
commit 5b5942db84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 46 deletions

View File

@ -10,8 +10,6 @@ import (
"sync"
)
var zero = big.NewInt(0)
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
bytes := value.Bytes()
byteLen := (curve.Params().BitSize + 7) >> 3
@ -30,31 +28,16 @@ func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
buffer := make([]byte, (curve.Params().BitSize+7)>>3+1)
copy(buffer[1:], toBytes(curve, x))
if getLastBitOfY(x, y) > 0 {
buffer[0] = compressed03
} else {
buffer[0] = compressed02
}
buffer[0] = byte(y.Bit(0)) | compressed02
return buffer
}
func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
buffer := elliptic.Marshal(curve, x, y)
if getLastBitOfY(x, y) > 0 {
buffer[0] = mixed07
} else {
buffer[0] = mixed06
}
buffer[0] = byte(y.Bit(0)) | mixed06
return buffer
}
func getLastBitOfY(x, y *big.Int) uint {
if x.Cmp(zero) == 0 {
return 0
}
return y.Bit(0)
}
func toPointXY(bytes []byte) *big.Int {
return new(big.Int).SetBytes(bytes)
}
@ -91,13 +74,14 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e
x := toPointXY(bytes[1 : 1+byteLen])
y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
if !curve.IsOnCurve(x, y) {
return nil, nil, 0, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name)
return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name)
}
return x, y, 1 + byteLen*2, nil
case compressed02, compressed03:
if len(bytes) < 1+byteLen {
return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes))
}
// Make sure it's NIST curve or SM2 P-256 curve
if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) {
// y² = x³ - 3x + b, prime curves
x := toPointXY(bytes[1 : 1+byteLen])
@ -105,9 +89,11 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e
if err != nil {
return nil, nil, 0, err
}
if (getLastBitOfY(x, y) > 0 && format == compressed02) || (getLastBitOfY(x, y) == 0 && format == compressed03) {
y.Sub(curve.Params().P, y)
if byte(y.Bit(0)) != bytes[0]&1 {
y.Neg(y).Mod(y, curve.Params().P)
}
if !curve.IsOnCurve(x, y) {
return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name)
}
return x, y, 1 + byteLen, nil
}

View File

@ -31,29 +31,6 @@ func Test_toBytes(t *testing.T) {
}
}
func Test_getLastBitOfY(t *testing.T) {
type args struct {
y string
}
tests := []struct {
name string
args args
want uint
}{
// TODO: Add test cases.
{"0", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, 0},
{"1", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865ff"}, 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
y, _ := new(big.Int).SetString(tt.args.y, 16)
if got := getLastBitOfY(y, y); got != tt.want {
t.Errorf("getLastBitOfY() = %v, want %v", got, tt.want)
}
})
}
}
func Test_toPointXY(t *testing.T) {
type args struct {
bytes string