From 5ade794e6be4dcaf10ebbee1f8f692ab25b374ad Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Tue, 4 Mar 2025 08:59:51 +0800 Subject: [PATCH] internal/sm2ec: make SetBytes constant time #309 --- internal/sm2ec/fiat/generate.go | 11 ++-- internal/sm2ec/fiat/sm2p256.go | 11 ++-- internal/subtle/constant_time.go | 37 +++++++++++ internal/subtle/constant_time_test.go | 91 +++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 14 deletions(-) diff --git a/internal/sm2ec/fiat/generate.go b/internal/sm2ec/fiat/generate.go index 2fd525b..0e14902 100644 --- a/internal/sm2ec/fiat/generate.go +++ b/internal/sm2ec/fiat/generate.go @@ -134,6 +134,8 @@ package fiat import ( "crypto/subtle" "errors" + + _subtle "github.com/emmansun/gmsm/internal/subtle" ) // {{ .Element }} is an integer modulo {{ .Prime }}. @@ -202,13 +204,8 @@ func (e *{{ .Element }}) SetBytes(v []byte) (*{{ .Element }}, error) { // the encoding of -1 mod p, so p - 1, the highest canonical encoding. var minusOneEncoding = new({{ .Element }}).Sub( new({{ .Element }}), new({{ .Element }}).One()).Bytes() - for i := range v { - if v[i] < minusOneEncoding[i] { - break - } - if v[i] > minusOneEncoding[i] { - return nil, errors.New("invalid {{ .Element }} encoding") - } + if _subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 { + return nil, errors.New("invalid {{ .Element }} encoding") } var in [{{ .Prefix }}ElementLen]byte copy(in[:], v) diff --git a/internal/sm2ec/fiat/sm2p256.go b/internal/sm2ec/fiat/sm2p256.go index 9af226f..4c1c9d8 100644 --- a/internal/sm2ec/fiat/sm2p256.go +++ b/internal/sm2ec/fiat/sm2p256.go @@ -9,6 +9,8 @@ package fiat import ( "crypto/subtle" "errors" + + _subtle "github.com/emmansun/gmsm/internal/subtle" ) // SM2P256Element is an integer modulo 2^256 - 2^224 - 2^96 + 2^64 - 1. @@ -78,13 +80,8 @@ func (e *SM2P256Element) SetBytes(v []byte) (*SM2P256Element, error) { // the encoding of -1 mod p, so p - 1, the highest canonical encoding. var minusOneEncoding = new(SM2P256Element).Sub( new(SM2P256Element), new(SM2P256Element).One()).Bytes() - for i := range v { - if v[i] < minusOneEncoding[i] { - break - } - if v[i] > minusOneEncoding[i] { - return nil, errors.New("invalid SM2P256Element encoding") - } + if _subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 { + return nil, errors.New("invalid {{ .Element }} encoding") } var in [sm2p256ElementLen]byte diff --git a/internal/subtle/constant_time.go b/internal/subtle/constant_time.go index ac5cb22..6ac9551 100644 --- a/internal/subtle/constant_time.go +++ b/internal/subtle/constant_time.go @@ -1,5 +1,11 @@ package subtle +import ( + "math/bits" + + "github.com/emmansun/gmsm/internal/byteorder" +) + func ConstantTimeAllZero(bytes []byte) int { var b uint8 for _, v := range bytes { @@ -7,3 +13,34 @@ func ConstantTimeAllZero(bytes []byte) int { } return int((uint32(b) - 1) >> 31) } + +// ConstantTimeLessOrEqBytes returns 1 if x <= y and 0 otherwise. The comparison +// is lexigraphical, or big-endian. The time taken is a function of the length of +// the slices and is independent of the contents. If the lengths of x and y do not +// match it returns 0 immediately. +func ConstantTimeLessOrEqBytes(x, y []byte) int { + if len(x) != len(y) { + return 0 + } + + // Do a constant time subtraction chain y - x. + // If there is no borrow at the end, then x <= y. + var b uint64 + for len(x) > 8 { + x0 := byteorder.BEUint64(x[len(x)-8:]) + y0 := byteorder.BEUint64(y[len(y)-8:]) + _, b = bits.Sub64(y0, x0, b) + x = x[:len(x)-8] + y = y[:len(y)-8] + } + if len(x) > 0 { + xb := make([]byte, 8) + yb := make([]byte, 8) + copy(xb[8-len(x):], x) + copy(yb[8-len(y):], y) + x0 := byteorder.BEUint64(xb) + y0 := byteorder.BEUint64(yb) + _, b = bits.Sub64(y0, x0, b) + } + return int(b ^ 1) +} diff --git a/internal/subtle/constant_time_test.go b/internal/subtle/constant_time_test.go index e3b4d59..e36e267 100644 --- a/internal/subtle/constant_time_test.go +++ b/internal/subtle/constant_time_test.go @@ -1,10 +1,101 @@ package subtle import ( + "bytes" + "crypto/rand" "fmt" "testing" ) +func TestConstantTimeLessOrEqBytes(t *testing.T) { + r := rand.Reader + for l := 0; l < 20; l++ { + a := make([]byte, l) + b := make([]byte, l) + empty := make([]byte, l) + r.Read(a) + r.Read(b) + exp := 0 + if bytes.Compare(a, b) <= 0 { + exp = 1 + } + if got := ConstantTimeLessOrEqBytes(a, b); got != exp { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", a, b, got, exp) + } + exp = 0 + if bytes.Compare(b, a) <= 0 { + exp = 1 + } + if got := ConstantTimeLessOrEqBytes(b, a); got != exp { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", b, a, got, exp) + } + if got := ConstantTimeLessOrEqBytes(empty, a); got != 1 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, a, got) + } + if got := ConstantTimeLessOrEqBytes(empty, b); got != 1 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, b, got) + } + if got := ConstantTimeLessOrEqBytes(a, a); got != 1 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, a, got) + } + if got := ConstantTimeLessOrEqBytes(b, b); got != 1 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, b, got) + } + if got := ConstantTimeLessOrEqBytes(empty, empty); got != 1 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, empty, got) + } + if l == 0 { + continue + } + max := make([]byte, l) + for i := range max { + max[i] = 0xff + } + if got := ConstantTimeLessOrEqBytes(a, max); got != 1 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, max, got) + } + if got := ConstantTimeLessOrEqBytes(b, max); got != 1 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, max, got) + } + if got := ConstantTimeLessOrEqBytes(empty, max); got != 1 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, max, got) + } + if got := ConstantTimeLessOrEqBytes(max, max); got != 1 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", max, max, got) + } + aPlusOne := make([]byte, l) + copy(aPlusOne, a) + for i := l - 1; i >= 0; i-- { + if aPlusOne[i] == 0xff { + aPlusOne[i] = 0 + continue + } + aPlusOne[i]++ + if got := ConstantTimeLessOrEqBytes(a, aPlusOne); got != 1 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, aPlusOne, got) + } + if got := ConstantTimeLessOrEqBytes(aPlusOne, a); got != 0 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", aPlusOne, a, got) + } + break + } + shorter := make([]byte, l-1) + copy(shorter, a) + if got := ConstantTimeLessOrEqBytes(a, shorter); got != 0 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", a, shorter, got) + } + if got := ConstantTimeLessOrEqBytes(shorter, a); got != 0 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, a, got) + } + if got := ConstantTimeLessOrEqBytes(b, shorter); got != 0 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", b, shorter, got) + } + if got := ConstantTimeLessOrEqBytes(shorter, b); got != 0 { + t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, b, got) + } + } +} + func TestConstantTimeAllZero(t *testing.T) { type args struct { bytes []byte