internal/bigmod: add more //go:norace annotations and refactoring

This commit is contained in:
Sun Yimin 2024-12-06 08:54:47 +08:00 committed by GitHub
parent 0d56114869
commit 865159d86a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 88 additions and 66 deletions

View File

@ -18,6 +18,15 @@ const (
_S = _W / 8 _S = _W / 8
) )
// Note: These functions make many loops over all the words in a Nat.
// These loops used to be in assembly, invisible to -race, -asan, and -msan,
// but now they are in Go and incur significant overhead in those modes.
// To bring the old performance back, we mark all functions that loop
// over Nat words with //go:norace. Because //go:norace does not
// propagate across inlining, we must also mark functions that inline
// //go:norace functions - specifically, those that inline add, addMulVVW,
// assign, cmpGeq, rshift1, and sub.
// choice represents a constant-time boolean. The value of choice is always // choice represents a constant-time boolean. The value of choice is always
// either 1 or 0. We use an int instead of bool in order to make decisions in // either 1 or 0. We use an int instead of bool in order to make decisions in
// constant time by turning it into a mask. // constant time by turning it into a mask.
@ -40,14 +49,6 @@ func ctEq(x, y uint) choice {
return not(choice(c1 | c2)) return not(choice(c1 | c2))
} }
// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
// function does not depend on its inputs.
func ctGeq(x, y uint) choice {
// If x < y, then x - y generates a carry.
_, carry := bits.Sub(x, y, 0)
return not(choice(carry))
}
// Nat represents an arbitrary natural number // Nat represents an arbitrary natural number
// //
// Each Nat has an announced length, which is the number of limbs it has stored. // Each Nat has an announced length, which is the number of limbs it has stored.
@ -84,6 +85,7 @@ func (x *Nat) expand(n int) *Nat {
return x return x
} }
extraLimbs := x.limbs[len(x.limbs):n] extraLimbs := x.limbs[len(x.limbs):n]
// clear(extraLimbs)
for i := range extraLimbs { for i := range extraLimbs {
extraLimbs[i] = 0 extraLimbs[i] = 0
} }
@ -97,6 +99,7 @@ func (x *Nat) reset(n int) *Nat {
x.limbs = make([]uint, n) x.limbs = make([]uint, n)
return x return x
} }
// clear(x.limbs)
for i := range x.limbs { for i := range x.limbs {
x.limbs[i] = 0 x.limbs[i] = 0
} }
@ -131,7 +134,7 @@ func (x *Nat) trim() *Nat {
} }
// set assigns x = y, optionally resizing x to the appropriate size. // set assigns x = y, optionally resizing x to the appropriate size.
func (x *Nat) Set(y *Nat) *Nat { func (x *Nat) set(y *Nat) *Nat {
x.reset(len(y.limbs)) x.reset(len(y.limbs))
copy(x.limbs, y.limbs) copy(x.limbs, y.limbs)
return x return x
@ -164,12 +167,14 @@ func (x *Nat) Bytes(m *Modulus) []byte {
// SetBytes returns an error if b >= m. // SetBytes returns an error if b >= m.
// //
// The output will be resized to the size of m and overwritten. // The output will be resized to the size of m and overwritten.
//
//go:norace
func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
x.resetFor(m) x.resetFor(m)
if err := x.setBytes(b); err != nil { if err := x.setBytes(b); err != nil {
return nil, err return nil, err
} }
if x.CmpGeq(m.nat) == yes { if x.cmpGeq(m.nat) == yes {
return nil, errors.New("input overflows the modulus") return nil, errors.New("input overflows the modulus")
} }
return x, nil return x, nil
@ -195,20 +200,6 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
return x, nil return x, nil
} }
// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
//
// The output will be resized to the size of m and overwritten.
func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat {
mMinusOne := NewNat().Set(m.nat)
mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1
one := NewNat().resetFor(m)
one.limbs[0] = 1
x.resetToBytes(b)
x = NewNat().modNat(x, mMinusOne) // x = x mod (m-1)
x.add(one) // we can safely add 1, no need to check overflow
return x
}
// bigEndianUint returns the contents of buf interpreted as a // bigEndianUint returns the contents of buf interpreted as a
// big-endian encoded uint value. // big-endian encoded uint value.
func bigEndianUint(buf []byte) uint { func bigEndianUint(buf []byte) uint {
@ -309,8 +300,6 @@ func (x *Nat) IsMinusOne(m *Modulus) choice {
} }
// IsOdd returns 1 if x is odd, and 0 otherwise. // IsOdd returns 1 if x is odd, and 0 otherwise.
//
//go:norace
func (x *Nat) IsOdd() choice { func (x *Nat) IsOdd() choice {
if len(x.limbs) == 0 { if len(x.limbs) == 0 {
return no return no
@ -333,12 +322,12 @@ func (x *Nat) TrailingZeroBitsVarTime() uint {
return t return t
} }
// CmpGeq returns 1 if x >= y, and 0 otherwise. // cmpGeq returns 1 if x >= y, and 0 otherwise.
// //
// Both operands must have the same announced length. // Both operands must have the same announced length.
// //
//go:norace //go:norace
func (x *Nat) CmpGeq(y *Nat) choice { func (x *Nat) cmpGeq(y *Nat) choice {
// Eliminate bounds checks in the loop. // Eliminate bounds checks in the loop.
size := len(x.limbs) size := len(x.limbs)
xLimbs := x.limbs[:size] xLimbs := x.limbs[:size]
@ -564,6 +553,8 @@ func NewModulus(b []byte) (*Modulus, error) {
// NewModulusProduct creates a new Modulus from the product of two numbers // NewModulusProduct creates a new Modulus from the product of two numbers
// represented as big-endian byte slices. The result must be greater than one. // represented as big-endian byte slices. The result must be greater than one.
//
//go:norace
func NewModulusProduct(a, b []byte) (*Modulus, error) { func NewModulusProduct(a, b []byte) (*Modulus, error) {
x := NewNat().resetToBytes(a) x := NewNat().resetToBytes(a)
y := NewNat().resetToBytes(b) y := NewNat().resetToBytes(b)
@ -602,30 +593,23 @@ func (m *Modulus) Nat() *Nat {
// Make a copy so that the caller can't modify m.nat or alias it with // Make a copy so that the caller can't modify m.nat or alias it with
// another Nat in a modulus operation. // another Nat in a modulus operation.
n := NewNat() n := NewNat()
n.Set(m.nat) n.set(m.nat)
return n return n
} }
// shiftIn calculates x = x << _W + y mod m.
//
// This assumes that x is already reduced mod m, and that y < 2^_W.
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
return x.shiftInNat(y, m.nat)
}
// shiftIn calculates x = x << _W + y mod m. // shiftIn calculates x = x << _W + y mod m.
// //
// This assumes that x is already reduced mod m, and that y < 2^_W. // This assumes that x is already reduced mod m, and that y < 2^_W.
// //
//go:norace //go:norace
func (x *Nat) shiftInNat(y uint, m *Nat) *Nat { func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
d := NewNat().reset(len(m.limbs)) d := NewNat().resetFor(m)
// Eliminate bounds checks in the loop. // Eliminate bounds checks in the loop.
size := len(m.limbs) size := len(m.nat.limbs)
xLimbs := x.limbs[:size] xLimbs := x.limbs[:size]
dLimbs := d.limbs[:size] dLimbs := d.limbs[:size]
mLimbs := m.limbs[:size] mLimbs := m.nat.limbs[:size]
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit // Each iteration of this loop computes x = 2x + b mod m, where b is a bit
// from y. Effectively, it left-shifts x and adds y one bit at a time, // from y. Effectively, it left-shifts x and adds y one bit at a time,
@ -657,17 +641,10 @@ func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
// This works regardless how large the value of x is. // This works regardless how large the value of x is.
// //
// The output will be resized to the size of m and overwritten. // The output will be resized to the size of m and overwritten.
//
//go:norace
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
return out.modNat(x, m.nat) out.resetFor(m)
}
// Mod calculates out = x mod m.
//
// This works regardless how large the value of x is.
//
// The output will be resized to the size of m and overwritten.
func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
out.reset(len(m.limbs))
// Working our way from the most significant to the least significant limb, // Working our way from the most significant to the least significant limb,
// we can insert each limb at the least significant position, shifting all // we can insert each limb at the least significant position, shifting all
// previous limbs left by _W. This way each limb will get shifted by the // previous limbs left by _W. This way each limb will get shifted by the
@ -676,7 +653,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
i := len(x.limbs) - 1 i := len(x.limbs) - 1
// For the first N - 1 limbs we can skip the actual shifting and position // For the first N - 1 limbs we can skip the actual shifting and position
// them at the shifted position, which starts at min(N - 2, i). // them at the shifted position, which starts at min(N - 2, i).
start := len(m.limbs) - 2 start := len(m.nat.limbs) - 2
if i < start { if i < start {
start = i start = i
} }
@ -686,7 +663,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
} }
// We shift in the remaining limbs, reducing modulo m each time. // We shift in the remaining limbs, reducing modulo m each time.
for i >= 0 { for i >= 0 {
out.shiftInNat(x.limbs[i], m) out.shiftIn(x.limbs[i], m)
i-- i--
} }
return out return out
@ -715,8 +692,10 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m. // overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
// //
// x and m operands must have the same announced length. // x and m operands must have the same announced length.
//
//go:norace
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
t := NewNat().Set(x) t := NewNat().set(x)
underflow := t.sub(m.nat) underflow := t.sub(m.nat)
// We keep the result if x - m didn't underflow (meaning x >= m) // We keep the result if x - m didn't underflow (meaning x >= m)
// or if always was set. // or if always was set.
@ -728,10 +707,12 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
// //
// The length of both operands must be the same as the modulus. Both operands // The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m. // must already be reduced modulo m.
//
//go:norace
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
underflow := x.sub(y) underflow := x.sub(y)
// If the subtraction underflowed, add m. // If the subtraction underflowed, add m.
t := NewNat().Set(x) t := NewNat().set(x)
t.add(m.nat) t.add(m.nat)
x.assign(choice(underflow), t) x.assign(choice(underflow), t)
return x return x
@ -752,6 +733,8 @@ func (x *Nat) SubOne(m *Modulus) *Nat {
// //
// The length of both operands must be the same as the modulus. Both operands // The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m. // must already be reduced modulo m.
//
//go:norace
func (x *Nat) Add(y *Nat, m *Modulus) *Nat { func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
overflow := x.add(y) overflow := x.add(y)
x.maybeSubtractModulus(choice(overflow), m) x.maybeSubtractModulus(choice(overflow), m)
@ -789,6 +772,8 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
// //
// All inputs should be the same length and already reduced modulo m. // All inputs should be the same length and already reduced modulo m.
// x will be resized to the size of m and overwritten. // x will be resized to the size of m and overwritten.
//
//go:norace
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
n := len(m.nat.limbs) n := len(m.nat.limbs)
mLimbs := m.nat.limbs[:n] mLimbs := m.nat.limbs[:n]
@ -946,11 +931,13 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
// //
// The length of both operands must be the same as the modulus. Both operands // The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m. // must already be reduced modulo m.
//
//go:norace
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
if m.odd { if m.odd {
// A Montgomery multiplication by a value out of the Montgomery domain // A Montgomery multiplication by a value out of the Montgomery domain
// takes the result out of Montgomery representation. // takes the result out of Montgomery representation.
xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
} }
n := len(m.nat.limbs) n := len(m.nat.limbs)
@ -1009,6 +996,8 @@ func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
// to the size of m and overwritten. x must already be reduced modulo m. // to the size of m and overwritten. x must already be reduced modulo m.
// //
// m must be odd, or Exp will panic. // m must be odd, or Exp will panic.
//
//go:norace
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
if !m.odd { if !m.odd {
panic("bigmod: modulus for Exp must be odd") panic("bigmod: modulus for Exp must be odd")
@ -1025,7 +1014,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
} }
table[0].Set(x).montgomeryRepresentation(m) table[0].set(x).montgomeryRepresentation(m)
for i := 1; i < len(table); i++ { for i := 1; i < len(table); i++ {
table[i].montgomeryMul(table[i-1], table[0], m) table[i].montgomeryMul(table[i-1], table[0], m)
} }
@ -1071,8 +1060,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
// For short exponents, precomputing a table and using a window like in Exp // For short exponents, precomputing a table and using a window like in Exp
// doesn't pay off. Instead, we do a simple conditional square-and-multiply // doesn't pay off. Instead, we do a simple conditional square-and-multiply
// chain, skipping the initial run of zeroes. // chain, skipping the initial run of zeroes.
xR := NewNat().Set(x).montgomeryRepresentation(m) xR := NewNat().set(x).montgomeryRepresentation(m)
out.Set(xR) out.set(xR)
for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ { for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ {
out.montgomeryMul(out, out, m) out.montgomeryMul(out, out, m)
if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 { if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
@ -1088,6 +1077,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
// //
// a must be reduced modulo m, but doesn't need to have the same size. The // a must be reduced modulo m, but doesn't need to have the same size. The
// output will be resized to the size of m and overwritten. // output will be resized to the size of m and overwritten.
//
//go:norace
func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) { func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
// This is the extended binary GCD algorithm described in the Handbook of // This is the extended binary GCD algorithm described in the Handbook of
// Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound // Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound
@ -1121,7 +1112,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
return x, false return x, false
} }
u := NewNat().Set(a).ExpandFor(m) u := NewNat().set(a).ExpandFor(m)
v := m.Nat() v := m.Nat()
A := NewNat().reset(len(m.nat.limbs)) A := NewNat().reset(len(m.nat.limbs))
@ -1148,7 +1139,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
// If both u and v are odd, subtract the smaller from the larger. // If both u and v are odd, subtract the smaller from the larger.
// If u = v, we need to subtract from v to hit the modified exit condition. // If u = v, we need to subtract from v to hit the modified exit condition.
if u.IsOdd() == yes && v.IsOdd() == yes { if u.IsOdd() == yes && v.IsOdd() == yes {
if v.CmpGeq(u) == no { if v.cmpGeq(u) == no {
u.sub(v) u.sub(v)
A.Add(C, m) A.Add(C, m)
B.Add(D, &Modulus{nat: a}) B.Add(D, &Modulus{nat: a})
@ -1189,7 +1180,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
if u.IsOne() == no { if u.IsOne() == no {
return x, false return x, false
} }
return x.Set(A), true return x.set(A), true
} }
} }
} }

View File

@ -0,0 +1,31 @@
package bigmod
func (x *Nat) Set(y *Nat) *Nat {
return x.set(y)
}
// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
//
// The output will be resized to the size of m and overwritten.
//
//go:norace
func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat {
mMinusOne := NewNat().set(m.nat)
mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1
mMinusOneM, _ := NewModulus(mMinusOne.Bytes(m))
one := NewNat().resetFor(m)
one.limbs[0] = 1
x.resetToBytes(b)
x = NewNat().Mod(x, mMinusOneM) // x = x mod (m-1)
x.add(one) // we can safely add 1, no need to check overflow
return x
}
// CmpGeq returns 1 if x >= y, and 0 otherwise.
//
// Both operands must have the same announced length.
//
//go:norace
func (x *Nat) CmpGeq(y *Nat) choice {
return x.cmpGeq(y)
}

View File

@ -61,9 +61,9 @@ func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
func testModAddCommutative(a *Nat, b *Nat) bool { func testModAddCommutative(a *Nat, b *Nat) bool {
m := maxModulus(uint(len(a.limbs))) m := maxModulus(uint(len(a.limbs)))
aPlusB := new(Nat).Set(a) aPlusB := new(Nat).set(a)
aPlusB.Add(b, m) aPlusB.Add(b, m)
bPlusA := new(Nat).Set(b) bPlusA := new(Nat).set(b)
bPlusA.Add(a, m) bPlusA.Add(a, m)
return aPlusB.Equal(bPlusA) == 1 return aPlusB.Equal(bPlusA) == 1
} }
@ -77,7 +77,7 @@ func TestModAddCommutative(t *testing.T) {
func testModSubThenAddIdentity(a *Nat, b *Nat) bool { func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
m := maxModulus(uint(len(a.limbs))) m := maxModulus(uint(len(a.limbs)))
original := new(Nat).Set(a) original := new(Nat).set(a)
a.Sub(b, m) a.Sub(b, m)
a.Add(b, m) a.Add(b, m)
return a.Equal(original) == 1 return a.Equal(original) == 1
@ -97,9 +97,9 @@ func TestMontgomeryRoundtrip(t *testing.T) {
aPlusOne := new(big.Int).SetBytes(natBytes(a)) aPlusOne := new(big.Int).SetBytes(natBytes(a))
aPlusOne.Add(aPlusOne, big.NewInt(1)) aPlusOne.Add(aPlusOne, big.NewInt(1))
m, _ := NewModulus(aPlusOne.Bytes()) m, _ := NewModulus(aPlusOne.Bytes())
monty := new(Nat).Set(a) monty := new(Nat).set(a)
monty.montgomeryRepresentation(m) monty.montgomeryRepresentation(m)
aAgain := new(Nat).Set(monty) aAgain := new(Nat).set(monty)
aAgain.montgomeryMul(monty, one, m) aAgain.montgomeryMul(monty, one, m)
if a.Equal(aAgain) != 1 { if a.Equal(aAgain) != 1 {
t.Errorf("%v != %v", a, aAgain) t.Errorf("%v != %v", a, aAgain)