From 5734e67634459b6e6d81f4a8dd97ea8e1868f72c Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Mon, 17 Mar 2025 17:18:58 +0800 Subject: [PATCH] internal/cpu,internal/sm9: refactor and fix --- internal/cpu/cpu_x86.go | 64 ++++++++++++++++++------------------ internal/cpu/cpu_x86_test.go | 37 +++++++++++++++++++++ internal/sm9/sm9.go | 18 ++++------ internal/sm9/sm9_key.go | 6 ++-- 4 files changed, 79 insertions(+), 46 deletions(-) create mode 100644 internal/cpu/cpu_x86_test.go diff --git a/internal/cpu/cpu_x86.go b/internal/cpu/cpu_x86.go index 1e642f3..8831aa5 100644 --- a/internal/cpu/cpu_x86.go +++ b/internal/cpu/cpu_x86.go @@ -75,17 +75,17 @@ func archInit() { _, _, ecx1, edx1 := cpuid(1, 0) X86.HasSSE2 = isSet(26, edx1) - X86.HasSSE3 = isSet(0, ecx1) - X86.HasPCLMULQDQ = isSet(1, ecx1) - X86.HasSSSE3 = isSet(9, ecx1) - X86.HasFMA = isSet(12, ecx1) - X86.HasCX16 = isSet(13, ecx1) - X86.HasSSE41 = isSet(19, ecx1) - X86.HasSSE42 = isSet(20, ecx1) - X86.HasPOPCNT = isSet(23, ecx1) - X86.HasAES = isSet(25, ecx1) - X86.HasOSXSAVE = isSet(27, ecx1) - X86.HasRDRAND = isSet(30, ecx1) + X86.HasSSE3 = isSet(0, ecx1) // Check presence of SSE3 - bit 0 of ECX + X86.HasPCLMULQDQ = isSet(1, ecx1) // Check presence of PCLMULQDQ - bit 1 of ECX + X86.HasSSSE3 = isSet(9, ecx1) // Check presence of SSSE3 - bit 9 of ECX + X86.HasFMA = isSet(12, ecx1) // Check presence of FMA - bit 12 of ECX + X86.HasCX16 = isSet(13, ecx1) // Check presence of CX16 - bit 13 of ECX + X86.HasSSE41 = isSet(19, ecx1) // Check presence of SSE4.1 - bit 19 of ECX + X86.HasSSE42 = isSet(20, ecx1) // Check presence of SSE4.2 - bit 20 of ECX + X86.HasPOPCNT = isSet(23, ecx1) // Check presence of POPCNT - bit 23 of ECX + X86.HasAES = isSet(25, ecx1) // Check presence of AESNI - bit 25 of ECX + X86.HasOSXSAVE = isSet(27, ecx1) // Check presence of OSXSAVE - bit 27 of ECX + X86.HasRDRAND = isSet(30, ecx1) // Check presence of RDRAND - bit 30 of ECX var osSupportsAVX, osSupportsAVX512 bool // For XGETBV, OSXSAVE bit is required and sufficient. @@ -110,9 +110,9 @@ func archInit() { } eax7, ebx7, ecx7, edx7 := cpuid(7, 0) - X86.HasBMI1 = isSet(3, ebx7) - X86.HasAVX2 = isSet(5, ebx7) && osSupportsAVX - X86.HasBMI2 = isSet(8, ebx7) + X86.HasBMI1 = isSet(3, ebx7) // Check presence of BMI1 - bit 3 of EBX + X86.HasAVX2 = isSet(5, ebx7) && osSupportsAVX // Check presence of AVX2 - bit 5 of EBX + X86.HasBMI2 = isSet(8, ebx7) // Check presence of BMI2 - bit 8 of EBX X86.HasERMS = isSet(9, ebx7) X86.HasRDSEED = isSet(18, ebx7) X86.HasADX = isSet(19, ebx7) @@ -120,23 +120,23 @@ func archInit() { X86.HasAVX512 = isSet(16, ebx7) && osSupportsAVX512 // Because avx-512 foundation is the core required extension if X86.HasAVX512 { X86.HasAVX512F = true - X86.HasAVX512CD = isSet(28, ebx7) - X86.HasAVX512ER = isSet(27, ebx7) - X86.HasAVX512PF = isSet(26, ebx7) - X86.HasAVX512VL = isSet(31, ebx7) - X86.HasAVX512BW = isSet(30, ebx7) - X86.HasAVX512DQ = isSet(17, ebx7) - X86.HasAVX512IFMA = isSet(21, ebx7) - X86.HasAVX512VBMI = isSet(1, ecx7) - X86.HasAVX5124VNNIW = isSet(2, edx7) - X86.HasAVX5124FMAPS = isSet(3, edx7) - X86.HasAVX512VPOPCNTDQ = isSet(14, ecx7) - X86.HasAVX512VPCLMULQDQ = isSet(10, ecx7) - X86.HasAVX512VNNI = isSet(11, ecx7) - X86.HasAVX512GFNI = isSet(8, ecx7) - X86.HasAVX512VAES = isSet(9, ecx7) - X86.HasAVX512VBMI2 = isSet(6, ecx7) - X86.HasAVX512BITALG = isSet(12, ecx7) + X86.HasAVX512CD = isSet(28, ebx7) // Check presence of AVX512CD - bit 28 of EBX + X86.HasAVX512ER = isSet(27, ebx7) // Check presence of AVX512ER - bit 27 of EBX + X86.HasAVX512PF = isSet(26, ebx7) // Check presence of AVX512PF - bit 26 of EBX + X86.HasAVX512VL = isSet(31, ebx7) // Check presence of AVX512VL - bit 31 of EBX + X86.HasAVX512BW = isSet(30, ebx7) // Check presence of AVX512BW - bit 30 of EBX + X86.HasAVX512DQ = isSet(17, ebx7) // Check presence of AVX512F - bit 16 of EBX + X86.HasAVX512IFMA = isSet(21, ebx7) // Check presence of AVX512IFMA - bit 21 of EBX + X86.HasAVX512VBMI = isSet(1, ecx7) // Check presence of AVX512VBMI - bit 1 of ECX + X86.HasAVX5124VNNIW = isSet(2, edx7) // Check presence of AVX5124VNNIW - bit 2 of EDX + X86.HasAVX5124FMAPS = isSet(3, edx7) // Check presence of AVX5124FMAPS - bit 3 of EDX + X86.HasAVX512VPOPCNTDQ = isSet(14, ecx7) // Check presence of AVX512VPOPCNTDQ - bit 14 of ECX + X86.HasAVX512VPCLMULQDQ = isSet(10, ecx7) // Check presence of VPCLMULQDQ - bit 10 of ECX + X86.HasAVX512VNNI = isSet(11, ecx7) // Check presence of AVX512VNNI - bit 11 of ECX + X86.HasAVX512GFNI = isSet(8, ecx7) // Check presence of AVX512GFNI - bit 8 of ECX + X86.HasAVX512VAES = isSet(9, ecx7) // Check presence of AVX512VAES - bit 9 of ECX + X86.HasAVX512VBMI2 = isSet(6, ecx7) // Check presence of AVX512VBMI2 - bit 6 of ECX + X86.HasAVX512BITALG = isSet(12, ecx7) // Check presence of AVX512BITALG - bit 12 of ECX } X86.HasAMXTile = isSet(24, edx7) @@ -150,7 +150,7 @@ func archInit() { X86.HasAVX512BF16 = isSet(5, eax71) } if X86.HasAVX { - X86.HasAVXIFMA = isSet(23, eax71) + X86.HasAVXIFMA = isSet(23, eax71) // Check presence of AVXIFMA - bit 23 of EAX X86.HasAVXVNNI = isSet(4, eax71) X86.HasAVXVNNIInt8 = isSet(4, edx71) } diff --git a/internal/cpu/cpu_x86_test.go b/internal/cpu/cpu_x86_test.go new file mode 100644 index 0000000..53ee623 --- /dev/null +++ b/internal/cpu/cpu_x86_test.go @@ -0,0 +1,37 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build 386 || amd64 + +package cpu_test + +import ( + "testing" + + . "github.com/emmansun/gmsm/internal/cpu" +) + +func TestX86ifAVX2hasAVX(t *testing.T) { + if X86.HasAVX2 && !X86.HasAVX { + t.Fatalf("HasAVX expected true when HasAVX2 is true, got false") + } +} + +func TestX86ifAVX512FhasAVX2(t *testing.T) { + if X86.HasAVX512F && !X86.HasAVX2 { + t.Fatalf("HasAVX2 expected true when HasAVX512F is true, got false") + } +} + +func TestX86ifAVX512BWhasAVX512F(t *testing.T) { + if X86.HasAVX512BW && !X86.HasAVX512F { + t.Fatalf("HasAVX512F expected true when HasAVX512BW is true, got false") + } +} + +func TestX86ifAVX512VLhasAVX512F(t *testing.T) { + if X86.HasAVX512VL && !X86.HasAVX512F { + t.Fatalf("HasAVX512F expected true when HasAVX512VL is true, got false") + } +} diff --git a/internal/sm9/sm9.go b/internal/sm9/sm9.go index 46f6d33..6915c92 100644 --- a/internal/sm9/sm9.go +++ b/internal/sm9/sm9.go @@ -126,8 +126,7 @@ func (priv *SignPrivateKey) Sign(rand io.Reader, hash []byte, opts crypto.Signer r.Sub(hNat, orderNat) if r.IsZero() == 0 { // r != 0 - s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat)) - if err != nil { + if s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat)); err != nil { return nil, nil, err } break @@ -141,7 +140,8 @@ func (priv *SignPrivateKey) Sign(rand io.Reader, hash []byte, opts crypto.Signer // Verify checks the validity of a signature using the provided parameters. func (pub *SignMasterPublicKey) Verify(uid []byte, hid byte, hash, h, S []byte) bool { sPoint := new(bn256.G1) - if len(S) == len(bn256.OrderMinus1Bytes)+1 && S[0] != 0x04 { + numBytes := 2 * len(bn256.OrderBytes) + if len(S) != numBytes+1 || S[0] != 4 { return false } _, err := sPoint.Unmarshal(S[1:]) @@ -223,18 +223,14 @@ func (pub *EncryptMasterPublicKey) WrapKey(rand io.Reader, uid []byte, hid byte, // It returns the decrypted key of the specified length (kLen) or an error if decryption fails. func (priv *EncryptPrivateKey) UnwrapKey(uid, cipher []byte, kLen int) (key []byte, err error) { numBytes := 2 * len(bn256.OrderBytes) - if len(cipher) == numBytes+1 { - if cipher[0] != 0x04 { - return nil, ErrDecryption - } + if len(cipher) == numBytes+1 && cipher[0] == 4 { cipher = cipher[1:] } if len(cipher) != numBytes { return nil, ErrDecryption } p := new(bn256.G1) - _, err = p.Unmarshal(cipher) - if err != nil || !p.IsOnCurve() { + if _, err = p.Unmarshal(cipher); err != nil || !p.IsOnCurve() { return nil, ErrDecryption } @@ -366,7 +362,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) { func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA []byte) ([]byte, []byte, error) { numBytes := 2 * len(bn256.OrderBytes) - if len(rA) != numBytes+1 || rA[0] != 0x04 { + if len(rA) != numBytes+1 || rA[0] != 4 { return nil, nil, errors.New("sm9: invalid initiator's ephemeral public key") } rP := new(bn256.G1) @@ -417,7 +413,7 @@ func (ke *KeyExchange) RespondKeyExchange(rand io.Reader, hid byte, rA []byte) ( // ConfirmResponder for initiator's step A5-A7 func (ke *KeyExchange) ConfirmResponder(rB, sB []byte) ([]byte, []byte, error) { numBytes := 2 * len(bn256.OrderBytes) - if len(rB) != numBytes+1 || rB[0] != 0x04 { + if len(rB) != numBytes+1 || rB[0] != 4 { return nil, nil, errors.New("sm9: invalid responder's ephemeral public key") } pB := new(bn256.G1) diff --git a/internal/sm9/sm9_key.go b/internal/sm9/sm9_key.go index 9143afc..32d3663 100644 --- a/internal/sm9/sm9_key.go +++ b/internal/sm9/sm9_key.go @@ -275,7 +275,7 @@ func unmarshalG1(bytes []byte) (*bn256.G1, error) { return nil, err } default: - return nil, errors.New("sm9: invalid point identity byte") + return nil, errors.New("sm9: invalid point encoding") } return g, nil } @@ -426,7 +426,7 @@ func (pub *EncryptMasterPublicKey) ScalarBaseMult(scalar []byte) (*bn256.GT, err return bn256.ScalarBaseMultGT(tables, scalar) } -// GenerateUserPublicKey generate an encrypt public key for the given user. +// GenerateUserPublicKey generate an encryption public key for the given user. func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *bn256.G1 { var buffer []byte buffer = append(append(buffer, uid...), hid) @@ -456,7 +456,7 @@ func (priv *EncryptPrivateKey) MasterPublic() *EncryptMasterPublicKey { return priv.EncryptMasterPublicKey } -// SetMasterPublicKey bind the encrypt master public key to it. +// SetMasterPublicKey bind the encryption master public key to it. func (priv *EncryptPrivateKey) SetMasterPublicKey(pub *EncryptMasterPublicKey) { if priv.EncryptMasterPublicKey == nil || priv.EncryptMasterPublicKey.MasterPublicKey == nil { priv.EncryptMasterPublicKey = pub