mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-27 04:36:19 +08:00
internal/bigmod: add more //go:norace annotations and refactoring
This commit is contained in:
parent
0d56114869
commit
865159d86a
@ -18,6 +18,15 @@ const (
|
|||||||
_S = _W / 8
|
_S = _W / 8
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Note: These functions make many loops over all the words in a Nat.
|
||||||
|
// These loops used to be in assembly, invisible to -race, -asan, and -msan,
|
||||||
|
// but now they are in Go and incur significant overhead in those modes.
|
||||||
|
// To bring the old performance back, we mark all functions that loop
|
||||||
|
// over Nat words with //go:norace. Because //go:norace does not
|
||||||
|
// propagate across inlining, we must also mark functions that inline
|
||||||
|
// //go:norace functions - specifically, those that inline add, addMulVVW,
|
||||||
|
// assign, cmpGeq, rshift1, and sub.
|
||||||
|
|
||||||
// choice represents a constant-time boolean. The value of choice is always
|
// choice represents a constant-time boolean. The value of choice is always
|
||||||
// either 1 or 0. We use an int instead of bool in order to make decisions in
|
// either 1 or 0. We use an int instead of bool in order to make decisions in
|
||||||
// constant time by turning it into a mask.
|
// constant time by turning it into a mask.
|
||||||
@ -40,14 +49,6 @@ func ctEq(x, y uint) choice {
|
|||||||
return not(choice(c1 | c2))
|
return not(choice(c1 | c2))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
|
|
||||||
// function does not depend on its inputs.
|
|
||||||
func ctGeq(x, y uint) choice {
|
|
||||||
// If x < y, then x - y generates a carry.
|
|
||||||
_, carry := bits.Sub(x, y, 0)
|
|
||||||
return not(choice(carry))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Nat represents an arbitrary natural number
|
// Nat represents an arbitrary natural number
|
||||||
//
|
//
|
||||||
// Each Nat has an announced length, which is the number of limbs it has stored.
|
// Each Nat has an announced length, which is the number of limbs it has stored.
|
||||||
@ -84,6 +85,7 @@ func (x *Nat) expand(n int) *Nat {
|
|||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
extraLimbs := x.limbs[len(x.limbs):n]
|
extraLimbs := x.limbs[len(x.limbs):n]
|
||||||
|
// clear(extraLimbs)
|
||||||
for i := range extraLimbs {
|
for i := range extraLimbs {
|
||||||
extraLimbs[i] = 0
|
extraLimbs[i] = 0
|
||||||
}
|
}
|
||||||
@ -97,6 +99,7 @@ func (x *Nat) reset(n int) *Nat {
|
|||||||
x.limbs = make([]uint, n)
|
x.limbs = make([]uint, n)
|
||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
// clear(x.limbs)
|
||||||
for i := range x.limbs {
|
for i := range x.limbs {
|
||||||
x.limbs[i] = 0
|
x.limbs[i] = 0
|
||||||
}
|
}
|
||||||
@ -131,7 +134,7 @@ func (x *Nat) trim() *Nat {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// set assigns x = y, optionally resizing x to the appropriate size.
|
// set assigns x = y, optionally resizing x to the appropriate size.
|
||||||
func (x *Nat) Set(y *Nat) *Nat {
|
func (x *Nat) set(y *Nat) *Nat {
|
||||||
x.reset(len(y.limbs))
|
x.reset(len(y.limbs))
|
||||||
copy(x.limbs, y.limbs)
|
copy(x.limbs, y.limbs)
|
||||||
return x
|
return x
|
||||||
@ -164,12 +167,14 @@ func (x *Nat) Bytes(m *Modulus) []byte {
|
|||||||
// SetBytes returns an error if b >= m.
|
// SetBytes returns an error if b >= m.
|
||||||
//
|
//
|
||||||
// The output will be resized to the size of m and overwritten.
|
// The output will be resized to the size of m and overwritten.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
|
func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
|
||||||
x.resetFor(m)
|
x.resetFor(m)
|
||||||
if err := x.setBytes(b); err != nil {
|
if err := x.setBytes(b); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if x.CmpGeq(m.nat) == yes {
|
if x.cmpGeq(m.nat) == yes {
|
||||||
return nil, errors.New("input overflows the modulus")
|
return nil, errors.New("input overflows the modulus")
|
||||||
}
|
}
|
||||||
return x, nil
|
return x, nil
|
||||||
@ -195,20 +200,6 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
|
|||||||
return x, nil
|
return x, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
|
|
||||||
//
|
|
||||||
// The output will be resized to the size of m and overwritten.
|
|
||||||
func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat {
|
|
||||||
mMinusOne := NewNat().Set(m.nat)
|
|
||||||
mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1
|
|
||||||
one := NewNat().resetFor(m)
|
|
||||||
one.limbs[0] = 1
|
|
||||||
x.resetToBytes(b)
|
|
||||||
x = NewNat().modNat(x, mMinusOne) // x = x mod (m-1)
|
|
||||||
x.add(one) // we can safely add 1, no need to check overflow
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// bigEndianUint returns the contents of buf interpreted as a
|
// bigEndianUint returns the contents of buf interpreted as a
|
||||||
// big-endian encoded uint value.
|
// big-endian encoded uint value.
|
||||||
func bigEndianUint(buf []byte) uint {
|
func bigEndianUint(buf []byte) uint {
|
||||||
@ -309,8 +300,6 @@ func (x *Nat) IsMinusOne(m *Modulus) choice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsOdd returns 1 if x is odd, and 0 otherwise.
|
// IsOdd returns 1 if x is odd, and 0 otherwise.
|
||||||
//
|
|
||||||
//go:norace
|
|
||||||
func (x *Nat) IsOdd() choice {
|
func (x *Nat) IsOdd() choice {
|
||||||
if len(x.limbs) == 0 {
|
if len(x.limbs) == 0 {
|
||||||
return no
|
return no
|
||||||
@ -333,12 +322,12 @@ func (x *Nat) TrailingZeroBitsVarTime() uint {
|
|||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
// CmpGeq returns 1 if x >= y, and 0 otherwise.
|
// cmpGeq returns 1 if x >= y, and 0 otherwise.
|
||||||
//
|
//
|
||||||
// Both operands must have the same announced length.
|
// Both operands must have the same announced length.
|
||||||
//
|
//
|
||||||
//go:norace
|
//go:norace
|
||||||
func (x *Nat) CmpGeq(y *Nat) choice {
|
func (x *Nat) cmpGeq(y *Nat) choice {
|
||||||
// Eliminate bounds checks in the loop.
|
// Eliminate bounds checks in the loop.
|
||||||
size := len(x.limbs)
|
size := len(x.limbs)
|
||||||
xLimbs := x.limbs[:size]
|
xLimbs := x.limbs[:size]
|
||||||
@ -564,6 +553,8 @@ func NewModulus(b []byte) (*Modulus, error) {
|
|||||||
|
|
||||||
// NewModulusProduct creates a new Modulus from the product of two numbers
|
// NewModulusProduct creates a new Modulus from the product of two numbers
|
||||||
// represented as big-endian byte slices. The result must be greater than one.
|
// represented as big-endian byte slices. The result must be greater than one.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
func NewModulusProduct(a, b []byte) (*Modulus, error) {
|
func NewModulusProduct(a, b []byte) (*Modulus, error) {
|
||||||
x := NewNat().resetToBytes(a)
|
x := NewNat().resetToBytes(a)
|
||||||
y := NewNat().resetToBytes(b)
|
y := NewNat().resetToBytes(b)
|
||||||
@ -602,30 +593,23 @@ func (m *Modulus) Nat() *Nat {
|
|||||||
// Make a copy so that the caller can't modify m.nat or alias it with
|
// Make a copy so that the caller can't modify m.nat or alias it with
|
||||||
// another Nat in a modulus operation.
|
// another Nat in a modulus operation.
|
||||||
n := NewNat()
|
n := NewNat()
|
||||||
n.Set(m.nat)
|
n.set(m.nat)
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
// shiftIn calculates x = x << _W + y mod m.
|
|
||||||
//
|
|
||||||
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
|
||||||
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
|
|
||||||
return x.shiftInNat(y, m.nat)
|
|
||||||
}
|
|
||||||
|
|
||||||
// shiftIn calculates x = x << _W + y mod m.
|
// shiftIn calculates x = x << _W + y mod m.
|
||||||
//
|
//
|
||||||
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
||||||
//
|
//
|
||||||
//go:norace
|
//go:norace
|
||||||
func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
|
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
|
||||||
d := NewNat().reset(len(m.limbs))
|
d := NewNat().resetFor(m)
|
||||||
|
|
||||||
// Eliminate bounds checks in the loop.
|
// Eliminate bounds checks in the loop.
|
||||||
size := len(m.limbs)
|
size := len(m.nat.limbs)
|
||||||
xLimbs := x.limbs[:size]
|
xLimbs := x.limbs[:size]
|
||||||
dLimbs := d.limbs[:size]
|
dLimbs := d.limbs[:size]
|
||||||
mLimbs := m.limbs[:size]
|
mLimbs := m.nat.limbs[:size]
|
||||||
|
|
||||||
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
|
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
|
||||||
// from y. Effectively, it left-shifts x and adds y one bit at a time,
|
// from y. Effectively, it left-shifts x and adds y one bit at a time,
|
||||||
@ -657,17 +641,10 @@ func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
|
|||||||
// This works regardless how large the value of x is.
|
// This works regardless how large the value of x is.
|
||||||
//
|
//
|
||||||
// The output will be resized to the size of m and overwritten.
|
// The output will be resized to the size of m and overwritten.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
||||||
return out.modNat(x, m.nat)
|
out.resetFor(m)
|
||||||
}
|
|
||||||
|
|
||||||
// Mod calculates out = x mod m.
|
|
||||||
//
|
|
||||||
// This works regardless how large the value of x is.
|
|
||||||
//
|
|
||||||
// The output will be resized to the size of m and overwritten.
|
|
||||||
func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
|
|
||||||
out.reset(len(m.limbs))
|
|
||||||
// Working our way from the most significant to the least significant limb,
|
// Working our way from the most significant to the least significant limb,
|
||||||
// we can insert each limb at the least significant position, shifting all
|
// we can insert each limb at the least significant position, shifting all
|
||||||
// previous limbs left by _W. This way each limb will get shifted by the
|
// previous limbs left by _W. This way each limb will get shifted by the
|
||||||
@ -676,7 +653,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
|
|||||||
i := len(x.limbs) - 1
|
i := len(x.limbs) - 1
|
||||||
// For the first N - 1 limbs we can skip the actual shifting and position
|
// For the first N - 1 limbs we can skip the actual shifting and position
|
||||||
// them at the shifted position, which starts at min(N - 2, i).
|
// them at the shifted position, which starts at min(N - 2, i).
|
||||||
start := len(m.limbs) - 2
|
start := len(m.nat.limbs) - 2
|
||||||
if i < start {
|
if i < start {
|
||||||
start = i
|
start = i
|
||||||
}
|
}
|
||||||
@ -686,7 +663,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
|
|||||||
}
|
}
|
||||||
// We shift in the remaining limbs, reducing modulo m each time.
|
// We shift in the remaining limbs, reducing modulo m each time.
|
||||||
for i >= 0 {
|
for i >= 0 {
|
||||||
out.shiftInNat(x.limbs[i], m)
|
out.shiftIn(x.limbs[i], m)
|
||||||
i--
|
i--
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
@ -715,8 +692,10 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
|
|||||||
// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
|
// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
|
||||||
//
|
//
|
||||||
// x and m operands must have the same announced length.
|
// x and m operands must have the same announced length.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
|
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
|
||||||
t := NewNat().Set(x)
|
t := NewNat().set(x)
|
||||||
underflow := t.sub(m.nat)
|
underflow := t.sub(m.nat)
|
||||||
// We keep the result if x - m didn't underflow (meaning x >= m)
|
// We keep the result if x - m didn't underflow (meaning x >= m)
|
||||||
// or if always was set.
|
// or if always was set.
|
||||||
@ -728,10 +707,12 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
|
|||||||
//
|
//
|
||||||
// The length of both operands must be the same as the modulus. Both operands
|
// The length of both operands must be the same as the modulus. Both operands
|
||||||
// must already be reduced modulo m.
|
// must already be reduced modulo m.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
|
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
|
||||||
underflow := x.sub(y)
|
underflow := x.sub(y)
|
||||||
// If the subtraction underflowed, add m.
|
// If the subtraction underflowed, add m.
|
||||||
t := NewNat().Set(x)
|
t := NewNat().set(x)
|
||||||
t.add(m.nat)
|
t.add(m.nat)
|
||||||
x.assign(choice(underflow), t)
|
x.assign(choice(underflow), t)
|
||||||
return x
|
return x
|
||||||
@ -752,6 +733,8 @@ func (x *Nat) SubOne(m *Modulus) *Nat {
|
|||||||
//
|
//
|
||||||
// The length of both operands must be the same as the modulus. Both operands
|
// The length of both operands must be the same as the modulus. Both operands
|
||||||
// must already be reduced modulo m.
|
// must already be reduced modulo m.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
|
func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
|
||||||
overflow := x.add(y)
|
overflow := x.add(y)
|
||||||
x.maybeSubtractModulus(choice(overflow), m)
|
x.maybeSubtractModulus(choice(overflow), m)
|
||||||
@ -789,6 +772,8 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
|
|||||||
//
|
//
|
||||||
// All inputs should be the same length and already reduced modulo m.
|
// All inputs should be the same length and already reduced modulo m.
|
||||||
// x will be resized to the size of m and overwritten.
|
// x will be resized to the size of m and overwritten.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
|
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
|
||||||
n := len(m.nat.limbs)
|
n := len(m.nat.limbs)
|
||||||
mLimbs := m.nat.limbs[:n]
|
mLimbs := m.nat.limbs[:n]
|
||||||
@ -946,11 +931,13 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
|
|||||||
//
|
//
|
||||||
// The length of both operands must be the same as the modulus. Both operands
|
// The length of both operands must be the same as the modulus. Both operands
|
||||||
// must already be reduced modulo m.
|
// must already be reduced modulo m.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
|
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
|
||||||
if m.odd {
|
if m.odd {
|
||||||
// A Montgomery multiplication by a value out of the Montgomery domain
|
// A Montgomery multiplication by a value out of the Montgomery domain
|
||||||
// takes the result out of Montgomery representation.
|
// takes the result out of Montgomery representation.
|
||||||
xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m
|
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
|
||||||
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
|
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
|
||||||
}
|
}
|
||||||
n := len(m.nat.limbs)
|
n := len(m.nat.limbs)
|
||||||
@ -1009,6 +996,8 @@ func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
|
|||||||
// to the size of m and overwritten. x must already be reduced modulo m.
|
// to the size of m and overwritten. x must already be reduced modulo m.
|
||||||
//
|
//
|
||||||
// m must be odd, or Exp will panic.
|
// m must be odd, or Exp will panic.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
||||||
if !m.odd {
|
if !m.odd {
|
||||||
panic("bigmod: modulus for Exp must be odd")
|
panic("bigmod: modulus for Exp must be odd")
|
||||||
@ -1025,7 +1014,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
|||||||
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
|
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
|
||||||
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
|
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
|
||||||
}
|
}
|
||||||
table[0].Set(x).montgomeryRepresentation(m)
|
table[0].set(x).montgomeryRepresentation(m)
|
||||||
for i := 1; i < len(table); i++ {
|
for i := 1; i < len(table); i++ {
|
||||||
table[i].montgomeryMul(table[i-1], table[0], m)
|
table[i].montgomeryMul(table[i-1], table[0], m)
|
||||||
}
|
}
|
||||||
@ -1071,8 +1060,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
|
|||||||
// For short exponents, precomputing a table and using a window like in Exp
|
// For short exponents, precomputing a table and using a window like in Exp
|
||||||
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
|
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
|
||||||
// chain, skipping the initial run of zeroes.
|
// chain, skipping the initial run of zeroes.
|
||||||
xR := NewNat().Set(x).montgomeryRepresentation(m)
|
xR := NewNat().set(x).montgomeryRepresentation(m)
|
||||||
out.Set(xR)
|
out.set(xR)
|
||||||
for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ {
|
for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ {
|
||||||
out.montgomeryMul(out, out, m)
|
out.montgomeryMul(out, out, m)
|
||||||
if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
|
if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
|
||||||
@ -1088,6 +1077,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
|
|||||||
//
|
//
|
||||||
// a must be reduced modulo m, but doesn't need to have the same size. The
|
// a must be reduced modulo m, but doesn't need to have the same size. The
|
||||||
// output will be resized to the size of m and overwritten.
|
// output will be resized to the size of m and overwritten.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
|
func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
|
||||||
// This is the extended binary GCD algorithm described in the Handbook of
|
// This is the extended binary GCD algorithm described in the Handbook of
|
||||||
// Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound
|
// Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound
|
||||||
@ -1121,7 +1112,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
|
|||||||
return x, false
|
return x, false
|
||||||
}
|
}
|
||||||
|
|
||||||
u := NewNat().Set(a).ExpandFor(m)
|
u := NewNat().set(a).ExpandFor(m)
|
||||||
v := m.Nat()
|
v := m.Nat()
|
||||||
|
|
||||||
A := NewNat().reset(len(m.nat.limbs))
|
A := NewNat().reset(len(m.nat.limbs))
|
||||||
@ -1148,7 +1139,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
|
|||||||
// If both u and v are odd, subtract the smaller from the larger.
|
// If both u and v are odd, subtract the smaller from the larger.
|
||||||
// If u = v, we need to subtract from v to hit the modified exit condition.
|
// If u = v, we need to subtract from v to hit the modified exit condition.
|
||||||
if u.IsOdd() == yes && v.IsOdd() == yes {
|
if u.IsOdd() == yes && v.IsOdd() == yes {
|
||||||
if v.CmpGeq(u) == no {
|
if v.cmpGeq(u) == no {
|
||||||
u.sub(v)
|
u.sub(v)
|
||||||
A.Add(C, m)
|
A.Add(C, m)
|
||||||
B.Add(D, &Modulus{nat: a})
|
B.Add(D, &Modulus{nat: a})
|
||||||
@ -1189,7 +1180,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
|
|||||||
if u.IsOne() == no {
|
if u.IsOne() == no {
|
||||||
return x, false
|
return x, false
|
||||||
}
|
}
|
||||||
return x.Set(A), true
|
return x.set(A), true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
31
internal/bigmod/nat_extension.go
Normal file
31
internal/bigmod/nat_extension.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package bigmod
|
||||||
|
|
||||||
|
func (x *Nat) Set(y *Nat) *Nat {
|
||||||
|
return x.set(y)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
|
||||||
|
//
|
||||||
|
// The output will be resized to the size of m and overwritten.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
|
func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat {
|
||||||
|
mMinusOne := NewNat().set(m.nat)
|
||||||
|
mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1
|
||||||
|
mMinusOneM, _ := NewModulus(mMinusOne.Bytes(m))
|
||||||
|
one := NewNat().resetFor(m)
|
||||||
|
one.limbs[0] = 1
|
||||||
|
x.resetToBytes(b)
|
||||||
|
x = NewNat().Mod(x, mMinusOneM) // x = x mod (m-1)
|
||||||
|
x.add(one) // we can safely add 1, no need to check overflow
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
// CmpGeq returns 1 if x >= y, and 0 otherwise.
|
||||||
|
//
|
||||||
|
// Both operands must have the same announced length.
|
||||||
|
//
|
||||||
|
//go:norace
|
||||||
|
func (x *Nat) CmpGeq(y *Nat) choice {
|
||||||
|
return x.cmpGeq(y)
|
||||||
|
}
|
@ -61,9 +61,9 @@ func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
|
|||||||
|
|
||||||
func testModAddCommutative(a *Nat, b *Nat) bool {
|
func testModAddCommutative(a *Nat, b *Nat) bool {
|
||||||
m := maxModulus(uint(len(a.limbs)))
|
m := maxModulus(uint(len(a.limbs)))
|
||||||
aPlusB := new(Nat).Set(a)
|
aPlusB := new(Nat).set(a)
|
||||||
aPlusB.Add(b, m)
|
aPlusB.Add(b, m)
|
||||||
bPlusA := new(Nat).Set(b)
|
bPlusA := new(Nat).set(b)
|
||||||
bPlusA.Add(a, m)
|
bPlusA.Add(a, m)
|
||||||
return aPlusB.Equal(bPlusA) == 1
|
return aPlusB.Equal(bPlusA) == 1
|
||||||
}
|
}
|
||||||
@ -77,7 +77,7 @@ func TestModAddCommutative(t *testing.T) {
|
|||||||
|
|
||||||
func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
|
func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
|
||||||
m := maxModulus(uint(len(a.limbs)))
|
m := maxModulus(uint(len(a.limbs)))
|
||||||
original := new(Nat).Set(a)
|
original := new(Nat).set(a)
|
||||||
a.Sub(b, m)
|
a.Sub(b, m)
|
||||||
a.Add(b, m)
|
a.Add(b, m)
|
||||||
return a.Equal(original) == 1
|
return a.Equal(original) == 1
|
||||||
@ -97,9 +97,9 @@ func TestMontgomeryRoundtrip(t *testing.T) {
|
|||||||
aPlusOne := new(big.Int).SetBytes(natBytes(a))
|
aPlusOne := new(big.Int).SetBytes(natBytes(a))
|
||||||
aPlusOne.Add(aPlusOne, big.NewInt(1))
|
aPlusOne.Add(aPlusOne, big.NewInt(1))
|
||||||
m, _ := NewModulus(aPlusOne.Bytes())
|
m, _ := NewModulus(aPlusOne.Bytes())
|
||||||
monty := new(Nat).Set(a)
|
monty := new(Nat).set(a)
|
||||||
monty.montgomeryRepresentation(m)
|
monty.montgomeryRepresentation(m)
|
||||||
aAgain := new(Nat).Set(monty)
|
aAgain := new(Nat).set(monty)
|
||||||
aAgain.montgomeryMul(monty, one, m)
|
aAgain.montgomeryMul(monty, one, m)
|
||||||
if a.Equal(aAgain) != 1 {
|
if a.Equal(aAgain) != 1 {
|
||||||
t.Errorf("%v != %v", a, aAgain)
|
t.Errorf("%v != %v", a, aAgain)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user