mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-27 04:36:19 +08:00
internal/sm2ec: make SetBytes constant time #309
This commit is contained in:
parent
89962cf1e3
commit
5ade794e6b
@ -134,6 +134,8 @@ package fiat
|
|||||||
import (
|
import (
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
_subtle "github.com/emmansun/gmsm/internal/subtle"
|
||||||
)
|
)
|
||||||
|
|
||||||
// {{ .Element }} is an integer modulo {{ .Prime }}.
|
// {{ .Element }} is an integer modulo {{ .Prime }}.
|
||||||
@ -202,13 +204,8 @@ func (e *{{ .Element }}) SetBytes(v []byte) (*{{ .Element }}, error) {
|
|||||||
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
|
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
|
||||||
var minusOneEncoding = new({{ .Element }}).Sub(
|
var minusOneEncoding = new({{ .Element }}).Sub(
|
||||||
new({{ .Element }}), new({{ .Element }}).One()).Bytes()
|
new({{ .Element }}), new({{ .Element }}).One()).Bytes()
|
||||||
for i := range v {
|
if _subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
|
||||||
if v[i] < minusOneEncoding[i] {
|
return nil, errors.New("invalid {{ .Element }} encoding")
|
||||||
break
|
|
||||||
}
|
|
||||||
if v[i] > minusOneEncoding[i] {
|
|
||||||
return nil, errors.New("invalid {{ .Element }} encoding")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
var in [{{ .Prefix }}ElementLen]byte
|
var in [{{ .Prefix }}ElementLen]byte
|
||||||
copy(in[:], v)
|
copy(in[:], v)
|
||||||
|
@ -9,6 +9,8 @@ package fiat
|
|||||||
import (
|
import (
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
_subtle "github.com/emmansun/gmsm/internal/subtle"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SM2P256Element is an integer modulo 2^256 - 2^224 - 2^96 + 2^64 - 1.
|
// SM2P256Element is an integer modulo 2^256 - 2^224 - 2^96 + 2^64 - 1.
|
||||||
@ -78,13 +80,8 @@ func (e *SM2P256Element) SetBytes(v []byte) (*SM2P256Element, error) {
|
|||||||
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
|
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
|
||||||
var minusOneEncoding = new(SM2P256Element).Sub(
|
var minusOneEncoding = new(SM2P256Element).Sub(
|
||||||
new(SM2P256Element), new(SM2P256Element).One()).Bytes()
|
new(SM2P256Element), new(SM2P256Element).One()).Bytes()
|
||||||
for i := range v {
|
if _subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
|
||||||
if v[i] < minusOneEncoding[i] {
|
return nil, errors.New("invalid {{ .Element }} encoding")
|
||||||
break
|
|
||||||
}
|
|
||||||
if v[i] > minusOneEncoding[i] {
|
|
||||||
return nil, errors.New("invalid SM2P256Element encoding")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var in [sm2p256ElementLen]byte
|
var in [sm2p256ElementLen]byte
|
||||||
|
@ -1,5 +1,11 @@
|
|||||||
package subtle
|
package subtle
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/bits"
|
||||||
|
|
||||||
|
"github.com/emmansun/gmsm/internal/byteorder"
|
||||||
|
)
|
||||||
|
|
||||||
func ConstantTimeAllZero(bytes []byte) int {
|
func ConstantTimeAllZero(bytes []byte) int {
|
||||||
var b uint8
|
var b uint8
|
||||||
for _, v := range bytes {
|
for _, v := range bytes {
|
||||||
@ -7,3 +13,34 @@ func ConstantTimeAllZero(bytes []byte) int {
|
|||||||
}
|
}
|
||||||
return int((uint32(b) - 1) >> 31)
|
return int((uint32(b) - 1) >> 31)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConstantTimeLessOrEqBytes returns 1 if x <= y and 0 otherwise. The comparison
|
||||||
|
// is lexigraphical, or big-endian. The time taken is a function of the length of
|
||||||
|
// the slices and is independent of the contents. If the lengths of x and y do not
|
||||||
|
// match it returns 0 immediately.
|
||||||
|
func ConstantTimeLessOrEqBytes(x, y []byte) int {
|
||||||
|
if len(x) != len(y) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do a constant time subtraction chain y - x.
|
||||||
|
// If there is no borrow at the end, then x <= y.
|
||||||
|
var b uint64
|
||||||
|
for len(x) > 8 {
|
||||||
|
x0 := byteorder.BEUint64(x[len(x)-8:])
|
||||||
|
y0 := byteorder.BEUint64(y[len(y)-8:])
|
||||||
|
_, b = bits.Sub64(y0, x0, b)
|
||||||
|
x = x[:len(x)-8]
|
||||||
|
y = y[:len(y)-8]
|
||||||
|
}
|
||||||
|
if len(x) > 0 {
|
||||||
|
xb := make([]byte, 8)
|
||||||
|
yb := make([]byte, 8)
|
||||||
|
copy(xb[8-len(x):], x)
|
||||||
|
copy(yb[8-len(y):], y)
|
||||||
|
x0 := byteorder.BEUint64(xb)
|
||||||
|
y0 := byteorder.BEUint64(yb)
|
||||||
|
_, b = bits.Sub64(y0, x0, b)
|
||||||
|
}
|
||||||
|
return int(b ^ 1)
|
||||||
|
}
|
||||||
|
@ -1,10 +1,101 @@
|
|||||||
package subtle
|
package subtle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestConstantTimeLessOrEqBytes(t *testing.T) {
|
||||||
|
r := rand.Reader
|
||||||
|
for l := 0; l < 20; l++ {
|
||||||
|
a := make([]byte, l)
|
||||||
|
b := make([]byte, l)
|
||||||
|
empty := make([]byte, l)
|
||||||
|
r.Read(a)
|
||||||
|
r.Read(b)
|
||||||
|
exp := 0
|
||||||
|
if bytes.Compare(a, b) <= 0 {
|
||||||
|
exp = 1
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(a, b); got != exp {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", a, b, got, exp)
|
||||||
|
}
|
||||||
|
exp = 0
|
||||||
|
if bytes.Compare(b, a) <= 0 {
|
||||||
|
exp = 1
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(b, a); got != exp {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", b, a, got, exp)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(empty, a); got != 1 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, a, got)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(empty, b); got != 1 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, b, got)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(a, a); got != 1 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, a, got)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(b, b); got != 1 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, b, got)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(empty, empty); got != 1 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, empty, got)
|
||||||
|
}
|
||||||
|
if l == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
max := make([]byte, l)
|
||||||
|
for i := range max {
|
||||||
|
max[i] = 0xff
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(a, max); got != 1 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, max, got)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(b, max); got != 1 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, max, got)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(empty, max); got != 1 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, max, got)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(max, max); got != 1 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", max, max, got)
|
||||||
|
}
|
||||||
|
aPlusOne := make([]byte, l)
|
||||||
|
copy(aPlusOne, a)
|
||||||
|
for i := l - 1; i >= 0; i-- {
|
||||||
|
if aPlusOne[i] == 0xff {
|
||||||
|
aPlusOne[i] = 0
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
aPlusOne[i]++
|
||||||
|
if got := ConstantTimeLessOrEqBytes(a, aPlusOne); got != 1 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, aPlusOne, got)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(aPlusOne, a); got != 0 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", aPlusOne, a, got)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
shorter := make([]byte, l-1)
|
||||||
|
copy(shorter, a)
|
||||||
|
if got := ConstantTimeLessOrEqBytes(a, shorter); got != 0 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", a, shorter, got)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(shorter, a); got != 0 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, a, got)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(b, shorter); got != 0 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", b, shorter, got)
|
||||||
|
}
|
||||||
|
if got := ConstantTimeLessOrEqBytes(shorter, b); got != 0 {
|
||||||
|
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, b, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConstantTimeAllZero(t *testing.T) {
|
func TestConstantTimeAllZero(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
bytes []byte
|
bytes []byte
|
||||||
|
Loading…
x
Reference in New Issue
Block a user