internal/sm9/bn256: remove useless code

This commit is contained in:
Sun Yimin 2025-03-26 09:49:30 +08:00 committed by GitHub
parent a7c4473a48
commit e8a847e005
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 9 additions and 1238 deletions

View File

@ -2,6 +2,14 @@ package bn256
import "math/big"
func bigFromHex(s string) *big.Int {
b, ok := new(big.Int).SetString(s, 16)
if !ok {
panic("sm9: internal error: invalid encoding")
}
return b
}
// u is the BN parameter that determines the prime: 600000000058f98a.
var u = bigFromHex("600000000058f98a")

View File

@ -1,182 +0,0 @@
package bn256
import (
"io"
"math/big"
)
// A Curve represents a short-form Weierstrass curve with a=0.
//
// The behavior of Add, Double, and ScalarMult when the input is not a point on
// the curve is undefined.
//
// Note that the conventional point at infinity (0, 0) is not considered on the
// curve, although it can be returned by Add, Double, ScalarMult, or
// ScalarBaseMult (but not the Unmarshal or UnmarshalCompressed functions).
type Curve interface {
// Params returns the parameters for the curve.
Params() *CurveParams
// IsOnCurve reports whether the given (x,y) lies on the curve.
IsOnCurve(x, y *big.Int) bool
// Add returns the sum of (x1,y1) and (x2,y2)
Add(x1, y1, x2, y2 *big.Int) (x, y *big.Int)
// Double returns 2*(x,y)
Double(x1, y1 *big.Int) (x, y *big.Int)
// ScalarMult returns k*(Bx,By) where k is a number in big-endian form.
ScalarMult(x1, y1 *big.Int, k []byte) (x, y *big.Int)
// ScalarBaseMult returns k*G, where G is the base point of the group
// and k is an integer in big-endian form.
ScalarBaseMult(k []byte) (x, y *big.Int)
}
var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f}
// GenerateKey returns a public/private key pair. The private key is
// generated using the given reader, which must return random data.
func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) {
N := curve.Params().N
bitSize := N.BitLen()
byteLen := (bitSize + 7) / 8
priv = make([]byte, byteLen)
for x == nil {
_, err = io.ReadFull(rand, priv)
if err != nil {
return
}
// We have to mask off any excess bits in the case that the size of the
// underlying field is not a whole number of bytes.
priv[0] &= mask[bitSize%8]
// This is because, in tests, rand will return all zeros and we don't
// want to get the point at infinity and loop forever.
priv[1] ^= 0x42
// If the scalar is out of range, sample another random number.
if new(big.Int).SetBytes(priv).Cmp(N) >= 0 {
continue
}
x, y = curve.ScalarBaseMult(priv)
}
return
}
// Marshal converts a point on the curve into the uncompressed form specified in
// SEC 1, Version 2.0, Section 2.3.3. If the point is not on the curve (or is
// the conventional point at infinity), the behavior is undefined.
func Marshal(curve Curve, x, y *big.Int) []byte {
panicIfNotOnCurve(curve, x, y)
byteLen := (curve.Params().BitSize + 7) / 8
ret := make([]byte, 1+2*byteLen)
ret[0] = 4 // uncompressed point
x.FillBytes(ret[1 : 1+byteLen])
y.FillBytes(ret[1+byteLen : 1+2*byteLen])
return ret
}
// MarshalCompressed converts a point on the curve into the compressed form
// specified in SEC 1, Version 2.0, Section 2.3.3. If the point is not on the
// curve (or is the conventional point at infinity), the behavior is undefined.
func MarshalCompressed(curve Curve, x, y *big.Int) []byte {
panicIfNotOnCurve(curve, x, y)
byteLen := (curve.Params().BitSize + 7) / 8
compressed := make([]byte, 1+byteLen)
compressed[0] = byte(y.Bit(0)) | 2
x.FillBytes(compressed[1:])
return compressed
}
// unmarshaler is implemented by curves with their own constant-time Unmarshal.
//
// There isn't an equivalent interface for Marshal/MarshalCompressed because
// that doesn't involve any mathematical operations, only FillBytes and Bit.
type unmarshaler interface {
Unmarshal([]byte) (x, y *big.Int)
UnmarshalCompressed([]byte) (x, y *big.Int)
}
// Unmarshal converts a point, serialized by Marshal, into an x, y pair. It is
// an error if the point is not in uncompressed form, is not on the curve, or is
// the point at infinity. On error, x = nil.
func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
if c, ok := curve.(unmarshaler); ok {
return c.Unmarshal(data)
}
byteLen := (curve.Params().BitSize + 7) / 8
if len(data) != 1+2*byteLen {
return nil, nil
}
if data[0] != 4 { // uncompressed form
return nil, nil
}
p := curve.Params().P
x = new(big.Int).SetBytes(data[1 : 1+byteLen])
y = new(big.Int).SetBytes(data[1+byteLen:])
if x.Cmp(p) >= 0 || y.Cmp(p) >= 0 {
return nil, nil
}
if !curve.IsOnCurve(x, y) {
return nil, nil
}
return
}
// UnmarshalCompressed converts a point, serialized by MarshalCompressed, into
// an x, y pair. It is an error if the point is not in compressed form, is not
// on the curve, or is the point at infinity. On error, x = nil.
func UnmarshalCompressed(curve Curve, data []byte) (x, y *big.Int) {
if c, ok := curve.(unmarshaler); ok {
return c.UnmarshalCompressed(data)
}
byteLen := (curve.Params().BitSize + 7) / 8
if len(data) != 1+byteLen {
return nil, nil
}
if data[0] != 2 && data[0] != 3 { // compressed form
return nil, nil
}
p := curve.Params().P
x = new(big.Int).SetBytes(data[1:])
if x.Cmp(p) >= 0 {
return nil, nil
}
// y² = x³ + b
y = curve.Params().polynomial(x)
y = y.ModSqrt(y, p)
if y == nil {
return nil, nil
}
if byte(y.Bit(0)) != data[0]&1 {
y.Neg(y).Mod(y, p)
}
if !curve.IsOnCurve(x, y) {
return nil, nil
}
return
}
func panicIfNotOnCurve(curve Curve, x, y *big.Int) {
// (0, 0) is the point at infinity by convention. It's ok to operate on it,
// although IsOnCurve is documented to return false for it. See Issue 37294.
if x.Sign() == 0 && y.Sign() == 0 {
return
}
if !curve.IsOnCurve(x, y) {
panic("sm9/elliptic: attempted operation on invalid point")
}
}
func bigFromHex(s string) *big.Int {
b, ok := new(big.Int).SetString(s, 16)
if !ok {
panic("sm9/elliptic: internal error: invalid encoding")
}
return b
}

View File

@ -375,179 +375,3 @@ func (e *G1) Equal(other *G1) bool {
func (e *G1) IsOnCurve() bool {
return e.p.IsOnCurve()
}
type G1Curve struct {
params *CurveParams
g G1
}
var g1Curve = &G1Curve{
params: &CurveParams{
Name: "sm9",
BitSize: 256,
P: bigFromHex("B640000002A3A6F1D603AB4FF58EC74521F2934B1A7AEEDBE56F9B27E351457D"),
N: bigFromHex("B640000002A3A6F1D603AB4FF58EC74449F2934B18EA8BEEE56EE19CD69ECF25"),
B: bigFromHex("0000000000000000000000000000000000000000000000000000000000000005"),
Gx: bigFromHex("93DE051D62BF718FF5ED0704487D01D6E1E4086909DC3280E8C4E4817C66DDDD"),
Gy: bigFromHex("21FE8DDA4F21E607631065125C395BBC1C1C00CBFA6024350C464CD70A3EA616"),
},
g: G1{},
}
func (g1 *G1Curve) pointFromAffine(x, y *big.Int) (a *G1, err error) {
a = &G1{&curvePoint{}}
if x.Sign() == 0 {
a.p.SetInfinity()
return a, nil
}
// Reject values that would not get correctly encoded.
if x.Sign() < 0 || y.Sign() < 0 {
return a, errors.New("negative coordinate")
}
if x.BitLen() > g1.params.BitSize || y.BitLen() > g1.params.BitSize {
return a, errors.New("overflowing coordinate")
}
var buf [32]byte
x.FillBytes(buf[:])
a.p.x = *newGFpFromBytes(buf[:])
y.FillBytes(buf[:])
a.p.y = *newGFpFromBytes(buf[:])
a.p.z = *newGFp(1)
a.p.t = *newGFp(1)
if !a.p.IsOnCurve() {
return a, errors.New("point not on G1 curve")
}
return a, nil
}
func (g1 *G1Curve) Params() *CurveParams {
return g1.params
}
// normalizeScalar brings the scalar within the byte size of the order of the
// curve, as expected by the nistec scalar multiplication functions.
func (curve *G1Curve) normalizeScalar(scalar []byte) []byte {
byteSize := (curve.params.N.BitLen() + 7) / 8
s := new(big.Int).SetBytes(scalar)
if len(scalar) > byteSize {
s.Mod(s, curve.params.N)
}
out := make([]byte, byteSize)
return s.FillBytes(out)
}
func (g1 *G1Curve) ScalarBaseMult(scalar []byte) (*big.Int, *big.Int) {
scalar = g1.normalizeScalar(scalar)
p, err := g1.g.ScalarBaseMult(scalar)
if err != nil {
panic("sm9: g1 rejected normalized scalar")
}
res := p.Marshal()
return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:])
}
func (g1 *G1Curve) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) {
a, err := g1.pointFromAffine(Bx, By)
if err != nil {
panic("sm9: ScalarMult was called on an invalid point")
}
scalar = g1.normalizeScalar(scalar)
p, err := g1.g.ScalarMult(a, scalar)
if err != nil {
panic("sm9: g1 rejected normalized scalar")
}
res := p.Marshal()
return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:])
}
func (g1 *G1Curve) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) {
a, err := g1.pointFromAffine(x1, y1)
if err != nil {
panic("sm9: Add was called on an invalid point")
}
b, err := g1.pointFromAffine(x2, y2)
if err != nil {
panic("sm9: Add was called on an invalid point")
}
res := g1.g.Add(a, b).Marshal()
return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:])
}
func (g1 *G1Curve) Double(x, y *big.Int) (*big.Int, *big.Int) {
a, err := g1.pointFromAffine(x, y)
if err != nil {
panic("sm9: Double was called on an invalid point")
}
res := g1.g.Double(a).Marshal()
return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:])
}
func (g1 *G1Curve) IsOnCurve(x, y *big.Int) bool {
_, err := g1.pointFromAffine(x, y)
return err == nil
}
func (curve *G1Curve) UnmarshalCompressed(data []byte) (x, y *big.Int) {
if len(data) != 33 || (data[0] != 2 && data[0] != 3) {
return nil, nil
}
r := &gfP{}
r.Unmarshal(data[1:33])
if lessThanP(r) == 0 {
return nil, nil
}
x = new(big.Int).SetBytes(data[1:33])
p := &curvePoint{}
montEncode(r, r)
p.x = *r
p.z = *newGFp(1)
p.t = *newGFp(1)
y2 := &gfP{}
gfpMul(y2, r, r)
gfpMul(y2, y2, r)
gfpAdd(y2, y2, curveB)
y2.Sqrt(y2)
p.y = *y2
if !p.IsOnCurve() {
return nil, nil
}
montDecode(y2, y2)
ret := make([]byte, 32)
y2.Marshal(ret)
y = new(big.Int).SetBytes(ret)
if byte(y.Bit(0)) != data[0]&1 {
gfpNeg(y2, y2)
y2.Marshal(ret)
y.SetBytes(ret)
}
return x, y
}
func (curve *G1Curve) Unmarshal(data []byte) (x, y *big.Int) {
if len(data) != 65 || (data[0] != 4) {
return nil, nil
}
x1 := &gfP{}
x1.Unmarshal(data[1:33])
y1 := &gfP{}
y1.Unmarshal(data[33:])
if lessThanP(x1) == 0 || lessThanP(y1) == 0 {
return nil, nil
}
montEncode(x1, x1)
montEncode(y1, y1)
p := &curvePoint{
x: *x1,
y: *y1,
z: *newGFp(1),
t: *newGFp(1),
}
if !p.IsOnCurve() {
return nil, nil
}
x = new(big.Int).SetBytes(data[1:33])
y = new(big.Int).SetBytes(data[33:])
return x, y
}

View File

@ -1,13 +1,7 @@
package bn256
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"math/big"
"testing"
"time"
)
func TestG1AddNeg(t *testing.T) {
@ -237,471 +231,6 @@ var baseMultTests = []g1BaseMultTest{
},
}
func TestG1BaseMult(t *testing.T) {
g1 := g1Curve
g1Generic := g1.Params()
scalars := make([]*big.Int, 0, len(baseMultTests)+1)
for i := 1; i <= 20; i++ {
k := new(big.Int).SetInt64(int64(i))
scalars = append(scalars, k)
}
for _, e := range baseMultTests {
k, _ := new(big.Int).SetString(e.k, 10)
scalars = append(scalars, k)
}
k := new(big.Int).SetInt64(1)
k.Lsh(k, 500)
scalars = append(scalars, k)
for i, k := range scalars {
x, y := g1.ScalarBaseMult(k.Bytes())
x2, y2 := g1Generic.ScalarBaseMult(k.Bytes())
if x.Cmp(x2) != 0 || y.Cmp(y2) != 0 {
t.Errorf("#%d: got (%x, %x), want (%x, %x)", i, x, y, x2, y2)
}
if testing.Short() && i > 5 {
break
}
}
}
func TestG1ScalarMult(t *testing.T) {
checkScalar := func(t *testing.T, scalar []byte) {
p1, err := (&G1{}).ScalarBaseMult(scalar)
fatalIfErr(t, err)
p2, err := (&G1{}).ScalarMult(Gen1, scalar)
fatalIfErr(t, err)
p1.p.MakeAffine()
p2.p.MakeAffine()
if !p1.Equal(p2) {
t.Error("[k]G != ScalarBaseMult(k)")
}
d := new(big.Int).SetBytes(scalar)
d.Sub(Order, d)
d.Mod(d, Order)
g1, err := (&G1{}).ScalarBaseMult(d.FillBytes(make([]byte, len(scalar))))
fatalIfErr(t, err)
g1.Add(g1, p1)
g1.p.MakeAffine()
if !g1.p.IsInfinity() {
t.Error("[N - k]G + [k]G != ∞")
}
}
byteLen := len(Order.Bytes())
bitLen := Order.BitLen()
t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) })
t.Run("1", func(t *testing.T) {
checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen)))
})
t.Run("N-6", func(t *testing.T) {
checkScalar(t, new(big.Int).Sub(Order, big.NewInt(6)).Bytes())
})
t.Run("N-1", func(t *testing.T) {
checkScalar(t, new(big.Int).Sub(Order, big.NewInt(1)).Bytes())
})
t.Run("N", func(t *testing.T) { checkScalar(t, Order.Bytes()) })
t.Run("N+1", func(t *testing.T) {
checkScalar(t, new(big.Int).Add(Order, big.NewInt(1)).Bytes())
})
t.Run("N+22", func(t *testing.T) {
checkScalar(t, new(big.Int).Add(Order, big.NewInt(22)).Bytes())
})
t.Run("all1s", func(t *testing.T) {
s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen))
s.Sub(s, big.NewInt(1))
checkScalar(t, s.Bytes())
})
if testing.Short() {
return
}
for i := 0; i < bitLen; i++ {
t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) {
s := new(big.Int).Lsh(big.NewInt(1), uint(i))
checkScalar(t, s.FillBytes(make([]byte, byteLen)))
})
}
for i := 0; i <= 64; i++ {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen)))
})
}
// Test N-64...N+64 since they risk overlapping with precomputed table values
// in the final additions.
for i := int64(-64); i <= 64; i++ {
t.Run(fmt.Sprintf("N%+d", i), func(t *testing.T) {
checkScalar(t, new(big.Int).Add(Order, big.NewInt(i)).Bytes())
})
}
}
func fatalIfErr(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
func TestFuzz(t *testing.T) {
g1 := g1Curve
g1Generic := g1.Params()
var scalar1 [32]byte
var scalar2 [32]byte
var timeout *time.Timer
if testing.Short() {
timeout = time.NewTimer(10 * time.Millisecond)
} else {
timeout = time.NewTimer(2 * time.Second)
}
for {
select {
case <-timeout.C:
return
default:
}
io.ReadFull(rand.Reader, scalar1[:])
io.ReadFull(rand.Reader, scalar2[:])
x, y := g1.ScalarBaseMult(scalar1[:])
x2, y2 := g1Generic.ScalarBaseMult(scalar1[:])
xx, yy := g1.ScalarMult(x, y, scalar2[:])
xx2, yy2 := g1Generic.ScalarMult(x2, y2, scalar2[:])
if x.Cmp(x2) != 0 || y.Cmp(y2) != 0 {
t.Fatalf("ScalarBaseMult does not match reference result with scalar: %x, please report this error to https://github.com/emmansun/gmsm/issues", scalar1)
}
if xx.Cmp(xx2) != 0 || yy.Cmp(yy2) != 0 {
t.Fatalf("ScalarMult does not match reference result with scalars: %x and %x, please report this error to https://github.com/emmansun/gmsm/issues", scalar1, scalar2)
}
}
}
func TestG1OnCurve(t *testing.T) {
if !g1Curve.IsOnCurve(g1Curve.Params().Gx, g1Curve.Params().Gy) {
t.Error("basepoint is not on the curve")
}
}
func TestOffCurve(t *testing.T) {
x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
if g1Curve.IsOnCurve(x, y) {
t.Errorf("point off curve is claimed to be on the curve")
}
byteLen := (g1Curve.Params().BitSize + 7) / 8
b := make([]byte, 1+2*byteLen)
b[0] = 4 // uncompressed point
x.FillBytes(b[1 : 1+byteLen])
y.FillBytes(b[1+byteLen : 1+2*byteLen])
x1, y1 := Unmarshal(g1Curve, b)
if x1 != nil || y1 != nil {
t.Errorf("unmarshaling a point not on the curve succeeded")
}
}
func isInfinity(x, y *big.Int) bool {
return x.Sign() == 0 && y.Sign() == 0
}
func TestInfinity(t *testing.T) {
x0, y0 := new(big.Int), new(big.Int)
xG, yG := g1Curve.Params().Gx, g1Curve.Params().Gy
if !isInfinity(g1Curve.ScalarMult(xG, yG, g1Curve.Params().N.Bytes())) {
t.Errorf("x^q != ∞")
}
if !isInfinity(g1Curve.ScalarMult(xG, yG, []byte{0})) {
t.Errorf("x^0 != ∞")
}
if !isInfinity(g1Curve.ScalarMult(x0, y0, []byte{1, 2, 3})) {
t.Errorf("∞^k != ∞")
}
if !isInfinity(g1Curve.ScalarMult(x0, y0, []byte{0})) {
t.Errorf("∞^0 != ∞")
}
if !isInfinity(g1Curve.ScalarBaseMult(g1Curve.Params().N.Bytes())) {
t.Errorf("b^q != ∞")
}
if !isInfinity(g1Curve.ScalarBaseMult([]byte{0})) {
t.Errorf("b^0 != ∞")
}
if !isInfinity(g1Curve.Double(x0, y0)) {
t.Errorf("2∞ != ∞")
}
// There is no other point of order two on the NIST curves (as they have
// cofactor one), so Double can't otherwise return the point at infinity.
nMinusOne := new(big.Int).Sub(g1Curve.Params().N, big.NewInt(1))
x, y := g1Curve.ScalarMult(xG, yG, nMinusOne.Bytes())
x, y = g1Curve.Add(x, y, xG, yG)
if !isInfinity(x, y) {
t.Errorf("x^(q-1) + x != ∞")
}
x, y = g1Curve.Add(xG, yG, x0, y0)
if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
t.Errorf("x+∞ != x")
}
x, y = g1Curve.Add(x0, y0, xG, yG)
if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
t.Errorf("∞+x != x")
}
if !g1Curve.IsOnCurve(x0, y0) {
t.Errorf("IsOnCurve(∞) != true")
}
/*
if xx, yy := Unmarshal(g1Curve, Marshal(g1Curve, x0, y0)); xx == nil || yy == nil {
t.Errorf("Unmarshal(Marshal(∞)) did return an error")
}
// We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are
// two valid points with x = 0.
if xx, yy := Unmarshal(g1Curve, []byte{0x00}); xx != nil || yy != nil {
t.Errorf("Unmarshal(∞) did not return an error")
}
byteLen := (g1Curve.Params().BitSize + 7) / 8
buf := make([]byte, byteLen*2+1)
buf[0] = 4 // Uncompressed format.
if xx, yy := Unmarshal(g1Curve, buf); xx == nil || yy == nil {
t.Errorf("Unmarshal((0,0)) did return an error")
}
*/
}
func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
tests := []struct {
name string
curve Curve
}{
{"g1", g1Curve},
{"g1/Params", g1Curve.params},
}
for _, test := range tests {
curve := test.curve
t.Run(test.name, func(t *testing.T) {
t.Parallel()
f(t, curve)
})
}
}
func TestMarshal(t *testing.T) {
testAllCurves(t, func(t *testing.T, curve Curve) {
_, x, y, err := GenerateKey(curve, rand.Reader)
if err != nil {
t.Fatal(err)
}
serialized := Marshal(curve, x, y)
xx, yy := Unmarshal(curve, serialized)
if xx == nil {
t.Fatal("failed to unmarshal")
}
if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
t.Fatal("unmarshal returned different values")
}
})
}
func TestMarshalCompressed(t *testing.T) {
testAllCurves(t, func(t *testing.T, curve Curve) {
_, x, y, err := GenerateKey(curve, rand.Reader)
if err != nil {
t.Fatal(err)
}
testMarshalCompressed(t, curve, x, y, nil)
})
}
func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
if !curve.IsOnCurve(x, y) {
t.Fatal("invalid test point")
}
got := MarshalCompressed(curve, x, y)
if want != nil && !bytes.Equal(got, want) {
t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
}
X, Y := UnmarshalCompressed(curve, got)
if X == nil || Y == nil {
t.Fatalf("UnmarshalCompressed failed unexpectedly")
}
if !curve.IsOnCurve(X, Y) {
t.Error("UnmarshalCompressed returned a point not on the curve")
}
if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
}
}
func TestInvalidCoordinates(t *testing.T) {
checkIsOnCurveFalse := func(name string, x, y *big.Int) {
if g1Curve.IsOnCurve(x, y) {
t.Errorf("IsOnCurve(%s) unexpectedly returned true", name)
}
}
p := g1Curve.Params().P
_, x, y, _ := GenerateKey(g1Curve, rand.Reader)
xx, yy := new(big.Int), new(big.Int)
// Check if the sign is getting dropped.
xx.Neg(x)
checkIsOnCurveFalse("-x, y", xx, y)
yy.Neg(y)
checkIsOnCurveFalse("x, -y", x, yy)
// Check if negative values are reduced modulo P.
xx.Sub(x, p)
checkIsOnCurveFalse("x-P, y", xx, y)
yy.Sub(y, p)
checkIsOnCurveFalse("x, y-P", x, yy)
/*
// Check if positive values are reduced modulo P.
xx.Add(x, p)
checkIsOnCurveFalse("x+P, y", xx, y)
yy.Add(y, p)
checkIsOnCurveFalse("x, y+P", x, yy)
*/
// Check if the overflow is dropped.
xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535))
checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y)
yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535))
checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy)
// Check if P is treated like zero (if possible).
// y^2 = x^3 + B
// y = mod_sqrt(x^3 + B)
// y = mod_sqrt(B) if x = 0
// If there is no modsqrt, there is no point with x = 0, can't test x = P.
if yy := new(big.Int).ModSqrt(g1Curve.Params().B, p); yy != nil {
if !g1Curve.IsOnCurve(big.NewInt(0), yy) {
t.Fatal("(0, mod_sqrt(B)) is not on the curve?")
}
checkIsOnCurveFalse("P, y", p, yy)
}
}
func TestLargeIsOnCurve(t *testing.T) {
large := big.NewInt(1)
large.Lsh(large, 1000)
if g1Curve.IsOnCurve(large, large) {
t.Errorf("(2^1000, 2^1000) is reported on the curve")
}
}
func Test_G1MarshalCompressed(t *testing.T) {
e, e2 := &G1{}, &G1{}
ret := e.MarshalCompressed()
_, err := e2.UnmarshalCompressed(ret)
if err != nil {
t.Fatal(err)
}
if !e2.p.IsInfinity() {
t.Errorf("not same")
}
e.p.Set(curveGen)
ret = e.MarshalCompressed()
_, err = e2.UnmarshalCompressed(ret)
if err != nil {
t.Fatal(err)
}
if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z {
t.Errorf("not same")
}
e.p.Neg(e.p)
ret = e.MarshalCompressed()
_, err = e2.UnmarshalCompressed(ret)
if err != nil {
t.Fatal(err)
}
if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z {
t.Errorf("not same")
}
}
func benchmarkAllCurves(b *testing.B, f func(*testing.B, Curve)) {
tests := []struct {
name string
curve Curve
}{
{"sm9", g1Curve},
{"sm9Parmas", g1Curve.Params()},
}
for _, test := range tests {
curve := test.curve
b.Run(test.name, func(b *testing.B) {
f(b, curve)
})
}
}
func BenchmarkScalarBaseMult(b *testing.B) {
benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
priv, _, _, _ := GenerateKey(curve, rand.Reader)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
x, _ := curve.ScalarBaseMult(priv)
// Prevent the compiler from optimizing out the operation.
priv[0] ^= byte(x.Bits()[0])
}
})
}
func BenchmarkScalarMult(b *testing.B) {
benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
_, x, y, _ := GenerateKey(curve, rand.Reader)
priv, _, _, _ := GenerateKey(curve, rand.Reader)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
x, y = curve.ScalarMult(x, y, priv)
}
})
}
func BenchmarkMarshalUnmarshal(b *testing.B) {
benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
_, x, y, _ := GenerateKey(curve, rand.Reader)
b.Run("Uncompressed", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
buf := Marshal(curve, x, y)
xx, yy := Unmarshal(curve, buf)
if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
b.Error("Unmarshal output different from Marshal input")
}
}
})
b.Run("Compressed", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
buf := MarshalCompressed(curve, x, y)
xx, yy := UnmarshalCompressed(curve, buf)
if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
b.Error("Unmarshal output different from Marshal input")
}
}
})
})
}
func BenchmarkAddPoint(b *testing.B) {
p1 := &curvePoint{}
curvePointDouble(p1, curveGen)

View File

@ -1,238 +0,0 @@
package bn256
import "math/big"
// CurveParams contains the parameters of an elliptic curve and also provides
// a generic, non-constant time implementation of Curve.
type CurveParams struct {
P *big.Int // the order of the underlying field
N *big.Int // the order of the base point
B *big.Int // the constant of the curve equation
Gx, Gy *big.Int // (x,y) of the base point
BitSize int // the size of the underlying field
Name string // the canonical name of the curve
}
func (curve *CurveParams) Params() *CurveParams {
return curve
}
// CurveParams operates, internally, on Jacobian coordinates. For a given
// (x, y) position on the curve, the Jacobian coordinates are (x1, y1, z1)
// where x = x1/z1² and y = y1/z1³. The greatest speedups come when the whole
// calculation can be performed within the transform (as in ScalarMult and
// ScalarBaseMult). But even for Add and Double, it's faster to apply and
// reverse the transform than to operate in affine coordinates.
// polynomial returns x³ + b.
func (curve *CurveParams) polynomial(x *big.Int) *big.Int {
x3 := new(big.Int).Mul(x, x)
x3.Mul(x3, x)
x3.Add(x3, curve.B)
x3.Mod(x3, curve.P)
return x3
}
func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool {
if x.Sign() < 0 || x.Cmp(curve.P) >= 0 ||
y.Sign() < 0 || y.Cmp(curve.P) >= 0 {
return false
}
// y² = x³ + b
y2 := new(big.Int).Mul(y, y)
y2.Mod(y2, curve.P)
return curve.polynomial(x).Cmp(y2) == 0
}
// zForAffine returns a Jacobian Z value for the affine point (x, y). If x and
// y are zero, it assumes that they represent the point at infinity because (0,
// 0) is not on the any of the curves handled here.
func zForAffine(x, y *big.Int) *big.Int {
z := new(big.Int)
if x.Sign() != 0 || y.Sign() != 0 {
z.SetInt64(1)
}
return z
}
// affineFromJacobian reverses the Jacobian transform. See the comment at the
// top of the file. If the point is ∞ it returns 0, 0.
func (curve *CurveParams) affineFromJacobian(x, y, z *big.Int) (xOut, yOut *big.Int) {
if z.Sign() == 0 {
return new(big.Int), new(big.Int)
}
zinv := new(big.Int).ModInverse(z, curve.P)
zinvsq := new(big.Int).Mul(zinv, zinv)
xOut = new(big.Int).Mul(x, zinvsq)
xOut.Mod(xOut, curve.P)
zinvsq.Mul(zinvsq, zinv)
yOut = new(big.Int).Mul(y, zinvsq)
yOut.Mod(yOut, curve.P)
return
}
func (curve *CurveParams) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) {
z1 := zForAffine(x1, y1)
z2 := zForAffine(x2, y2)
return curve.affineFromJacobian(curve.addJacobian(x1, y1, z1, x2, y2, z2))
}
// addJacobian takes two points in Jacobian coordinates, (x1, y1, z1) and
// (x2, y2, z2) and returns their sum, also in Jacobian form.
func (curve *CurveParams) addJacobian(x1, y1, z1, x2, y2, z2 *big.Int) (*big.Int, *big.Int, *big.Int) {
// See https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-add-2007-bl
x3, y3, z3 := new(big.Int), new(big.Int), new(big.Int)
if z1.Sign() == 0 {
x3.Set(x2)
y3.Set(y2)
z3.Set(z2)
return x3, y3, z3
}
if z2.Sign() == 0 {
x3.Set(x1)
y3.Set(y1)
z3.Set(z1)
return x3, y3, z3
}
z1z1 := new(big.Int).Mul(z1, z1)
z1z1.Mod(z1z1, curve.P)
z2z2 := new(big.Int).Mul(z2, z2)
z2z2.Mod(z2z2, curve.P)
u1 := new(big.Int).Mul(x1, z2z2)
u1.Mod(u1, curve.P)
u2 := new(big.Int).Mul(x2, z1z1)
u2.Mod(u2, curve.P)
h := new(big.Int).Sub(u2, u1)
xEqual := h.Sign() == 0
if h.Sign() == -1 {
h.Add(h, curve.P)
}
i := new(big.Int).Lsh(h, 1)
i.Mul(i, i)
j := new(big.Int).Mul(h, i)
s1 := new(big.Int).Mul(y1, z2)
s1.Mul(s1, z2z2)
s1.Mod(s1, curve.P)
s2 := new(big.Int).Mul(y2, z1)
s2.Mul(s2, z1z1)
s2.Mod(s2, curve.P)
r := new(big.Int).Sub(s2, s1)
if r.Sign() == -1 {
r.Add(r, curve.P)
}
yEqual := r.Sign() == 0
if xEqual && yEqual {
return curve.doubleJacobian(x1, y1, z1)
}
r.Lsh(r, 1)
v := new(big.Int).Mul(u1, i)
x3.Set(r)
x3.Mul(x3, x3)
x3.Sub(x3, j)
x3.Sub(x3, v)
x3.Sub(x3, v)
x3.Mod(x3, curve.P)
y3.Set(r)
v.Sub(v, x3)
y3.Mul(y3, v)
s1.Mul(s1, j)
s1.Lsh(s1, 1)
y3.Sub(y3, s1)
y3.Mod(y3, curve.P)
z3.Add(z1, z2)
z3.Mul(z3, z3)
z3.Sub(z3, z1z1)
z3.Sub(z3, z2z2)
z3.Mul(z3, h)
z3.Mod(z3, curve.P)
return x3, y3, z3
}
func (curve *CurveParams) Double(x1, y1 *big.Int) (*big.Int, *big.Int) {
z1 := zForAffine(x1, y1)
return curve.affineFromJacobian(curve.doubleJacobian(x1, y1, z1))
}
// doubleJacobian takes a point in Jacobian coordinates, (x, y, z), and
// returns its double, also in Jacobian form.
func (curve *CurveParams) doubleJacobian(x, y, z *big.Int) (*big.Int, *big.Int, *big.Int) {
// See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/doubling/dbl-2009-l.op3
a := new(big.Int).Mul(x, x)
a.Mod(a, curve.P)
b := new(big.Int).Mul(y, y)
b.Mod(b, curve.P)
c := new(big.Int).Mul(b, b)
c.Mod(c, curve.P)
d := new(big.Int).Add(x, b)
d.Mul(d, d)
d.Sub(d, a)
d.Sub(d, c)
d.Lsh(d, 1)
if d.Sign() < 0 {
d.Add(d, curve.P)
} else {
d.Mod(d, curve.P)
}
e := new(big.Int).Lsh(a, 1)
e.Add(e, a)
f := new(big.Int).Mul(e, e)
x3 := new(big.Int).Lsh(d, 1)
x3.Sub(f, x3)
if x3.Sign() < 0 {
x3.Add(x3, curve.P)
} else {
x3.Mod(x3, curve.P)
}
y3 := new(big.Int).Sub(d, x3)
y3.Mul(y3, e)
c.Lsh(c, 3)
y3.Sub(y3, c)
if y3.Sign() < 0 {
y3.Add(y3, curve.P)
} else {
y3.Mod(y3, curve.P)
}
z3 := new(big.Int).Mul(y, z)
z3.Lsh(z3, 1)
z3.Mod(z3, curve.P)
return x3, y3, z3
}
func (curve *CurveParams) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) {
Bz := new(big.Int).SetInt64(1)
x, y, z := new(big.Int), new(big.Int), new(big.Int)
for _, byte := range k {
for bitNum := 0; bitNum < 8; bitNum++ {
x, y, z = curve.doubleJacobian(x, y, z)
if byte&0x80 == 0x80 {
x, y, z = curve.addJacobian(Bx, By, Bz, x, y, z)
}
byte <<= 1
}
}
return curve.affineFromJacobian(x, y, z)
}
func (curve *CurveParams) ScalarBaseMult(k []byte) (*big.Int, *big.Int) {
return curve.ScalarMult(curve.Gx, curve.Gy, k)
}

View File

@ -1,170 +0,0 @@
package bn256
import (
"encoding/hex"
"fmt"
"math/big"
"testing"
)
var secp256k1Params = &CurveParams{
Name: "secp256k1",
BitSize: 256,
P: bigFromHex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f"),
N: bigFromHex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"),
B: bigFromHex("0000000000000000000000000000000000000000000000000000000000000007"),
Gx: bigFromHex("79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"),
Gy: bigFromHex("483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8"),
}
var sm9CurveParams = &CurveParams{
Name: "sm9",
BitSize: 256,
P: bigFromHex("B640000002A3A6F1D603AB4FF58EC74521F2934B1A7AEEDBE56F9B27E351457D"),
N: bigFromHex("B640000002A3A6F1D603AB4FF58EC74449F2934B18EA8BEEE56EE19CD69ECF25"),
B: bigFromHex("0000000000000000000000000000000000000000000000000000000000000005"),
Gx: bigFromHex("93DE051D62BF718FF5ED0704487D01D6E1E4086909DC3280E8C4E4817C66DDDD"),
Gy: bigFromHex("21FE8DDA4F21E607631065125C395BBC1C1C00CBFA6024350C464CD70A3EA616"),
}
type baseMultTest struct {
k string
x, y string
}
var s256BaseMultTests = []baseMultTest{
{
"AA5E28D6A97A2479A65527F7290311A3624D4CC0FA1578598EE3C2613BF99522",
"34F9460F0E4F08393D192B3C5133A6BA099AA0AD9FD54EBCCFACDFA239FF49C6",
"B71EA9BD730FD8923F6D25A7A91E7DD7728A960686CB5A901BB419E0F2CA232",
},
{
"7E2B897B8CEBC6361663AD410835639826D590F393D90A9538881735256DFAE3",
"D74BF844B0862475103D96A611CF2D898447E288D34B360BC885CB8CE7C00575",
"131C670D414C4546B88AC3FF664611B1C38CEB1C21D76369D7A7A0969D61D97D",
},
{
"6461E6DF0FE7DFD05329F41BF771B86578143D4DD1F7866FB4CA7E97C5FA945D",
"E8AECC370AEDD953483719A116711963CE201AC3EB21D3F3257BB48668C6A72F",
"C25CAF2F0EBA1DDB2F0F3F47866299EF907867B7D27E95B3873BF98397B24EE1",
},
{
"376A3A2CDCD12581EFFF13EE4AD44C4044B8A0524C42422A7E1E181E4DEECCEC",
"14890E61FCD4B0BD92E5B36C81372CA6FED471EF3AA60A3E415EE4FE987DABA1",
"297B858D9F752AB42D3BCA67EE0EB6DCD1C2B7B0DBE23397E66ADC272263F982",
},
{
"1B22644A7BE026548810C378D0B2994EEFA6D2B9881803CB02CEFF865287D1B9",
"F73C65EAD01C5126F28F442D087689BFA08E12763E0CEC1D35B01751FD735ED3",
"F449A8376906482A84ED01479BD18882B919C140D638307F0C0934BA12590BDE",
},
}
func TestBaseMult(t *testing.T) {
for i, e := range s256BaseMultTests {
k, ok := new(big.Int).SetString(e.k, 16)
if !ok {
t.Errorf("%d: bad value for k: %s", i, e.k)
}
x, y := secp256k1Params.ScalarBaseMult(k.Bytes())
if fmt.Sprintf("%X", x) != e.x || fmt.Sprintf("%X", y) != e.y {
t.Errorf("%d: bad output for k=%s: got (%X, %X), want (%s, %s)", i, e.k, x, y, e.x, e.y)
}
}
}
func TestOnCurve(t *testing.T) {
if !secp256k1Params.IsOnCurve(secp256k1Params.Gx, secp256k1Params.Gy) {
t.Errorf("point is not on curve")
}
if !sm9CurveParams.IsOnCurve(sm9CurveParams.Gx, sm9CurveParams.Gy) {
t.Errorf("point is not on curve")
}
}
func TestPMode4And8(t *testing.T) {
res := new(big.Int).Mod(sm9CurveParams.P, big.NewInt(4))
if res.Int64() != 1 {
t.Errorf("p mod 4 != 1")
}
res = new(big.Int).Mod(sm9CurveParams.P, big.NewInt(6))
if res.Int64() != 1 {
t.Errorf("p mod 6 != 1")
}
res = new(big.Int).Mod(sm9CurveParams.P, big.NewInt(8))
if res.Int64() != 5 {
t.Errorf("p mod 8 != 5")
}
res = new(big.Int).Sub(sm9CurveParams.P, big.NewInt(1))
res.Div(res, big.NewInt(2))
if hex.EncodeToString(res.Bytes()) != "5b2000000151d378eb01d5a7fac763a290f949a58d3d776df2b7cd93f1a8a2be" {
t.Errorf("expected %v, got %v\n", "5b2000000151d378eb01d5a7fac763a290f949a58d3d776df2b7cd93f1a8a2be", hex.EncodeToString(res.Bytes()))
}
res = new(big.Int).Add(sm9CurveParams.P, big.NewInt(1))
res.Div(res, big.NewInt(2))
if hex.EncodeToString(res.Bytes()) != "5b2000000151d378eb01d5a7fac763a290f949a58d3d776df2b7cd93f1a8a2bf" {
t.Errorf("expected %v, got %v\n", "5b2000000151d378eb01d5a7fac763a290f949a58d3d776df2b7cd93f1a8a2bf", hex.EncodeToString(res.Bytes()))
}
res = new(big.Int).Add(sm9CurveParams.P, big.NewInt(1))
res.Div(res, big.NewInt(3))
if hex.EncodeToString(res.Bytes()) != "3cc0000000e137a5f201391aa72f97c1b5fb866e5e28fa494c7a890d4bc5c1d4" {
t.Errorf("expected %v, got %v\n", "3cc0000000e137a5f201391aa72f97c1b5fb866e5e28fa494c7a890d4bc5c1d4", hex.EncodeToString(res.Bytes()))
}
res = new(big.Int).Sub(sm9CurveParams.P, big.NewInt(1))
res.Div(res, big.NewInt(4))
if hex.EncodeToString(res.Bytes()) != "2d90000000a8e9bc7580ead3fd63b1d1487ca4d2c69ebbb6f95be6c9f8d4515f" {
t.Errorf("expected %v, got %v\n", "2d90000000a8e9bc7580ead3fd63b1d1487ca4d2c69ebbb6f95be6c9f8d4515f", hex.EncodeToString(res.Bytes()))
}
res = new(big.Int).Sub(sm9CurveParams.P, big.NewInt(1))
res.Div(res, big.NewInt(6))
if hex.EncodeToString(res.Bytes()) != "1e60000000709bd2f9009c8d5397cbe0dafdc3372f147d24a63d4486a5e2e0ea" {
t.Errorf("expected %v, got %v\n", "1e60000000709bd2f9009c8d5397cbe0dafdc3372f147d24a63d4486a5e2e0ea", hex.EncodeToString(res.Bytes()))
}
res = new(big.Int).Sub(sm9CurveParams.P, big.NewInt(1))
res.Div(res, big.NewInt(3))
if hex.EncodeToString(res.Bytes()) != "3cc0000000e137a5f201391aa72f97c1b5fb866e5e28fa494c7a890d4bc5c1d4" {
t.Errorf("expected %v, got %v\n", "3cc0000000e137a5f201391aa72f97c1b5fb866e5e28fa494c7a890d4bc5c1d4", hex.EncodeToString(res.Bytes()))
}
res = new(big.Int).Mul(sm9CurveParams.P, sm9CurveParams.P)
res.Sub(res, big.NewInt(1))
res.Div(res, big.NewInt(3))
if hex.EncodeToString(res.Bytes()) != "2b3fb0000140abbbc71510370c6fa2b194d4665ff95c18014568b07bbd19fb54f0b9aded6fea5b670c35d6b4e3b966415456a4a8503c6361c90d41b4e8a78a58" {
t.Errorf("expected %v, got %v\n", "2b3fb0000140abbbc71510370c6fa2b194d4665ff95c18014568b07bbd19fb54f0b9aded6fea5b670c35d6b4e3b966415456a4a8503c6361c90d41b4e8a78a58", hex.EncodeToString(res.Bytes()))
}
res = new(big.Int).Mul(sm9CurveParams.P, sm9CurveParams.P)
res.Sub(res, big.NewInt(1))
res.Div(res, big.NewInt(2))
if hex.EncodeToString(res.Bytes()) != "40df880001e10199aa9f985292a7740a5f3e998ff60a2401e81d08b99ba6f8ff691684e427df891a9250c20f55961961fe81f6fc785a9512ad93e28f5cfb4f84" {
t.Errorf("expected %v, got %v\n", "40df880001e10199aa9f985292a7740a5f3e998ff60a2401e81d08b99ba6f8ff691684e427df891a9250c20f55961961fe81f6fc785a9512ad93e28f5cfb4f84", hex.EncodeToString(res.Bytes()))
}
res = new(big.Int).Sub(sm9CurveParams.P, big.NewInt(5))
res.Div(res, big.NewInt(8))
if hex.EncodeToString(res.Bytes()) != "16c80000005474de3ac07569feb1d8e8a43e5269634f5ddb7cadf364fc6a28af" {
t.Errorf("expected %v, got %v\n", "16c80000005474de3ac07569feb1d8e8a43e5269634f5ddb7cadf364fc6a28af", hex.EncodeToString(res.Bytes()))
}
res.Exp(big.NewInt(2), res, sm9CurveParams.P)
if hex.EncodeToString(res.Bytes()) != "800db90d149e875b5b564505fe88efba5223f2bf170cc61fea968b3df63edd75" {
t.Errorf("expected %v, got %v\n", "800db90d149e875b5b564505fe88efba5223f2bf170cc61fea968b3df63edd75", hex.EncodeToString(res.Bytes()))
}
res.Mul(u, big.NewInt(6))
res.Add(res, big.NewInt(5))
if hex.EncodeToString(res.Bytes()) != "02400000000215d941" {
t.Errorf("expected %v, got %v\n", "02400000000215d941", hex.EncodeToString(res.Bytes()))
}
res.Mul(u, big.NewInt(6))
res.Mul(res, u)
res.Add(res, big.NewInt(1))
if hex.EncodeToString(res.Bytes()) != "d8000000019062ed0000b98b0cb27659" {
t.Errorf("expected %v, got %v\n", "d8000000019062ed0000b98b0cb27659", hex.EncodeToString(res.Bytes()))
}
}