diff --git a/internal/sm9/bn256/gfp.go b/internal/sm9/bn256/gfp.go index 2c614e8..ea55216 100644 --- a/internal/sm9/bn256/gfp.go +++ b/internal/sm9/bn256/gfp.go @@ -12,6 +12,11 @@ var zero = newGFp(0) var one = newGFp(1) var two = newGFp(2) +// newGFp creates a new gfP element from the given int64 value. +// If the input value is non-negative, it directly converts it to uint64. +// If the input value is negative, it converts the absolute value to uint64 +// and then negates the resulting gfP element. +// The resulting gfP element is then encoded in Montgomery form. func newGFp(x int64) (out *gfP) { if x >= 0 { out = &gfP{uint64(x)} @@ -24,6 +29,8 @@ func newGFp(x int64) (out *gfP) { return out } +// newGFpFromBytes creates a new gfP element from a byte slice. +// It unmarshals the byte slice into a gfP element, then encodes it in Montgomery form. func newGFpFromBytes(in []byte) (out *gfP) { out = &gfP{} gfpUnmarshal(out, (*[32]byte)(in)) @@ -40,6 +47,17 @@ func (e *gfP) Set(f *gfP) *gfP { return e } +// exp calculates the exponentiation of a given gfP element `f` raised to the power +// represented by the 256-bit integer `bits`. The result is stored in the gfP element `e`. +// +// The function uses a square-and-multiply algorithm to perform the exponentiation. +// It iterates over each bit of the 256-bit integer `bits`, and for each bit, it squares +// the current power and multiplies it to the result if the bit is set. +// +// Parameters: +// - f: The base gfP element to be exponentiated. +// - bits: A 256-bit integer represented as an array of 4 uint64 values, where bits[0] +// contains the least significant 64 bits and bits[3] contains the most significant 64 bits. func (e *gfP) exp(f *gfP, bits [4]uint64) { sum, power := &gfP{}, &gfP{} sum.Set(rN1) @@ -94,7 +112,12 @@ func (e *gfP) Sqrt(f *gfP) { e.Set(i) } +// Marshal serializes the gfP element into the provided byte slice. +// The output byte slice must be at least 32 bytes long. func (e *gfP) Marshal(out []byte) { + if len(out) < 32 { + panic("sm9: invalid out length") + } gfpMarshal((*[32]byte)(out), e) } @@ -110,6 +133,9 @@ func uint64IsZero(x uint64) int { return int(x & 1) } +// lessThanP returns 1 if the given gfP element x is less than the prime modulus p2, +// and 0 otherwise. It performs a subtraction of x from p2 and checks the borrow bit +// to determine if x is less than p2. func lessThanP(x *gfP) int { var b uint64 _, b = bits.Sub64(x[0], p2[0], b) @@ -119,19 +145,18 @@ func lessThanP(x *gfP) int { return int(b) } +// Unmarshal decodes a 32-byte big-endian representation of a gfP element. +// It returns an error if the input length is not 32 bytes or if the decoded +// value is not a valid gfP element (i.e., greater than or equal to the field prime). func (e *gfP) Unmarshal(in []byte) error { - gfpUnmarshal(e, (*[32]byte)(in)) - // Ensure the point respects the curve modulus - // TODO: Do we need to change it to constant time version ? - for i := 3; i >= 0; i-- { - if e[i] < p2[i] { - return nil - } - if e[i] > p2[i] { - return errors.New("sm9: coordinate exceeds modulus") - } + if len(in) < 32 { + return errors.New("sm9: invalid input length") } - return errors.New("sm9: coordinate equals modulus") + gfpUnmarshal(e, (*[32]byte)(in)) + if lessThanP(e) == 0 { + return errors.New("sm9: invalid gfP encoding") + } + return nil } func montEncode(c, a *gfP) { gfpMul(c, a, r2) } diff --git a/internal/sm9/bn256/gfp_test.go b/internal/sm9/bn256/gfp_test.go index fafd041..7505cc9 100644 --- a/internal/sm9/bn256/gfp_test.go +++ b/internal/sm9/bn256/gfp_test.go @@ -1,6 +1,7 @@ package bn256 import ( + "bytes" "encoding/hex" "math/big" "testing" @@ -276,6 +277,43 @@ func TestGfpNeg(t *testing.T) { } } +func TestGfpUnmarshal(t *testing.T) { + validHex := "85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141" + invalidHex := "b640000002a3a6f1d603ab4ff58ec74521f2934b1a7aeedbe56f9b27e351457d" + + t.Run("valid input", func(t *testing.T) { + x, _ := hex.DecodeString(validHex) + var out [32]byte + ret := &gfP{} + err := ret.Unmarshal(x[:]) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + ret.Marshal(out[:]) + if !bytes.Equal(out[:], x) { + t.Errorf("got %x, expected %x", out, x) + } + }) + + t.Run("invalid length", func(t *testing.T) { + x, _ := hex.DecodeString(validHex) + ret := &gfP{} + err := ret.Unmarshal(x[1:]) + if err == nil || err.Error() != "sm9: invalid input length" { + t.Errorf("expected error, got %v", err) + } + }) + + t.Run("invalid value", func(t *testing.T) { + x, _ := hex.DecodeString(invalidHex) + ret := &gfP{} + err := ret.Unmarshal(x[:]) + if err == nil || err.Error() != "sm9: invalid gfP encoding" { + t.Errorf("expected error, got %v", err) + } + }) +} + func BenchmarkGfPUnmarshal(b *testing.B) { x := newGFpFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596") b.ReportAllocs()