internal/cpu,internal/sm9: refactor and fix

This commit is contained in:
Sun Yimin 2025-03-17 17:18:58 +08:00 committed by GitHub
parent 82ccb95527
commit 5734e67634
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 79 additions and 46 deletions

View File

@ -75,17 +75,17 @@ func archInit() {
_, _, ecx1, edx1 := cpuid(1, 0) _, _, ecx1, edx1 := cpuid(1, 0)
X86.HasSSE2 = isSet(26, edx1) X86.HasSSE2 = isSet(26, edx1)
X86.HasSSE3 = isSet(0, ecx1) X86.HasSSE3 = isSet(0, ecx1) // Check presence of SSE3 - bit 0 of ECX
X86.HasPCLMULQDQ = isSet(1, ecx1) X86.HasPCLMULQDQ = isSet(1, ecx1) // Check presence of PCLMULQDQ - bit 1 of ECX
X86.HasSSSE3 = isSet(9, ecx1) X86.HasSSSE3 = isSet(9, ecx1) // Check presence of SSSE3 - bit 9 of ECX
X86.HasFMA = isSet(12, ecx1) X86.HasFMA = isSet(12, ecx1) // Check presence of FMA - bit 12 of ECX
X86.HasCX16 = isSet(13, ecx1) X86.HasCX16 = isSet(13, ecx1) // Check presence of CX16 - bit 13 of ECX
X86.HasSSE41 = isSet(19, ecx1) X86.HasSSE41 = isSet(19, ecx1) // Check presence of SSE4.1 - bit 19 of ECX
X86.HasSSE42 = isSet(20, ecx1) X86.HasSSE42 = isSet(20, ecx1) // Check presence of SSE4.2 - bit 20 of ECX
X86.HasPOPCNT = isSet(23, ecx1) X86.HasPOPCNT = isSet(23, ecx1) // Check presence of POPCNT - bit 23 of ECX
X86.HasAES = isSet(25, ecx1) X86.HasAES = isSet(25, ecx1) // Check presence of AESNI - bit 25 of ECX
X86.HasOSXSAVE = isSet(27, ecx1) X86.HasOSXSAVE = isSet(27, ecx1) // Check presence of OSXSAVE - bit 27 of ECX
X86.HasRDRAND = isSet(30, ecx1) X86.HasRDRAND = isSet(30, ecx1) // Check presence of RDRAND - bit 30 of ECX
var osSupportsAVX, osSupportsAVX512 bool var osSupportsAVX, osSupportsAVX512 bool
// For XGETBV, OSXSAVE bit is required and sufficient. // For XGETBV, OSXSAVE bit is required and sufficient.
@ -110,9 +110,9 @@ func archInit() {
} }
eax7, ebx7, ecx7, edx7 := cpuid(7, 0) eax7, ebx7, ecx7, edx7 := cpuid(7, 0)
X86.HasBMI1 = isSet(3, ebx7) X86.HasBMI1 = isSet(3, ebx7) // Check presence of BMI1 - bit 3 of EBX
X86.HasAVX2 = isSet(5, ebx7) && osSupportsAVX X86.HasAVX2 = isSet(5, ebx7) && osSupportsAVX // Check presence of AVX2 - bit 5 of EBX
X86.HasBMI2 = isSet(8, ebx7) X86.HasBMI2 = isSet(8, ebx7) // Check presence of BMI2 - bit 8 of EBX
X86.HasERMS = isSet(9, ebx7) X86.HasERMS = isSet(9, ebx7)
X86.HasRDSEED = isSet(18, ebx7) X86.HasRDSEED = isSet(18, ebx7)
X86.HasADX = isSet(19, ebx7) X86.HasADX = isSet(19, ebx7)
@ -120,23 +120,23 @@ func archInit() {
X86.HasAVX512 = isSet(16, ebx7) && osSupportsAVX512 // Because avx-512 foundation is the core required extension X86.HasAVX512 = isSet(16, ebx7) && osSupportsAVX512 // Because avx-512 foundation is the core required extension
if X86.HasAVX512 { if X86.HasAVX512 {
X86.HasAVX512F = true X86.HasAVX512F = true
X86.HasAVX512CD = isSet(28, ebx7) X86.HasAVX512CD = isSet(28, ebx7) // Check presence of AVX512CD - bit 28 of EBX
X86.HasAVX512ER = isSet(27, ebx7) X86.HasAVX512ER = isSet(27, ebx7) // Check presence of AVX512ER - bit 27 of EBX
X86.HasAVX512PF = isSet(26, ebx7) X86.HasAVX512PF = isSet(26, ebx7) // Check presence of AVX512PF - bit 26 of EBX
X86.HasAVX512VL = isSet(31, ebx7) X86.HasAVX512VL = isSet(31, ebx7) // Check presence of AVX512VL - bit 31 of EBX
X86.HasAVX512BW = isSet(30, ebx7) X86.HasAVX512BW = isSet(30, ebx7) // Check presence of AVX512BW - bit 30 of EBX
X86.HasAVX512DQ = isSet(17, ebx7) X86.HasAVX512DQ = isSet(17, ebx7) // Check presence of AVX512F - bit 16 of EBX
X86.HasAVX512IFMA = isSet(21, ebx7) X86.HasAVX512IFMA = isSet(21, ebx7) // Check presence of AVX512IFMA - bit 21 of EBX
X86.HasAVX512VBMI = isSet(1, ecx7) X86.HasAVX512VBMI = isSet(1, ecx7) // Check presence of AVX512VBMI - bit 1 of ECX
X86.HasAVX5124VNNIW = isSet(2, edx7) X86.HasAVX5124VNNIW = isSet(2, edx7) // Check presence of AVX5124VNNIW - bit 2 of EDX
X86.HasAVX5124FMAPS = isSet(3, edx7) X86.HasAVX5124FMAPS = isSet(3, edx7) // Check presence of AVX5124FMAPS - bit 3 of EDX
X86.HasAVX512VPOPCNTDQ = isSet(14, ecx7) X86.HasAVX512VPOPCNTDQ = isSet(14, ecx7) // Check presence of AVX512VPOPCNTDQ - bit 14 of ECX
X86.HasAVX512VPCLMULQDQ = isSet(10, ecx7) X86.HasAVX512VPCLMULQDQ = isSet(10, ecx7) // Check presence of VPCLMULQDQ - bit 10 of ECX
X86.HasAVX512VNNI = isSet(11, ecx7) X86.HasAVX512VNNI = isSet(11, ecx7) // Check presence of AVX512VNNI - bit 11 of ECX
X86.HasAVX512GFNI = isSet(8, ecx7) X86.HasAVX512GFNI = isSet(8, ecx7) // Check presence of AVX512GFNI - bit 8 of ECX
X86.HasAVX512VAES = isSet(9, ecx7) X86.HasAVX512VAES = isSet(9, ecx7) // Check presence of AVX512VAES - bit 9 of ECX
X86.HasAVX512VBMI2 = isSet(6, ecx7) X86.HasAVX512VBMI2 = isSet(6, ecx7) // Check presence of AVX512VBMI2 - bit 6 of ECX
X86.HasAVX512BITALG = isSet(12, ecx7) X86.HasAVX512BITALG = isSet(12, ecx7) // Check presence of AVX512BITALG - bit 12 of ECX
} }
X86.HasAMXTile = isSet(24, edx7) X86.HasAMXTile = isSet(24, edx7)
@ -150,7 +150,7 @@ func archInit() {
X86.HasAVX512BF16 = isSet(5, eax71) X86.HasAVX512BF16 = isSet(5, eax71)
} }
if X86.HasAVX { if X86.HasAVX {
X86.HasAVXIFMA = isSet(23, eax71) X86.HasAVXIFMA = isSet(23, eax71) // Check presence of AVXIFMA - bit 23 of EAX
X86.HasAVXVNNI = isSet(4, eax71) X86.HasAVXVNNI = isSet(4, eax71)
X86.HasAVXVNNIInt8 = isSet(4, edx71) X86.HasAVXVNNIInt8 = isSet(4, edx71)
} }

View File

@ -0,0 +1,37 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build 386 || amd64
package cpu_test
import (
"testing"
. "github.com/emmansun/gmsm/internal/cpu"
)
func TestX86ifAVX2hasAVX(t *testing.T) {
if X86.HasAVX2 && !X86.HasAVX {
t.Fatalf("HasAVX expected true when HasAVX2 is true, got false")
}
}
func TestX86ifAVX512FhasAVX2(t *testing.T) {
if X86.HasAVX512F && !X86.HasAVX2 {
t.Fatalf("HasAVX2 expected true when HasAVX512F is true, got false")
}
}
func TestX86ifAVX512BWhasAVX512F(t *testing.T) {
if X86.HasAVX512BW && !X86.HasAVX512F {
t.Fatalf("HasAVX512F expected true when HasAVX512BW is true, got false")
}
}
func TestX86ifAVX512VLhasAVX512F(t *testing.T) {
if X86.HasAVX512VL && !X86.HasAVX512F {
t.Fatalf("HasAVX512F expected true when HasAVX512VL is true, got false")
}
}

View File

@ -126,8 +126,7 @@ func (priv *SignPrivateKey) Sign(rand io.Reader, hash []byte, opts crypto.Signer
r.Sub(hNat, orderNat) r.Sub(hNat, orderNat)
if r.IsZero() == 0 { // r != 0 if r.IsZero() == 0 { // r != 0
s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat)) if s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat)); err != nil {
if err != nil {
return nil, nil, err return nil, nil, err
} }
break break
@ -141,7 +140,8 @@ func (priv *SignPrivateKey) Sign(rand io.Reader, hash []byte, opts crypto.Signer
// Verify checks the validity of a signature using the provided parameters. // Verify checks the validity of a signature using the provided parameters.
func (pub *SignMasterPublicKey) Verify(uid []byte, hid byte, hash, h, S []byte) bool { func (pub *SignMasterPublicKey) Verify(uid []byte, hid byte, hash, h, S []byte) bool {
sPoint := new(bn256.G1) sPoint := new(bn256.G1)
if len(S) == len(bn256.OrderMinus1Bytes)+1 && S[0] != 0x04 { numBytes := 2 * len(bn256.OrderBytes)
if len(S) != numBytes+1 || S[0] != 4 {
return false return false
} }
_, err := sPoint.Unmarshal(S[1:]) _, err := sPoint.Unmarshal(S[1:])
@ -223,18 +223,14 @@ func (pub *EncryptMasterPublicKey) WrapKey(rand io.Reader, uid []byte, hid byte,
// It returns the decrypted key of the specified length (kLen) or an error if decryption fails. // It returns the decrypted key of the specified length (kLen) or an error if decryption fails.
func (priv *EncryptPrivateKey) UnwrapKey(uid, cipher []byte, kLen int) (key []byte, err error) { func (priv *EncryptPrivateKey) UnwrapKey(uid, cipher []byte, kLen int) (key []byte, err error) {
numBytes := 2 * len(bn256.OrderBytes) numBytes := 2 * len(bn256.OrderBytes)
if len(cipher) == numBytes+1 { if len(cipher) == numBytes+1 && cipher[0] == 4 {
if cipher[0] != 0x04 {
return nil, ErrDecryption
}
cipher = cipher[1:] cipher = cipher[1:]
} }
if len(cipher) != numBytes { if len(cipher) != numBytes {
return nil, ErrDecryption return nil, ErrDecryption
} }
p := new(bn256.G1) p := new(bn256.G1)
_, err = p.Unmarshal(cipher) if _, err = p.Unmarshal(cipher); err != nil || !p.IsOnCurve() {
if err != nil || !p.IsOnCurve() {
return nil, ErrDecryption return nil, ErrDecryption
} }
@ -366,7 +362,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA []byte) ([]byte, []byte, error) { func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA []byte) ([]byte, []byte, error) {
numBytes := 2 * len(bn256.OrderBytes) numBytes := 2 * len(bn256.OrderBytes)
if len(rA) != numBytes+1 || rA[0] != 0x04 { if len(rA) != numBytes+1 || rA[0] != 4 {
return nil, nil, errors.New("sm9: invalid initiator's ephemeral public key") return nil, nil, errors.New("sm9: invalid initiator's ephemeral public key")
} }
rP := new(bn256.G1) rP := new(bn256.G1)
@ -417,7 +413,7 @@ func (ke *KeyExchange) RespondKeyExchange(rand io.Reader, hid byte, rA []byte) (
// ConfirmResponder for initiator's step A5-A7 // ConfirmResponder for initiator's step A5-A7
func (ke *KeyExchange) ConfirmResponder(rB, sB []byte) ([]byte, []byte, error) { func (ke *KeyExchange) ConfirmResponder(rB, sB []byte) ([]byte, []byte, error) {
numBytes := 2 * len(bn256.OrderBytes) numBytes := 2 * len(bn256.OrderBytes)
if len(rB) != numBytes+1 || rB[0] != 0x04 { if len(rB) != numBytes+1 || rB[0] != 4 {
return nil, nil, errors.New("sm9: invalid responder's ephemeral public key") return nil, nil, errors.New("sm9: invalid responder's ephemeral public key")
} }
pB := new(bn256.G1) pB := new(bn256.G1)

View File

@ -275,7 +275,7 @@ func unmarshalG1(bytes []byte) (*bn256.G1, error) {
return nil, err return nil, err
} }
default: default:
return nil, errors.New("sm9: invalid point identity byte") return nil, errors.New("sm9: invalid point encoding")
} }
return g, nil return g, nil
} }
@ -426,7 +426,7 @@ func (pub *EncryptMasterPublicKey) ScalarBaseMult(scalar []byte) (*bn256.GT, err
return bn256.ScalarBaseMultGT(tables, scalar) return bn256.ScalarBaseMultGT(tables, scalar)
} }
// GenerateUserPublicKey generate an encrypt public key for the given user. // GenerateUserPublicKey generate an encryption public key for the given user.
func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *bn256.G1 { func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *bn256.G1 {
var buffer []byte var buffer []byte
buffer = append(append(buffer, uid...), hid) buffer = append(append(buffer, uid...), hid)
@ -456,7 +456,7 @@ func (priv *EncryptPrivateKey) MasterPublic() *EncryptMasterPublicKey {
return priv.EncryptMasterPublicKey return priv.EncryptMasterPublicKey
} }
// SetMasterPublicKey bind the encrypt master public key to it. // SetMasterPublicKey bind the encryption master public key to it.
func (priv *EncryptPrivateKey) SetMasterPublicKey(pub *EncryptMasterPublicKey) { func (priv *EncryptPrivateKey) SetMasterPublicKey(pub *EncryptMasterPublicKey) {
if priv.EncryptMasterPublicKey == nil || priv.EncryptMasterPublicKey.MasterPublicKey == nil { if priv.EncryptMasterPublicKey == nil || priv.EncryptMasterPublicKey.MasterPublicKey == nil {
priv.EncryptMasterPublicKey = pub priv.EncryptMasterPublicKey = pub