mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 20:26:19 +08:00
internal/bigmod: add more //go:norace annotations and refactoring
This commit is contained in:
parent
0d56114869
commit
865159d86a
@ -18,6 +18,15 @@ const (
|
||||
_S = _W / 8
|
||||
)
|
||||
|
||||
// Note: These functions make many loops over all the words in a Nat.
|
||||
// These loops used to be in assembly, invisible to -race, -asan, and -msan,
|
||||
// but now they are in Go and incur significant overhead in those modes.
|
||||
// To bring the old performance back, we mark all functions that loop
|
||||
// over Nat words with //go:norace. Because //go:norace does not
|
||||
// propagate across inlining, we must also mark functions that inline
|
||||
// //go:norace functions - specifically, those that inline add, addMulVVW,
|
||||
// assign, cmpGeq, rshift1, and sub.
|
||||
|
||||
// choice represents a constant-time boolean. The value of choice is always
|
||||
// either 1 or 0. We use an int instead of bool in order to make decisions in
|
||||
// constant time by turning it into a mask.
|
||||
@ -40,14 +49,6 @@ func ctEq(x, y uint) choice {
|
||||
return not(choice(c1 | c2))
|
||||
}
|
||||
|
||||
// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
|
||||
// function does not depend on its inputs.
|
||||
func ctGeq(x, y uint) choice {
|
||||
// If x < y, then x - y generates a carry.
|
||||
_, carry := bits.Sub(x, y, 0)
|
||||
return not(choice(carry))
|
||||
}
|
||||
|
||||
// Nat represents an arbitrary natural number
|
||||
//
|
||||
// Each Nat has an announced length, which is the number of limbs it has stored.
|
||||
@ -84,6 +85,7 @@ func (x *Nat) expand(n int) *Nat {
|
||||
return x
|
||||
}
|
||||
extraLimbs := x.limbs[len(x.limbs):n]
|
||||
// clear(extraLimbs)
|
||||
for i := range extraLimbs {
|
||||
extraLimbs[i] = 0
|
||||
}
|
||||
@ -97,6 +99,7 @@ func (x *Nat) reset(n int) *Nat {
|
||||
x.limbs = make([]uint, n)
|
||||
return x
|
||||
}
|
||||
// clear(x.limbs)
|
||||
for i := range x.limbs {
|
||||
x.limbs[i] = 0
|
||||
}
|
||||
@ -131,7 +134,7 @@ func (x *Nat) trim() *Nat {
|
||||
}
|
||||
|
||||
// set assigns x = y, optionally resizing x to the appropriate size.
|
||||
func (x *Nat) Set(y *Nat) *Nat {
|
||||
func (x *Nat) set(y *Nat) *Nat {
|
||||
x.reset(len(y.limbs))
|
||||
copy(x.limbs, y.limbs)
|
||||
return x
|
||||
@ -164,12 +167,14 @@ func (x *Nat) Bytes(m *Modulus) []byte {
|
||||
// SetBytes returns an error if b >= m.
|
||||
//
|
||||
// The output will be resized to the size of m and overwritten.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
|
||||
x.resetFor(m)
|
||||
if err := x.setBytes(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if x.CmpGeq(m.nat) == yes {
|
||||
if x.cmpGeq(m.nat) == yes {
|
||||
return nil, errors.New("input overflows the modulus")
|
||||
}
|
||||
return x, nil
|
||||
@ -195,20 +200,6 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
|
||||
//
|
||||
// The output will be resized to the size of m and overwritten.
|
||||
func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat {
|
||||
mMinusOne := NewNat().Set(m.nat)
|
||||
mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1
|
||||
one := NewNat().resetFor(m)
|
||||
one.limbs[0] = 1
|
||||
x.resetToBytes(b)
|
||||
x = NewNat().modNat(x, mMinusOne) // x = x mod (m-1)
|
||||
x.add(one) // we can safely add 1, no need to check overflow
|
||||
return x
|
||||
}
|
||||
|
||||
// bigEndianUint returns the contents of buf interpreted as a
|
||||
// big-endian encoded uint value.
|
||||
func bigEndianUint(buf []byte) uint {
|
||||
@ -309,8 +300,6 @@ func (x *Nat) IsMinusOne(m *Modulus) choice {
|
||||
}
|
||||
|
||||
// IsOdd returns 1 if x is odd, and 0 otherwise.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) IsOdd() choice {
|
||||
if len(x.limbs) == 0 {
|
||||
return no
|
||||
@ -333,12 +322,12 @@ func (x *Nat) TrailingZeroBitsVarTime() uint {
|
||||
return t
|
||||
}
|
||||
|
||||
// CmpGeq returns 1 if x >= y, and 0 otherwise.
|
||||
// cmpGeq returns 1 if x >= y, and 0 otherwise.
|
||||
//
|
||||
// Both operands must have the same announced length.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) CmpGeq(y *Nat) choice {
|
||||
func (x *Nat) cmpGeq(y *Nat) choice {
|
||||
// Eliminate bounds checks in the loop.
|
||||
size := len(x.limbs)
|
||||
xLimbs := x.limbs[:size]
|
||||
@ -564,6 +553,8 @@ func NewModulus(b []byte) (*Modulus, error) {
|
||||
|
||||
// NewModulusProduct creates a new Modulus from the product of two numbers
|
||||
// represented as big-endian byte slices. The result must be greater than one.
|
||||
//
|
||||
//go:norace
|
||||
func NewModulusProduct(a, b []byte) (*Modulus, error) {
|
||||
x := NewNat().resetToBytes(a)
|
||||
y := NewNat().resetToBytes(b)
|
||||
@ -602,30 +593,23 @@ func (m *Modulus) Nat() *Nat {
|
||||
// Make a copy so that the caller can't modify m.nat or alias it with
|
||||
// another Nat in a modulus operation.
|
||||
n := NewNat()
|
||||
n.Set(m.nat)
|
||||
n.set(m.nat)
|
||||
return n
|
||||
}
|
||||
|
||||
// shiftIn calculates x = x << _W + y mod m.
|
||||
//
|
||||
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
||||
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
|
||||
return x.shiftInNat(y, m.nat)
|
||||
}
|
||||
|
||||
// shiftIn calculates x = x << _W + y mod m.
|
||||
//
|
||||
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
|
||||
d := NewNat().reset(len(m.limbs))
|
||||
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
|
||||
d := NewNat().resetFor(m)
|
||||
|
||||
// Eliminate bounds checks in the loop.
|
||||
size := len(m.limbs)
|
||||
size := len(m.nat.limbs)
|
||||
xLimbs := x.limbs[:size]
|
||||
dLimbs := d.limbs[:size]
|
||||
mLimbs := m.limbs[:size]
|
||||
mLimbs := m.nat.limbs[:size]
|
||||
|
||||
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
|
||||
// from y. Effectively, it left-shifts x and adds y one bit at a time,
|
||||
@ -657,17 +641,10 @@ func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
|
||||
// This works regardless how large the value of x is.
|
||||
//
|
||||
// The output will be resized to the size of m and overwritten.
|
||||
//
|
||||
//go:norace
|
||||
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
||||
return out.modNat(x, m.nat)
|
||||
}
|
||||
|
||||
// Mod calculates out = x mod m.
|
||||
//
|
||||
// This works regardless how large the value of x is.
|
||||
//
|
||||
// The output will be resized to the size of m and overwritten.
|
||||
func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
|
||||
out.reset(len(m.limbs))
|
||||
out.resetFor(m)
|
||||
// Working our way from the most significant to the least significant limb,
|
||||
// we can insert each limb at the least significant position, shifting all
|
||||
// previous limbs left by _W. This way each limb will get shifted by the
|
||||
@ -676,7 +653,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
|
||||
i := len(x.limbs) - 1
|
||||
// For the first N - 1 limbs we can skip the actual shifting and position
|
||||
// them at the shifted position, which starts at min(N - 2, i).
|
||||
start := len(m.limbs) - 2
|
||||
start := len(m.nat.limbs) - 2
|
||||
if i < start {
|
||||
start = i
|
||||
}
|
||||
@ -686,7 +663,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
|
||||
}
|
||||
// We shift in the remaining limbs, reducing modulo m each time.
|
||||
for i >= 0 {
|
||||
out.shiftInNat(x.limbs[i], m)
|
||||
out.shiftIn(x.limbs[i], m)
|
||||
i--
|
||||
}
|
||||
return out
|
||||
@ -715,8 +692,10 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
|
||||
// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
|
||||
//
|
||||
// x and m operands must have the same announced length.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
|
||||
t := NewNat().Set(x)
|
||||
t := NewNat().set(x)
|
||||
underflow := t.sub(m.nat)
|
||||
// We keep the result if x - m didn't underflow (meaning x >= m)
|
||||
// or if always was set.
|
||||
@ -728,10 +707,12 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
|
||||
//
|
||||
// The length of both operands must be the same as the modulus. Both operands
|
||||
// must already be reduced modulo m.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
|
||||
underflow := x.sub(y)
|
||||
// If the subtraction underflowed, add m.
|
||||
t := NewNat().Set(x)
|
||||
t := NewNat().set(x)
|
||||
t.add(m.nat)
|
||||
x.assign(choice(underflow), t)
|
||||
return x
|
||||
@ -752,6 +733,8 @@ func (x *Nat) SubOne(m *Modulus) *Nat {
|
||||
//
|
||||
// The length of both operands must be the same as the modulus. Both operands
|
||||
// must already be reduced modulo m.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
|
||||
overflow := x.add(y)
|
||||
x.maybeSubtractModulus(choice(overflow), m)
|
||||
@ -789,6 +772,8 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
|
||||
//
|
||||
// All inputs should be the same length and already reduced modulo m.
|
||||
// x will be resized to the size of m and overwritten.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
|
||||
n := len(m.nat.limbs)
|
||||
mLimbs := m.nat.limbs[:n]
|
||||
@ -946,11 +931,13 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
|
||||
//
|
||||
// The length of both operands must be the same as the modulus. Both operands
|
||||
// must already be reduced modulo m.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
|
||||
if m.odd {
|
||||
// A Montgomery multiplication by a value out of the Montgomery domain
|
||||
// takes the result out of Montgomery representation.
|
||||
xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m
|
||||
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
|
||||
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
|
||||
}
|
||||
n := len(m.nat.limbs)
|
||||
@ -1009,6 +996,8 @@ func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
|
||||
// to the size of m and overwritten. x must already be reduced modulo m.
|
||||
//
|
||||
// m must be odd, or Exp will panic.
|
||||
//
|
||||
//go:norace
|
||||
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
||||
if !m.odd {
|
||||
panic("bigmod: modulus for Exp must be odd")
|
||||
@ -1025,7 +1014,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
||||
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
|
||||
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
|
||||
}
|
||||
table[0].Set(x).montgomeryRepresentation(m)
|
||||
table[0].set(x).montgomeryRepresentation(m)
|
||||
for i := 1; i < len(table); i++ {
|
||||
table[i].montgomeryMul(table[i-1], table[0], m)
|
||||
}
|
||||
@ -1071,8 +1060,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
|
||||
// For short exponents, precomputing a table and using a window like in Exp
|
||||
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
|
||||
// chain, skipping the initial run of zeroes.
|
||||
xR := NewNat().Set(x).montgomeryRepresentation(m)
|
||||
out.Set(xR)
|
||||
xR := NewNat().set(x).montgomeryRepresentation(m)
|
||||
out.set(xR)
|
||||
for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ {
|
||||
out.montgomeryMul(out, out, m)
|
||||
if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
|
||||
@ -1088,6 +1077,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
|
||||
//
|
||||
// a must be reduced modulo m, but doesn't need to have the same size. The
|
||||
// output will be resized to the size of m and overwritten.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
|
||||
// This is the extended binary GCD algorithm described in the Handbook of
|
||||
// Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound
|
||||
@ -1121,7 +1112,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
|
||||
return x, false
|
||||
}
|
||||
|
||||
u := NewNat().Set(a).ExpandFor(m)
|
||||
u := NewNat().set(a).ExpandFor(m)
|
||||
v := m.Nat()
|
||||
|
||||
A := NewNat().reset(len(m.nat.limbs))
|
||||
@ -1148,7 +1139,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
|
||||
// If both u and v are odd, subtract the smaller from the larger.
|
||||
// If u = v, we need to subtract from v to hit the modified exit condition.
|
||||
if u.IsOdd() == yes && v.IsOdd() == yes {
|
||||
if v.CmpGeq(u) == no {
|
||||
if v.cmpGeq(u) == no {
|
||||
u.sub(v)
|
||||
A.Add(C, m)
|
||||
B.Add(D, &Modulus{nat: a})
|
||||
@ -1189,7 +1180,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
|
||||
if u.IsOne() == no {
|
||||
return x, false
|
||||
}
|
||||
return x.Set(A), true
|
||||
return x.set(A), true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
31
internal/bigmod/nat_extension.go
Normal file
31
internal/bigmod/nat_extension.go
Normal file
@ -0,0 +1,31 @@
|
||||
package bigmod
|
||||
|
||||
func (x *Nat) Set(y *Nat) *Nat {
|
||||
return x.set(y)
|
||||
}
|
||||
|
||||
// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
|
||||
//
|
||||
// The output will be resized to the size of m and overwritten.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat {
|
||||
mMinusOne := NewNat().set(m.nat)
|
||||
mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1
|
||||
mMinusOneM, _ := NewModulus(mMinusOne.Bytes(m))
|
||||
one := NewNat().resetFor(m)
|
||||
one.limbs[0] = 1
|
||||
x.resetToBytes(b)
|
||||
x = NewNat().Mod(x, mMinusOneM) // x = x mod (m-1)
|
||||
x.add(one) // we can safely add 1, no need to check overflow
|
||||
return x
|
||||
}
|
||||
|
||||
// CmpGeq returns 1 if x >= y, and 0 otherwise.
|
||||
//
|
||||
// Both operands must have the same announced length.
|
||||
//
|
||||
//go:norace
|
||||
func (x *Nat) CmpGeq(y *Nat) choice {
|
||||
return x.cmpGeq(y)
|
||||
}
|
@ -61,9 +61,9 @@ func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
|
||||
|
||||
func testModAddCommutative(a *Nat, b *Nat) bool {
|
||||
m := maxModulus(uint(len(a.limbs)))
|
||||
aPlusB := new(Nat).Set(a)
|
||||
aPlusB := new(Nat).set(a)
|
||||
aPlusB.Add(b, m)
|
||||
bPlusA := new(Nat).Set(b)
|
||||
bPlusA := new(Nat).set(b)
|
||||
bPlusA.Add(a, m)
|
||||
return aPlusB.Equal(bPlusA) == 1
|
||||
}
|
||||
@ -77,7 +77,7 @@ func TestModAddCommutative(t *testing.T) {
|
||||
|
||||
func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
|
||||
m := maxModulus(uint(len(a.limbs)))
|
||||
original := new(Nat).Set(a)
|
||||
original := new(Nat).set(a)
|
||||
a.Sub(b, m)
|
||||
a.Add(b, m)
|
||||
return a.Equal(original) == 1
|
||||
@ -97,9 +97,9 @@ func TestMontgomeryRoundtrip(t *testing.T) {
|
||||
aPlusOne := new(big.Int).SetBytes(natBytes(a))
|
||||
aPlusOne.Add(aPlusOne, big.NewInt(1))
|
||||
m, _ := NewModulus(aPlusOne.Bytes())
|
||||
monty := new(Nat).Set(a)
|
||||
monty := new(Nat).set(a)
|
||||
monty.montgomeryRepresentation(m)
|
||||
aAgain := new(Nat).Set(monty)
|
||||
aAgain := new(Nat).set(monty)
|
||||
aAgain.montgomeryMul(monty, one, m)
|
||||
if a.Equal(aAgain) != 1 {
|
||||
t.Errorf("%v != %v", a, aAgain)
|
||||
|
Loading…
x
Reference in New Issue
Block a user