mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 20:26:19 +08:00
internal/sm2ec: make SetBytes constant time #309
This commit is contained in:
parent
89962cf1e3
commit
5ade794e6b
@ -134,6 +134,8 @@ package fiat
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
|
||||
_subtle "github.com/emmansun/gmsm/internal/subtle"
|
||||
)
|
||||
|
||||
// {{ .Element }} is an integer modulo {{ .Prime }}.
|
||||
@ -202,13 +204,8 @@ func (e *{{ .Element }}) SetBytes(v []byte) (*{{ .Element }}, error) {
|
||||
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
|
||||
var minusOneEncoding = new({{ .Element }}).Sub(
|
||||
new({{ .Element }}), new({{ .Element }}).One()).Bytes()
|
||||
for i := range v {
|
||||
if v[i] < minusOneEncoding[i] {
|
||||
break
|
||||
}
|
||||
if v[i] > minusOneEncoding[i] {
|
||||
return nil, errors.New("invalid {{ .Element }} encoding")
|
||||
}
|
||||
if _subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
|
||||
return nil, errors.New("invalid {{ .Element }} encoding")
|
||||
}
|
||||
var in [{{ .Prefix }}ElementLen]byte
|
||||
copy(in[:], v)
|
||||
|
@ -9,6 +9,8 @@ package fiat
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
|
||||
_subtle "github.com/emmansun/gmsm/internal/subtle"
|
||||
)
|
||||
|
||||
// SM2P256Element is an integer modulo 2^256 - 2^224 - 2^96 + 2^64 - 1.
|
||||
@ -78,13 +80,8 @@ func (e *SM2P256Element) SetBytes(v []byte) (*SM2P256Element, error) {
|
||||
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
|
||||
var minusOneEncoding = new(SM2P256Element).Sub(
|
||||
new(SM2P256Element), new(SM2P256Element).One()).Bytes()
|
||||
for i := range v {
|
||||
if v[i] < minusOneEncoding[i] {
|
||||
break
|
||||
}
|
||||
if v[i] > minusOneEncoding[i] {
|
||||
return nil, errors.New("invalid SM2P256Element encoding")
|
||||
}
|
||||
if _subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
|
||||
return nil, errors.New("invalid {{ .Element }} encoding")
|
||||
}
|
||||
|
||||
var in [sm2p256ElementLen]byte
|
||||
|
@ -1,5 +1,11 @@
|
||||
package subtle
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
|
||||
"github.com/emmansun/gmsm/internal/byteorder"
|
||||
)
|
||||
|
||||
func ConstantTimeAllZero(bytes []byte) int {
|
||||
var b uint8
|
||||
for _, v := range bytes {
|
||||
@ -7,3 +13,34 @@ func ConstantTimeAllZero(bytes []byte) int {
|
||||
}
|
||||
return int((uint32(b) - 1) >> 31)
|
||||
}
|
||||
|
||||
// ConstantTimeLessOrEqBytes returns 1 if x <= y and 0 otherwise. The comparison
|
||||
// is lexigraphical, or big-endian. The time taken is a function of the length of
|
||||
// the slices and is independent of the contents. If the lengths of x and y do not
|
||||
// match it returns 0 immediately.
|
||||
func ConstantTimeLessOrEqBytes(x, y []byte) int {
|
||||
if len(x) != len(y) {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Do a constant time subtraction chain y - x.
|
||||
// If there is no borrow at the end, then x <= y.
|
||||
var b uint64
|
||||
for len(x) > 8 {
|
||||
x0 := byteorder.BEUint64(x[len(x)-8:])
|
||||
y0 := byteorder.BEUint64(y[len(y)-8:])
|
||||
_, b = bits.Sub64(y0, x0, b)
|
||||
x = x[:len(x)-8]
|
||||
y = y[:len(y)-8]
|
||||
}
|
||||
if len(x) > 0 {
|
||||
xb := make([]byte, 8)
|
||||
yb := make([]byte, 8)
|
||||
copy(xb[8-len(x):], x)
|
||||
copy(yb[8-len(y):], y)
|
||||
x0 := byteorder.BEUint64(xb)
|
||||
y0 := byteorder.BEUint64(yb)
|
||||
_, b = bits.Sub64(y0, x0, b)
|
||||
}
|
||||
return int(b ^ 1)
|
||||
}
|
||||
|
@ -1,10 +1,101 @@
|
||||
package subtle
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConstantTimeLessOrEqBytes(t *testing.T) {
|
||||
r := rand.Reader
|
||||
for l := 0; l < 20; l++ {
|
||||
a := make([]byte, l)
|
||||
b := make([]byte, l)
|
||||
empty := make([]byte, l)
|
||||
r.Read(a)
|
||||
r.Read(b)
|
||||
exp := 0
|
||||
if bytes.Compare(a, b) <= 0 {
|
||||
exp = 1
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(a, b); got != exp {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", a, b, got, exp)
|
||||
}
|
||||
exp = 0
|
||||
if bytes.Compare(b, a) <= 0 {
|
||||
exp = 1
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(b, a); got != exp {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", b, a, got, exp)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(empty, a); got != 1 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, a, got)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(empty, b); got != 1 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, b, got)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(a, a); got != 1 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, a, got)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(b, b); got != 1 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, b, got)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(empty, empty); got != 1 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, empty, got)
|
||||
}
|
||||
if l == 0 {
|
||||
continue
|
||||
}
|
||||
max := make([]byte, l)
|
||||
for i := range max {
|
||||
max[i] = 0xff
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(a, max); got != 1 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, max, got)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(b, max); got != 1 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, max, got)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(empty, max); got != 1 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, max, got)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(max, max); got != 1 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", max, max, got)
|
||||
}
|
||||
aPlusOne := make([]byte, l)
|
||||
copy(aPlusOne, a)
|
||||
for i := l - 1; i >= 0; i-- {
|
||||
if aPlusOne[i] == 0xff {
|
||||
aPlusOne[i] = 0
|
||||
continue
|
||||
}
|
||||
aPlusOne[i]++
|
||||
if got := ConstantTimeLessOrEqBytes(a, aPlusOne); got != 1 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, aPlusOne, got)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(aPlusOne, a); got != 0 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", aPlusOne, a, got)
|
||||
}
|
||||
break
|
||||
}
|
||||
shorter := make([]byte, l-1)
|
||||
copy(shorter, a)
|
||||
if got := ConstantTimeLessOrEqBytes(a, shorter); got != 0 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", a, shorter, got)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(shorter, a); got != 0 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, a, got)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(b, shorter); got != 0 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", b, shorter, got)
|
||||
}
|
||||
if got := ConstantTimeLessOrEqBytes(shorter, b); got != 0 {
|
||||
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, b, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstantTimeAllZero(t *testing.T) {
|
||||
type args struct {
|
||||
bytes []byte
|
||||
|
Loading…
x
Reference in New Issue
Block a user