internal/sm2ec: make SetBytes constant time #309

This commit is contained in:
Sun Yimin 2025-03-04 08:59:51 +08:00 committed by GitHub
parent 89962cf1e3
commit 5ade794e6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 136 additions and 14 deletions

View File

@ -134,6 +134,8 @@ package fiat
import (
"crypto/subtle"
"errors"
_subtle "github.com/emmansun/gmsm/internal/subtle"
)
// {{ .Element }} is an integer modulo {{ .Prime }}.
@ -202,13 +204,8 @@ func (e *{{ .Element }}) SetBytes(v []byte) (*{{ .Element }}, error) {
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new({{ .Element }}).Sub(
new({{ .Element }}), new({{ .Element }}).One()).Bytes()
for i := range v {
if v[i] < minusOneEncoding[i] {
break
}
if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid {{ .Element }} encoding")
}
if _subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
return nil, errors.New("invalid {{ .Element }} encoding")
}
var in [{{ .Prefix }}ElementLen]byte
copy(in[:], v)

View File

@ -9,6 +9,8 @@ package fiat
import (
"crypto/subtle"
"errors"
_subtle "github.com/emmansun/gmsm/internal/subtle"
)
// SM2P256Element is an integer modulo 2^256 - 2^224 - 2^96 + 2^64 - 1.
@ -78,13 +80,8 @@ func (e *SM2P256Element) SetBytes(v []byte) (*SM2P256Element, error) {
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new(SM2P256Element).Sub(
new(SM2P256Element), new(SM2P256Element).One()).Bytes()
for i := range v {
if v[i] < minusOneEncoding[i] {
break
}
if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid SM2P256Element encoding")
}
if _subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
return nil, errors.New("invalid {{ .Element }} encoding")
}
var in [sm2p256ElementLen]byte

View File

@ -1,5 +1,11 @@
package subtle
import (
"math/bits"
"github.com/emmansun/gmsm/internal/byteorder"
)
func ConstantTimeAllZero(bytes []byte) int {
var b uint8
for _, v := range bytes {
@ -7,3 +13,34 @@ func ConstantTimeAllZero(bytes []byte) int {
}
return int((uint32(b) - 1) >> 31)
}
// ConstantTimeLessOrEqBytes returns 1 if x <= y and 0 otherwise. The comparison
// is lexigraphical, or big-endian. The time taken is a function of the length of
// the slices and is independent of the contents. If the lengths of x and y do not
// match it returns 0 immediately.
func ConstantTimeLessOrEqBytes(x, y []byte) int {
if len(x) != len(y) {
return 0
}
// Do a constant time subtraction chain y - x.
// If there is no borrow at the end, then x <= y.
var b uint64
for len(x) > 8 {
x0 := byteorder.BEUint64(x[len(x)-8:])
y0 := byteorder.BEUint64(y[len(y)-8:])
_, b = bits.Sub64(y0, x0, b)
x = x[:len(x)-8]
y = y[:len(y)-8]
}
if len(x) > 0 {
xb := make([]byte, 8)
yb := make([]byte, 8)
copy(xb[8-len(x):], x)
copy(yb[8-len(y):], y)
x0 := byteorder.BEUint64(xb)
y0 := byteorder.BEUint64(yb)
_, b = bits.Sub64(y0, x0, b)
}
return int(b ^ 1)
}

View File

@ -1,10 +1,101 @@
package subtle
import (
"bytes"
"crypto/rand"
"fmt"
"testing"
)
func TestConstantTimeLessOrEqBytes(t *testing.T) {
r := rand.Reader
for l := 0; l < 20; l++ {
a := make([]byte, l)
b := make([]byte, l)
empty := make([]byte, l)
r.Read(a)
r.Read(b)
exp := 0
if bytes.Compare(a, b) <= 0 {
exp = 1
}
if got := ConstantTimeLessOrEqBytes(a, b); got != exp {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", a, b, got, exp)
}
exp = 0
if bytes.Compare(b, a) <= 0 {
exp = 1
}
if got := ConstantTimeLessOrEqBytes(b, a); got != exp {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", b, a, got, exp)
}
if got := ConstantTimeLessOrEqBytes(empty, a); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, a, got)
}
if got := ConstantTimeLessOrEqBytes(empty, b); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, b, got)
}
if got := ConstantTimeLessOrEqBytes(a, a); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, a, got)
}
if got := ConstantTimeLessOrEqBytes(b, b); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, b, got)
}
if got := ConstantTimeLessOrEqBytes(empty, empty); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, empty, got)
}
if l == 0 {
continue
}
max := make([]byte, l)
for i := range max {
max[i] = 0xff
}
if got := ConstantTimeLessOrEqBytes(a, max); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, max, got)
}
if got := ConstantTimeLessOrEqBytes(b, max); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, max, got)
}
if got := ConstantTimeLessOrEqBytes(empty, max); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, max, got)
}
if got := ConstantTimeLessOrEqBytes(max, max); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", max, max, got)
}
aPlusOne := make([]byte, l)
copy(aPlusOne, a)
for i := l - 1; i >= 0; i-- {
if aPlusOne[i] == 0xff {
aPlusOne[i] = 0
continue
}
aPlusOne[i]++
if got := ConstantTimeLessOrEqBytes(a, aPlusOne); got != 1 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, aPlusOne, got)
}
if got := ConstantTimeLessOrEqBytes(aPlusOne, a); got != 0 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", aPlusOne, a, got)
}
break
}
shorter := make([]byte, l-1)
copy(shorter, a)
if got := ConstantTimeLessOrEqBytes(a, shorter); got != 0 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", a, shorter, got)
}
if got := ConstantTimeLessOrEqBytes(shorter, a); got != 0 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, a, got)
}
if got := ConstantTimeLessOrEqBytes(b, shorter); got != 0 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", b, shorter, got)
}
if got := ConstantTimeLessOrEqBytes(shorter, b); got != 0 {
t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, b, got)
}
}
}
func TestConstantTimeAllZero(t *testing.T) {
type args struct {
bytes []byte