diff --git a/sm4/gcm_asm_test.go b/sm4/gcm_asm_test.go deleted file mode 100644 index 6506e65..0000000 --- a/sm4/gcm_asm_test.go +++ /dev/null @@ -1,123 +0,0 @@ -//go:build amd64 || arm64 -// +build amd64 arm64 - -package sm4 - -import ( - "encoding/hex" - "testing" -) - -func createGcm() *gcmAsm { - key := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10} - c := sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, 4, 64} - expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0]) - c1 := &sm4CipherGCM{c} - g := &gcmAsm{} - g.cipher = &c1.sm4CipherAsm - g.tagSize = 16 - gcmSm4Init(&g.bytesProductTable, g.cipher.enc) - return g -} - -var sm4GCMTests = []struct { - plaintext string -}{ - { // case 0: < 16 - "abcdefg", - }, - { // case 1: = 16 - "abcdefgabcdefghg", - }, - { // case 2: > 16 , < 64 - "abcdefgabcdefghgabcdefgabcdefghgaaa", - }, - { // case 3: = 64 - "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghg", - }, - { // case 4: > 64, < 128 - "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgaaa", - }, - { // case 5: = 128 - "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghg", - }, - { // case 6: 227 > 128, < 256, 128 + 64 + 35 - "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgaaa", - }, - { // case 7: = 256 - "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghg", - }, - { // case 8: > 256, = 355 - "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgaaa", - }, -} - -func initCounter(i byte, counter *[16]byte) { - copy(counter[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}) - counter[gcmBlockSize-1] = i -} - -func resetTag(tag *[16]byte) { - for j := 0; j < 16; j++ { - tag[j] = 0 - } -} - -func TestGcmSm4Enc(t *testing.T) { - var counter1, counter2 [16]byte - gcm := createGcm() - var tagOut1, tagOut2 [gcmTagSize]byte - - for i, test := range sm4GCMTests { - initCounter(2, &counter1) - initCounter(1, &counter2) - - gcmSm4Data(&gcm.bytesProductTable, []byte("emmansun"), &tagOut1) - out1 := make([]byte, len(test.plaintext)+gcm.tagSize) - gcm.counterCrypt(out1, []byte(test.plaintext), &counter1) - gcmSm4Data(&gcm.bytesProductTable, out1[:len(test.plaintext)], &tagOut1) - - out2 := make([]byte, len(test.plaintext)+gcm.tagSize) - gcmSm4Data(&gcm.bytesProductTable, []byte("emmansun"), &tagOut2) - gcmSm4Enc(&gcm.bytesProductTable, out2, []byte(test.plaintext), &counter2, &tagOut2, gcm.cipher.enc) - if hex.EncodeToString(out1) != hex.EncodeToString(out2) { - t.Errorf("#%d: out expected %s, got %s", i, hex.EncodeToString(out1), hex.EncodeToString(out2)) - } - if hex.EncodeToString(tagOut1[:]) != hex.EncodeToString(tagOut2[:]) { - t.Errorf("#%d: tag expected %s, got %s", i, hex.EncodeToString(tagOut1[:]), hex.EncodeToString(tagOut2[:])) - } - resetTag(&tagOut1) - resetTag(&tagOut2) - } -} - -func TestGcmSm4Dec(t *testing.T) { - var counter1, counter2 [16]byte - gcm := createGcm() - var tagOut1, tagOut2 [gcmTagSize]byte - - for i, test := range sm4GCMTests { - initCounter(2, &counter1) - initCounter(1, &counter2) - - gcmSm4Data(&gcm.bytesProductTable, []byte("emmansun"), &tagOut1) - out1 := make([]byte, len(test.plaintext)+gcm.tagSize) - gcm.counterCrypt(out1, []byte(test.plaintext), &counter1) - gcmSm4Data(&gcm.bytesProductTable, out1[:len(test.plaintext)], &tagOut1) - - out1 = out1[:len(test.plaintext)] - - out2 := make([]byte, len(test.plaintext)+gcm.tagSize) - gcmSm4Data(&gcm.bytesProductTable, []byte("emmansun"), &tagOut2) - gcmSm4Dec(&gcm.bytesProductTable, out2, out1, &counter2, &tagOut2, gcm.cipher.enc) - - if hex.EncodeToString([]byte(test.plaintext)) != hex.EncodeToString(out2[:len(test.plaintext)]) { - t.Errorf("#%d: out expected %s, got %s", i, hex.EncodeToString([]byte(test.plaintext)), hex.EncodeToString(out2[:len(test.plaintext)])) - } - if hex.EncodeToString(tagOut1[:]) != hex.EncodeToString(tagOut2[:]) { - t.Errorf("#%d: tag expected %s, got %s", i, hex.EncodeToString(tagOut1[:]), hex.EncodeToString(tagOut2[:])) - } - resetTag(&tagOut1) - resetTag(&tagOut2) - } -} diff --git a/sm4/sm4_gcm_arm64.go b/sm4/sm4_gcm_arm64.go deleted file mode 100644 index 2573bc9..0000000 --- a/sm4/sm4_gcm_arm64.go +++ /dev/null @@ -1,166 +0,0 @@ -//go:build arm64 -// +build arm64 - -package sm4 - -import ( - "crypto/cipher" - goSubtle "crypto/subtle" - - "github.com/emmansun/gmsm/internal/subtle" -) - -// sm4CipherGCM implements crypto/cipher.gcmAble so that crypto/cipher.NewGCM -// will use the optimised implementation in this file when possible. Instances -// of this type only exist when hasGCMAsm returns true. -type sm4CipherGCM struct { - sm4CipherAsm -} - -// Assert that sm4CipherGCM implements the gcmAble interface. -var _ gcmAble = (*sm4CipherGCM)(nil) - -//go:noescape -func gcmSm4Init(productTable *[256]byte, rk []uint32) - -//go:noescape -func gcmSm4Enc(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk []uint32) - -//go:noescape -func gcmSm4Dec(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk []uint32) - -//go:noescape -func gcmSm4Data(productTable *[256]byte, data []byte, T *[16]byte) - -//go:noescape -func gcmSm4Finish(productTable *[256]byte, tagMask, T *[16]byte, pLen, dLen uint64) - -type gcmAsm struct { - gcm - bytesProductTable [256]byte -} - -// NewGCM returns the SM4 cipher wrapped in Galois Counter Mode. This is only -// called by crypto/cipher.NewGCM via the gcmAble interface. -func (c *sm4CipherGCM) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) { - g := &gcmAsm{} - g.cipher = &c.sm4CipherAsm - g.nonceSize = nonceSize - g.tagSize = tagSize - gcmSm4Init(&g.bytesProductTable, g.cipher.enc) - return g, nil -} - -func (g *gcmAsm) NonceSize() int { - return g.nonceSize -} - -func (g *gcmAsm) Overhead() int { - return g.tagSize -} - -// Seal encrypts and authenticates plaintext. See the cipher.AEAD interface for -// details. -func (g *gcmAsm) Seal(dst, nonce, plaintext, data []byte) []byte { - if len(nonce) != g.nonceSize { - panic("cipher: incorrect nonce length given to GCM") - } - if uint64(len(plaintext)) > ((1<<32)-2)*BlockSize { - panic("cipher: message too large for GCM") - } - - var counter, tagMask [gcmBlockSize]byte - - if len(nonce) == gcmStandardNonceSize { - // Init counter to nonce||1 - copy(counter[:], nonce) - counter[gcmBlockSize-1] = 1 - } else { - // Otherwise counter = GHASH(nonce) - gcmSm4Data(&g.bytesProductTable, nonce, &counter) - gcmSm4Finish(&g.bytesProductTable, &tagMask, &counter, uint64(len(nonce)), uint64(0)) - } - - g.cipher.Encrypt(tagMask[:], counter[:]) - gcmInc32(&counter) - - var tagOut [gcmTagSize]byte - - gcmSm4Data(&g.bytesProductTable, data, &tagOut) - - ret, out := subtle.SliceForAppend(dst, len(plaintext)+g.tagSize) - if subtle.InexactOverlap(out[:len(plaintext)], plaintext) { - panic("cipher: invalid buffer overlap") - } - - if len(plaintext) > 0 { - g.counterCrypt(out, plaintext, &counter) - gcmSm4Data(&g.bytesProductTable, out[:len(plaintext)], &tagOut) - } - gcmSm4Finish(&g.bytesProductTable, &tagMask, &tagOut, uint64(len(plaintext)), uint64(len(data))) - copy(out[len(plaintext):], tagOut[:]) - - return ret -} - -// Open authenticates and decrypts ciphertext. See the cipher.AEAD interface -// for details. -func (g *gcmAsm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { - if len(nonce) != g.nonceSize { - panic("cipher: incorrect nonce length given to GCM") - } - // Sanity check to prevent the authentication from always succeeding if an implementation - // leaves tagSize uninitialized, for example. - if g.tagSize < gcmMinimumTagSize { - panic("cipher: incorrect GCM tag size") - } - - if len(ciphertext) < g.tagSize { - return nil, errOpen - } - if uint64(len(ciphertext)) > ((1<<32)-2)*uint64(BlockSize)+uint64(g.tagSize) { - return nil, errOpen - } - - tag := ciphertext[len(ciphertext)-g.tagSize:] - ciphertext = ciphertext[:len(ciphertext)-g.tagSize] - - // See GCM spec, section 7.1. - var counter, tagMask [gcmBlockSize]byte - - if len(nonce) == gcmStandardNonceSize { - // Init counter to nonce||1 - copy(counter[:], nonce) - counter[gcmBlockSize-1] = 1 - } else { - // Otherwise counter = GHASH(nonce) - gcmSm4Data(&g.bytesProductTable, nonce, &counter) - gcmSm4Finish(&g.bytesProductTable, &tagMask, &counter, uint64(len(nonce)), uint64(0)) - } - - g.cipher.Encrypt(tagMask[:], counter[:]) - gcmInc32(&counter) - - var expectedTag [gcmTagSize]byte - gcmSm4Data(&g.bytesProductTable, data, &expectedTag) - - ret, out := subtle.SliceForAppend(dst, len(ciphertext)) - if subtle.InexactOverlap(out, ciphertext) { - panic("cipher: invalid buffer overlap") - } - if len(ciphertext) > 0 { - gcmSm4Data(&g.bytesProductTable, ciphertext, &expectedTag) - } - gcmSm4Finish(&g.bytesProductTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data))) - - if goSubtle.ConstantTimeCompare(expectedTag[:g.tagSize], tag) != 1 { - for i := range out { - out[i] = 0 - } - return nil, errOpen - } - - g.counterCrypt(out, ciphertext, &counter) - - return ret, nil -} diff --git a/sm4/sm4_gcm_amd64.go b/sm4/sm4_gcm_asm.go similarity index 95% rename from sm4/sm4_gcm_amd64.go rename to sm4/sm4_gcm_asm.go index 0704989..4a2b12a 100644 --- a/sm4/sm4_gcm_amd64.go +++ b/sm4/sm4_gcm_asm.go @@ -1,5 +1,5 @@ -//go:build amd64 -// +build amd64 +//go:build amd64 || arm64 +// +build amd64 arm64 package sm4 diff --git a/sm4/sm4_gcm_test.go b/sm4/sm4_gcm_test.go index 75acca9..7014b7b 100644 --- a/sm4/sm4_gcm_test.go +++ b/sm4/sm4_gcm_test.go @@ -4,6 +4,7 @@ package sm4 import ( + "encoding/hex" "fmt" "testing" ) @@ -141,3 +142,117 @@ func TestBothDataPlaintext(t *testing.T) { } fmt.Println() } + +func createGcm() *gcmAsm { + key := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10} + c := sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, 4, 64} + expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0]) + c1 := &sm4CipherGCM{c} + g := &gcmAsm{} + g.cipher = &c1.sm4CipherAsm + g.tagSize = 16 + gcmSm4Init(&g.bytesProductTable, g.cipher.enc) + return g +} + +var sm4GCMTests = []struct { + plaintext string +}{ + { // case 0: < 16 + "abcdefg", + }, + { // case 1: = 16 + "abcdefgabcdefghg", + }, + { // case 2: > 16 , < 64 + "abcdefgabcdefghgabcdefgabcdefghgaaa", + }, + { // case 3: = 64 + "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghg", + }, + { // case 4: > 64, < 128 + "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgaaa", + }, + { // case 5: = 128 + "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghg", + }, + { // case 6: 227 > 128, < 256, 128 + 64 + 35 + "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgaaa", + }, + { // case 7: = 256 + "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghg", + }, + { // case 8: > 256, = 355 + "abcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgabcdefgabcdefghgaaa", + }, +} + +func initCounter(i byte, counter *[16]byte) { + copy(counter[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}) + counter[gcmBlockSize-1] = i +} + +func resetTag(tag *[16]byte) { + for j := 0; j < 16; j++ { + tag[j] = 0 + } +} + +func TestGcmSm4Enc(t *testing.T) { + var counter1, counter2 [16]byte + gcm := createGcm() + var tagOut1, tagOut2 [gcmTagSize]byte + + for i, test := range sm4GCMTests { + initCounter(2, &counter1) + initCounter(1, &counter2) + + gcmSm4Data(&gcm.bytesProductTable, []byte("emmansun"), &tagOut1) + out1 := make([]byte, len(test.plaintext)+gcm.tagSize) + gcm.counterCrypt(out1, []byte(test.plaintext), &counter1) + gcmSm4Data(&gcm.bytesProductTable, out1[:len(test.plaintext)], &tagOut1) + + out2 := make([]byte, len(test.plaintext)+gcm.tagSize) + gcmSm4Data(&gcm.bytesProductTable, []byte("emmansun"), &tagOut2) + gcmSm4Enc(&gcm.bytesProductTable, out2, []byte(test.plaintext), &counter2, &tagOut2, gcm.cipher.enc) + if hex.EncodeToString(out1) != hex.EncodeToString(out2) { + t.Errorf("#%d: out expected %s, got %s", i, hex.EncodeToString(out1), hex.EncodeToString(out2)) + } + if hex.EncodeToString(tagOut1[:]) != hex.EncodeToString(tagOut2[:]) { + t.Errorf("#%d: tag expected %s, got %s", i, hex.EncodeToString(tagOut1[:]), hex.EncodeToString(tagOut2[:])) + } + resetTag(&tagOut1) + resetTag(&tagOut2) + } +} + +func TestGcmSm4Dec(t *testing.T) { + var counter1, counter2 [16]byte + gcm := createGcm() + var tagOut1, tagOut2 [gcmTagSize]byte + + for i, test := range sm4GCMTests { + initCounter(2, &counter1) + initCounter(1, &counter2) + + gcmSm4Data(&gcm.bytesProductTable, []byte("emmansun"), &tagOut1) + out1 := make([]byte, len(test.plaintext)+gcm.tagSize) + gcm.counterCrypt(out1, []byte(test.plaintext), &counter1) + gcmSm4Data(&gcm.bytesProductTable, out1[:len(test.plaintext)], &tagOut1) + + out1 = out1[:len(test.plaintext)] + + out2 := make([]byte, len(test.plaintext)+gcm.tagSize) + gcmSm4Data(&gcm.bytesProductTable, []byte("emmansun"), &tagOut2) + gcmSm4Dec(&gcm.bytesProductTable, out2, out1, &counter2, &tagOut2, gcm.cipher.enc) + + if hex.EncodeToString([]byte(test.plaintext)) != hex.EncodeToString(out2[:len(test.plaintext)]) { + t.Errorf("#%d: out expected %s, got %s", i, hex.EncodeToString([]byte(test.plaintext)), hex.EncodeToString(out2[:len(test.plaintext)])) + } + if hex.EncodeToString(tagOut1[:]) != hex.EncodeToString(tagOut2[:]) { + t.Errorf("#%d: tag expected %s, got %s", i, hex.EncodeToString(tagOut1[:]), hex.EncodeToString(tagOut2[:])) + } + resetTag(&tagOut1) + resetTag(&tagOut2) + } +}