mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-27 12:46:18 +08:00
internal/bigmod: add support for even moduli #280
This commit is contained in:
parent
dec688f7cc
commit
4df708a76b
@ -228,6 +228,19 @@ func (x *Nat) setBytes(b []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUint assigns x = y, and returns an error if y >= m.
|
||||||
|
//
|
||||||
|
// The output will be resized to the size of m and overwritten.
|
||||||
|
func (x *Nat) SetUint(y uint, m *Modulus) (*Nat, error) {
|
||||||
|
x.resetFor(m)
|
||||||
|
// Modulus is never zero, so always at least one limb.
|
||||||
|
x.limbs[0] = y
|
||||||
|
if x.CmpGeq(m.nat) == yes {
|
||||||
|
return nil, errors.New("input overflows the modulus")
|
||||||
|
}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Equal returns 1 if x == y, and 0 otherwise.
|
// Equal returns 1 if x == y, and 0 otherwise.
|
||||||
//
|
//
|
||||||
// Both operands must have the same announced length.
|
// Both operands must have the same announced length.
|
||||||
@ -323,10 +336,8 @@ func (x *Nat) sub(y *Nat) (c uint) {
|
|||||||
|
|
||||||
// Modulus is used for modular arithmetic, precomputing relevant constants.
|
// Modulus is used for modular arithmetic, precomputing relevant constants.
|
||||||
//
|
//
|
||||||
// Moduli are assumed to be odd numbers. Moduli can also leak the exact
|
// A Modulus can leak the exact number of bits needed to store its value
|
||||||
// number of bits needed to store their value, and are stored without padding.
|
// and is stored without padding. Its actual value is still kept secret.
|
||||||
//
|
|
||||||
// Their actual value is still kept secret.
|
|
||||||
type Modulus struct {
|
type Modulus struct {
|
||||||
// The underlying natural number for this modulus.
|
// The underlying natural number for this modulus.
|
||||||
//
|
//
|
||||||
@ -334,6 +345,9 @@ type Modulus struct {
|
|||||||
// other natural number being used.
|
// other natural number being used.
|
||||||
nat *Nat
|
nat *Nat
|
||||||
leading int // number of leading zeros in the modulus
|
leading int // number of leading zeros in the modulus
|
||||||
|
|
||||||
|
// If m is even, the following fields are not set.
|
||||||
|
odd bool
|
||||||
m0inv uint // -nat.limbs[0]⁻¹ mod _W
|
m0inv uint // -nat.limbs[0]⁻¹ mod _W
|
||||||
rr *Nat // R*R for montgomeryRepresentation
|
rr *Nat // R*R for montgomeryRepresentation
|
||||||
}
|
}
|
||||||
@ -406,17 +420,20 @@ func minusInverseModW(x uint) uint {
|
|||||||
|
|
||||||
// NewModulus creates a new Modulus from a slice of big-endian bytes.
|
// NewModulus creates a new Modulus from a slice of big-endian bytes.
|
||||||
//
|
//
|
||||||
// The value must be odd. The number of significant bits (and nothing else) is
|
// The number of significant bits and whether the modulus is even is leaked
|
||||||
// leaked through timing side-channels.
|
// through timing side-channels.
|
||||||
func NewModulus(b []byte) (*Modulus, error) {
|
func NewModulus(b []byte) (*Modulus, error) {
|
||||||
if len(b) == 0 || b[len(b)-1]&1 != 1 {
|
|
||||||
return nil, errors.New("modulus must be > 0 and odd")
|
|
||||||
}
|
|
||||||
m := &Modulus{}
|
m := &Modulus{}
|
||||||
m.nat = NewNat().resetToBytes(b)
|
m.nat = NewNat().resetToBytes(b)
|
||||||
|
if len(m.nat.limbs) == 0 {
|
||||||
|
return nil, errors.New("modulus must be > 0")
|
||||||
|
}
|
||||||
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
|
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
|
||||||
|
if m.nat.limbs[0]&1 == 1 {
|
||||||
|
m.odd = true
|
||||||
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
||||||
m.rr = rr(m)
|
m.rr = rr(m)
|
||||||
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -775,17 +792,73 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
|
|||||||
// The length of both operands must be the same as the modulus. Both operands
|
// The length of both operands must be the same as the modulus. Both operands
|
||||||
// must already be reduced modulo m.
|
// must already be reduced modulo m.
|
||||||
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
|
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
|
||||||
|
if m.odd {
|
||||||
// A Montgomery multiplication by a value out of the Montgomery domain
|
// A Montgomery multiplication by a value out of the Montgomery domain
|
||||||
// takes the result out of Montgomery representation.
|
// takes the result out of Montgomery representation.
|
||||||
xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m
|
xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m
|
||||||
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
|
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
|
||||||
|
}
|
||||||
|
n := len(m.nat.limbs)
|
||||||
|
xLimbs := x.limbs[:n]
|
||||||
|
yLimbs := y.limbs[:n]
|
||||||
|
switch n {
|
||||||
|
default:
|
||||||
|
// Attempt to use a stack-allocated backing array.
|
||||||
|
T := make([]uint, 0, preallocLimbs*2)
|
||||||
|
if cap(T) < n*2 {
|
||||||
|
T = make([]uint, 0, n*2)
|
||||||
|
}
|
||||||
|
T = T[:n*2]
|
||||||
|
// T = x * y
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
T[n+i] = addMulVVW(T[i:n+i], xLimbs, yLimbs[i])
|
||||||
|
}
|
||||||
|
// x = T mod m
|
||||||
|
return x.Mod(&Nat{limbs: T}, m)
|
||||||
|
// The following specialized cases follow the exact same algorithm, but
|
||||||
|
// optimized for the sizes most used in RSA. See montgomeryMul for details.
|
||||||
|
case 256 / _W: // optimization for 256 bits nat
|
||||||
|
const n = 256 / _W // compiler hint
|
||||||
|
T := make([]uint, n*2)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
T[n+i] = addMulVVW256(&T[i], &xLimbs[0], yLimbs[i])
|
||||||
|
}
|
||||||
|
return x.Mod(&Nat{limbs: T}, m)
|
||||||
|
case 1024 / _W:
|
||||||
|
const n = 1024 / _W // compiler hint
|
||||||
|
T := make([]uint, n*2)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
T[n+i] = addMulVVW1024(&T[i], &xLimbs[0], yLimbs[i])
|
||||||
|
}
|
||||||
|
return x.Mod(&Nat{limbs: T}, m)
|
||||||
|
case 1536 / _W:
|
||||||
|
const n = 1536 / _W // compiler hint
|
||||||
|
T := make([]uint, n*2)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
T[n+i] = addMulVVW1536(&T[i], &xLimbs[0], yLimbs[i])
|
||||||
|
}
|
||||||
|
return x.Mod(&Nat{limbs: T}, m)
|
||||||
|
case 2048 / _W:
|
||||||
|
const n = 2048 / _W // compiler hint
|
||||||
|
T := make([]uint, n*2)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
T[n+i] = addMulVVW2048(&T[i], &xLimbs[0], yLimbs[i])
|
||||||
|
}
|
||||||
|
return x.Mod(&Nat{limbs: T}, m)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exp calculates out = x^e mod m.
|
// Exp calculates out = x^e mod m.
|
||||||
//
|
//
|
||||||
// The exponent e is represented in big-endian order. The output will be resized
|
// The exponent e is represented in big-endian order. The output will be resized
|
||||||
// to the size of m and overwritten. x must already be reduced modulo m.
|
// to the size of m and overwritten. x must already be reduced modulo m.
|
||||||
|
//
|
||||||
|
// m must be odd, or Exp will panic.
|
||||||
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
||||||
|
if !m.odd {
|
||||||
|
panic("bigmod: modulus for Exp must be odd")
|
||||||
|
}
|
||||||
|
|
||||||
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
|
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
|
||||||
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
|
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
|
||||||
// Using bit sizes that don't divide 8 are more complex to implement, but
|
// Using bit sizes that don't divide 8 are more complex to implement, but
|
||||||
@ -834,7 +907,12 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
|||||||
//
|
//
|
||||||
// The output will be resized to the size of m and overwritten. x must already
|
// The output will be resized to the size of m and overwritten. x must already
|
||||||
// be reduced modulo m. This leaks the exponent through timing side-channels.
|
// be reduced modulo m. This leaks the exponent through timing side-channels.
|
||||||
|
//
|
||||||
|
// m must be odd, or ExpShortVarTime will panic.
|
||||||
func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
|
func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
|
||||||
|
if !m.odd {
|
||||||
|
panic("bigmod: modulus for ExpShortVarTime must be odd")
|
||||||
|
}
|
||||||
// For short exponents, precomputing a table and using a window like in Exp
|
// For short exponents, precomputing a table and using a window like in Exp
|
||||||
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
|
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
|
||||||
// chain, skipping the initial run of zeroes.
|
// chain, skipping the initial run of zeroes.
|
||||||
|
@ -6,6 +6,7 @@ package bigmod
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
cryptorand "crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
@ -17,6 +18,19 @@ import (
|
|||||||
"testing/quick"
|
"testing/quick"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// setBig assigns x = n, optionally resizing n to the appropriate size.
|
||||||
|
//
|
||||||
|
// The announced length of x is set based on the actual bit size of the input,
|
||||||
|
// ignoring leading zeroes.
|
||||||
|
func (x *Nat) setBig(n *big.Int) *Nat {
|
||||||
|
limbs := n.Bits()
|
||||||
|
x.reset(len(limbs))
|
||||||
|
for i := range limbs {
|
||||||
|
x.limbs[i] = uint(limbs[i])
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
func (n *Nat) String() string {
|
func (n *Nat) String() string {
|
||||||
var limbs []string
|
var limbs []string
|
||||||
for i := range n.limbs {
|
for i := range n.limbs {
|
||||||
@ -312,19 +326,6 @@ func TestExpShort(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// setBig assigns x = n, optionally resizing n to the appropriate size.
|
|
||||||
//
|
|
||||||
// The announced length of x is set based on the actual bit size of the input,
|
|
||||||
// ignoring leading zeroes.
|
|
||||||
func (x *Nat) setBig(n *big.Int) *Nat {
|
|
||||||
limbs := n.Bits()
|
|
||||||
x.reset(len(limbs))
|
|
||||||
for i := range limbs {
|
|
||||||
x.limbs[i] = uint(limbs[i])
|
|
||||||
}
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestMulReductions tests that Mul reduces results equal or slightly greater
|
// TestMulReductions tests that Mul reduces results equal or slightly greater
|
||||||
// than the modulus. Some Montgomery algorithms don't and need extra care to
|
// than the modulus. Some Montgomery algorithms don't and need extra care to
|
||||||
// return correct results. See https://go.dev/issue/13907.
|
// return correct results. See https://go.dev/issue/13907.
|
||||||
@ -353,6 +354,52 @@ func TestMulReductions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMul(t *testing.T) {
|
||||||
|
t.Run("760", func(t *testing.T) { testMul(t, 760/8) })
|
||||||
|
t.Run("256", func(t *testing.T) { testMul(t, 256/8) })
|
||||||
|
t.Run("1024", func(t *testing.T) { testMul(t, 1024/8) })
|
||||||
|
t.Run("1536", func(t *testing.T) { testMul(t, 1536/8) })
|
||||||
|
t.Run("2048", func(t *testing.T) { testMul(t, 2048/8) })
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMul(t *testing.T, n int) {
|
||||||
|
a, b, m := make([]byte, n), make([]byte, n), make([]byte, n)
|
||||||
|
cryptorand.Read(a)
|
||||||
|
cryptorand.Read(b)
|
||||||
|
cryptorand.Read(m)
|
||||||
|
// Pick the highest as the modulus.
|
||||||
|
if bytes.Compare(a, m) > 0 {
|
||||||
|
a, m = m, a
|
||||||
|
}
|
||||||
|
if bytes.Compare(b, m) > 0 {
|
||||||
|
b, m = m, b
|
||||||
|
}
|
||||||
|
M, err := NewModulus(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
A, err := NewNat().SetBytes(a, M)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
B, err := NewNat().SetBytes(b, M)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
A.Mul(B, M)
|
||||||
|
ABytes := A.Bytes(M)
|
||||||
|
mBig := new(big.Int).SetBytes(m)
|
||||||
|
aBig := new(big.Int).SetBytes(a)
|
||||||
|
bBig := new(big.Int).SetBytes(b)
|
||||||
|
nBig := new(big.Int).Mul(aBig, bBig)
|
||||||
|
nBig.Mod(nBig, mBig)
|
||||||
|
nBigBytes := make([]byte, len(ABytes))
|
||||||
|
nBig.FillBytes(nBigBytes)
|
||||||
|
if !bytes.Equal(ABytes, nBigBytes) {
|
||||||
|
t.Errorf("got %x, want %x", ABytes, nBigBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func natBytes(n *Nat) []byte {
|
func natBytes(n *Nat) []byte {
|
||||||
return n.Bytes(maxModulus(uint(len(n.limbs))))
|
return n.Bytes(maxModulus(uint(len(n.limbs))))
|
||||||
}
|
}
|
||||||
|
@ -309,6 +309,8 @@ func encodingCiphertextASN1(C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, err
|
|||||||
// Most applications should use [crypto/rand.Reader] as rand. Note that the
|
// Most applications should use [crypto/rand.Reader] as rand. Note that the
|
||||||
// returned key does not depend deterministically on the bytes read from rand,
|
// returned key does not depend deterministically on the bytes read from rand,
|
||||||
// and may change between calls and/or between versions.
|
// and may change between calls and/or between versions.
|
||||||
|
//
|
||||||
|
// According GB/T 32918.1-2016, the private key must be in [1, n-2].
|
||||||
func GenerateKey(rand io.Reader) (*PrivateKey, error) {
|
func GenerateKey(rand io.Reader) (*PrivateKey, error) {
|
||||||
randutil.MaybeReadByte(rand)
|
randutil.MaybeReadByte(rand)
|
||||||
|
|
||||||
@ -331,6 +333,8 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) {
|
|||||||
// NewPrivateKey checks that key is valid and returns a SM2 PrivateKey.
|
// NewPrivateKey checks that key is valid and returns a SM2 PrivateKey.
|
||||||
//
|
//
|
||||||
// key - the private key byte slice, the length must be 32 for SM2.
|
// key - the private key byte slice, the length must be 32 for SM2.
|
||||||
|
//
|
||||||
|
// According GB/T 32918.1-2016, the private key must be in [1, n-2].
|
||||||
func NewPrivateKey(key []byte) (*PrivateKey, error) {
|
func NewPrivateKey(key []byte) (*PrivateKey, error) {
|
||||||
c := p256()
|
c := p256()
|
||||||
if len(key) != c.N.Size() {
|
if len(key) != c.N.Size() {
|
||||||
@ -364,6 +368,8 @@ func NewPrivateKeyFromInt(key *big.Int) (*PrivateKey, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewPublicKey checks that key is valid and returns a PublicKey.
|
// NewPublicKey checks that key is valid and returns a PublicKey.
|
||||||
|
//
|
||||||
|
// According GB/T 32918.1-2016, the private key must be in [1, n-2].
|
||||||
func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) {
|
func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) {
|
||||||
c := p256()
|
c := p256()
|
||||||
// Reject the point at infinity and compressed encodings.
|
// Reject the point at infinity and compressed encodings.
|
||||||
@ -598,7 +604,7 @@ func (priv *PrivateKey) inverseOfPrivateKeyPlus1(c *sm2Curve) (*bigmod.Nat, erro
|
|||||||
dp1Bytes []byte
|
dp1Bytes []byte
|
||||||
)
|
)
|
||||||
priv.inverseOfKeyPlus1Once.Do(func() {
|
priv.inverseOfKeyPlus1Once.Do(func() {
|
||||||
oneNat, _ = bigmod.NewNat().SetBytes(one.Bytes(), c.N)
|
oneNat, _ = bigmod.NewNat().SetUint(1, c.N)
|
||||||
dp1Inv, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
|
dp1Inv, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dp1Inv.Add(oneNat, c.N)
|
dp1Inv.Add(oneNat, c.N)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user