mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-27 20:56:18 +08:00
refactor point marshal/unmarshal
This commit is contained in:
parent
dafbb30c6e
commit
5b5942db84
32
sm2/util.go
32
sm2/util.go
@ -10,8 +10,6 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var zero = big.NewInt(0)
|
||||
|
||||
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
|
||||
bytes := value.Bytes()
|
||||
byteLen := (curve.Params().BitSize + 7) >> 3
|
||||
@ -30,31 +28,16 @@ func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||
func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||
buffer := make([]byte, (curve.Params().BitSize+7)>>3+1)
|
||||
copy(buffer[1:], toBytes(curve, x))
|
||||
if getLastBitOfY(x, y) > 0 {
|
||||
buffer[0] = compressed03
|
||||
} else {
|
||||
buffer[0] = compressed02
|
||||
}
|
||||
buffer[0] = byte(y.Bit(0)) | compressed02
|
||||
return buffer
|
||||
}
|
||||
|
||||
func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||
buffer := elliptic.Marshal(curve, x, y)
|
||||
if getLastBitOfY(x, y) > 0 {
|
||||
buffer[0] = mixed07
|
||||
} else {
|
||||
buffer[0] = mixed06
|
||||
}
|
||||
buffer[0] = byte(y.Bit(0)) | mixed06
|
||||
return buffer
|
||||
}
|
||||
|
||||
func getLastBitOfY(x, y *big.Int) uint {
|
||||
if x.Cmp(zero) == 0 {
|
||||
return 0
|
||||
}
|
||||
return y.Bit(0)
|
||||
}
|
||||
|
||||
func toPointXY(bytes []byte) *big.Int {
|
||||
return new(big.Int).SetBytes(bytes)
|
||||
}
|
||||
@ -91,13 +74,14 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e
|
||||
x := toPointXY(bytes[1 : 1+byteLen])
|
||||
y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
|
||||
if !curve.IsOnCurve(x, y) {
|
||||
return nil, nil, 0, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name)
|
||||
return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name)
|
||||
}
|
||||
return x, y, 1 + byteLen*2, nil
|
||||
case compressed02, compressed03:
|
||||
if len(bytes) < 1+byteLen {
|
||||
return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes))
|
||||
}
|
||||
// Make sure it's NIST curve or SM2 P-256 curve
|
||||
if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) {
|
||||
// y² = x³ - 3x + b, prime curves
|
||||
x := toPointXY(bytes[1 : 1+byteLen])
|
||||
@ -105,9 +89,11 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
if (getLastBitOfY(x, y) > 0 && format == compressed02) || (getLastBitOfY(x, y) == 0 && format == compressed03) {
|
||||
y.Sub(curve.Params().P, y)
|
||||
if byte(y.Bit(0)) != bytes[0]&1 {
|
||||
y.Neg(y).Mod(y, curve.Params().P)
|
||||
}
|
||||
if !curve.IsOnCurve(x, y) {
|
||||
return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name)
|
||||
}
|
||||
return x, y, 1 + byteLen, nil
|
||||
}
|
||||
|
@ -31,29 +31,6 @@ func Test_toBytes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getLastBitOfY(t *testing.T) {
|
||||
type args struct {
|
||||
y string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want uint
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{"0", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, 0},
|
||||
{"1", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865ff"}, 1},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
y, _ := new(big.Int).SetString(tt.args.y, 16)
|
||||
if got := getLastBitOfY(y, y); got != tt.want {
|
||||
t.Errorf("getLastBitOfY() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_toPointXY(t *testing.T) {
|
||||
type args struct {
|
||||
bytes string
|
||||
|
Loading…
x
Reference in New Issue
Block a user