sm4: reduce allocations

This commit is contained in:
Sun Yimin 2024-03-27 08:38:25 +08:00 committed by GitHub
parent 178241aa0f
commit e4909bed2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 36 additions and 54 deletions

View File

@ -7,9 +7,8 @@ import (
)
// Encrypt one block from src into dst, using the expanded key xk.
func encryptBlockGo(xk []uint32, dst, src []byte) {
func encryptBlockGo(xk *[rounds]uint32, dst, src []byte) {
_ = src[15] // early bounds check
_ = xk[31] // bounds check elimination hint
var b0, b1, b2, b3 uint32
b0 = binary.BigEndian.Uint32(src[0:4])
@ -68,10 +67,8 @@ func encryptBlockGo(xk []uint32, dst, src []byte) {
}
// Key expansion algorithm.
func expandKeyGo(key []byte, enc, dec []uint32) {
func expandKeyGo(key []byte, enc, dec *[rounds]uint32) {
// Encryption key setup.
enc = enc[:rounds]
dec = dec[:rounds]
key = key[:KeySize]
var b0, b1, b2, b3 uint32
b0 = binary.BigEndian.Uint32(key[:4]) ^ fk[0]

View File

@ -15,8 +15,8 @@ const rounds = 32
// A cipher is an instance of SM4 encryption using a particular key.
type sm4Cipher struct {
enc []uint32
dec []uint32
enc [rounds]uint32
dec [rounds]uint32
}
// NewCipher creates and returns a new cipher.Block.
@ -35,9 +35,9 @@ func NewCipher(key []byte) (cipher.Block, error) {
// newCipher creates and returns a new cipher.Block
// implemented in pure Go.
func newCipherGeneric(key []byte) (cipher.Block, error) {
c := sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}
expandKeyGo(key, c.enc, c.dec)
return &c, nil
c := &sm4Cipher{}
expandKeyGo(key, &c.enc, &c.dec)
return c, nil
}
func (c *sm4Cipher) BlockSize() int { return BlockSize }
@ -52,7 +52,7 @@ func (c *sm4Cipher) Encrypt(dst, src []byte) {
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("sm4: invalid buffer overlap")
}
encryptBlockGo(c.enc, dst, src)
encryptBlockGo(&c.enc, dst, src)
}
func (c *sm4Cipher) Decrypt(dst, src []byte) {
@ -65,5 +65,5 @@ func (c *sm4Cipher) Decrypt(dst, src []byte) {
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("sm4: invalid buffer overlap")
}
encryptBlockGo(c.dec, dst, src)
encryptBlockGo(&c.dec, dst, src)
}

View File

@ -51,12 +51,12 @@ func newCipher(key []byte) (cipher.Block, error) {
if useAVX2 {
blocks = 8
}
c := &sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, blocks, blocks * BlockSize}
c := &sm4CipherGCM{sm4CipherAsm{sm4Cipher{}, blocks, blocks * BlockSize}}
expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], INST_AES)
if supportsGFMUL {
return &sm4CipherGCM{c}, nil
return c, nil
}
return c, nil
return &c.sm4CipherAsm, nil
}
func (c *sm4CipherAsm) Concurrency() int { return c.batchBlocks }
@ -74,7 +74,7 @@ func (c *sm4CipherAsm) Encrypt(dst, src []byte) {
if useAESNI4SingleBlock {
encryptBlockAsm(&c.enc[0], &dst[0], &src[0], INST_AES)
} else {
encryptBlockGo(c.enc, dst, src)
encryptBlockGo(&c.enc, dst, src)
}
}
@ -91,7 +91,7 @@ func (c *sm4CipherAsm) Decrypt(dst, src []byte) {
if useAESNI4SingleBlock {
encryptBlockAsm(&c.dec[0], &dst[0], &src[0], INST_AES)
} else {
encryptBlockGo(c.dec, dst, src)
encryptBlockGo(&c.dec, dst, src)
}
}
@ -129,6 +129,6 @@ func expandKey(key []byte, enc, dec []uint32) {
} else if supportsAES {
expandKeyAsm(&key[0], &ck[0], &enc[0], &dec[0], INST_AES)
} else {
expandKeyGo(key, enc, dec)
expandKeyGo(key, (*[rounds]uint32)(enc), (*[rounds]uint32)(dec))
}
}

View File

@ -13,8 +13,8 @@ import (
func TestExpandKey(t *testing.T) {
key := make([]byte, 16)
encRes1 := make([]uint32, 32)
decRes1 := make([]uint32, 32)
var encRes1 [rounds]uint32
var decRes1 [rounds]uint32
encRes2 := make([]uint32, 32)
decRes2 := make([]uint32, 32)
var timeout *time.Timer
@ -32,13 +32,13 @@ func TestExpandKey(t *testing.T) {
default:
}
io.ReadFull(rand.Reader, key)
expandKeyGo(key, encRes1, decRes1)
expandKeyGo(key, &encRes1, &decRes1)
expandKey(key, encRes2, decRes2)
if !reflect.DeepEqual(encRes1, encRes2) {
t.Errorf("expected=%x, result=%x\n", encRes1, encRes2)
if !reflect.DeepEqual(encRes1[:], encRes2) {
t.Errorf("expected=%x, result=%x\n", encRes1[:], encRes2)
}
if !reflect.DeepEqual(decRes1, decRes2) {
t.Errorf("expected=%x, result=%x\n", encRes1, encRes2)
if !reflect.DeepEqual(decRes1[:], decRes2) {
t.Errorf("expected=%x, result=%x\n", decRes1[:], decRes2)
}
}
}

View File

@ -25,7 +25,7 @@ func TestWithoutGFMUL(t *testing.T) {
if useAVX2 {
blocks = 8
}
c1 := &sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, blocks, blocks * BlockSize}
c1 := &sm4CipherAsm{sm4Cipher{}, blocks, blocks * BlockSize}
expandKeyAsm(&key[0], &ck[0], &c1.enc[0], &c1.dec[0], INST_AES)
c = c1
}

View File

@ -12,9 +12,3 @@ import "crypto/cipher"
func newCipher(key []byte) (cipher.Block, error) {
return newCipherGeneric(key)
}
// expandKey is used by BenchmarkExpand and should
// call an assembly implementation if one is available.
func expandKey(key []byte, enc, dec []uint32) {
expandKeyGo(key, enc, dec)
}

View File

@ -13,12 +13,12 @@ type sm4CipherNI struct {
}
func newCipherNI(key []byte) (cipher.Block, error) {
c := &sm4CipherNI{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}}
c := &sm4CipherNIGCM{sm4CipherNI{sm4Cipher{}}}
expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], INST_SM4)
if supportsGFMUL {
return &sm4CipherNIGCM{c}, nil
return c, nil
}
return c, nil
return &c.sm4CipherNI, nil
}
func (c *sm4CipherNI) Encrypt(dst, src []byte) {

View File

@ -114,12 +114,3 @@ func BenchmarkDecrypt(b *testing.B) {
c.Decrypt(out, tt.out)
}
}
func BenchmarkExpand(b *testing.B) {
tt := encryptTests[0]
c := &sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}
b.ResetTimer()
for i := 0; i < b.N; i++ {
expandKey(tt.key, c.enc, c.dec)
}
}

View File

@ -13,7 +13,7 @@ import (
// will use the optimised implementation in this file when possible. Instances
// of this type only exist when hasGCMAsm and hasAES returns true.
type sm4CipherGCM struct {
*sm4CipherAsm
sm4CipherAsm
}
// Assert that sm4CipherGCM implements the gcmAble interface.
@ -43,10 +43,10 @@ type gcmAsm struct {
// called by crypto/cipher.NewGCM via the gcmAble interface.
func (c *sm4CipherGCM) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) {
g := &gcmAsm{}
g.cipher = c.sm4CipherAsm
g.cipher = &c.sm4CipherAsm
g.nonceSize = nonceSize
g.tagSize = tagSize
gcmSm4Init(&g.bytesProductTable, g.cipher.enc, INST_AES)
gcmSm4Init(&g.bytesProductTable, g.cipher.enc[:], INST_AES)
return g, nil
}
@ -91,7 +91,7 @@ func (g *gcmAsm) Seal(dst, nonce, plaintext, data []byte) []byte {
}
if len(plaintext) > 0 {
gcmSm4Enc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc)
gcmSm4Enc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc[:])
}
gcmSm4Finish(&g.bytesProductTable, &tagMask, &tagOut, uint64(len(plaintext)), uint64(len(data)))
copy(out[len(plaintext):], tagOut[:])
@ -144,7 +144,7 @@ func (g *gcmAsm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
panic("cipher: invalid buffer overlap")
}
if len(ciphertext) > 0 {
gcmSm4Dec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc)
gcmSm4Dec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc[:])
}
gcmSm4Finish(&g.bytesProductTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data)))

View File

@ -19,7 +19,7 @@ func gcmSm4niDec(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk
// will use the optimised implementation in this file when possible. Instances
// of this type only exist when hasGCMAsm and hasSM4 returns true.
type sm4CipherNIGCM struct {
*sm4CipherNI
sm4CipherNI
}
// Assert that sm4CipherNIGCM implements the gcmAble interface.
@ -44,10 +44,10 @@ func (g *gcmNI) Overhead() int {
// called by crypto/cipher.NewGCM via the gcmAble interface.
func (c *sm4CipherNIGCM) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) {
g := &gcmNI{}
g.cipher = c.sm4CipherNI
g.cipher = &c.sm4CipherNI
g.nonceSize = nonceSize
g.tagSize = tagSize
gcmSm4Init(&g.bytesProductTable, g.cipher.enc, INST_SM4)
gcmSm4Init(&g.bytesProductTable, g.cipher.enc[:], INST_SM4)
return g, nil
}
@ -84,7 +84,7 @@ func (g *gcmNI) Seal(dst, nonce, plaintext, data []byte) []byte {
}
if len(plaintext) > 0 {
gcmSm4niEnc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc)
gcmSm4niEnc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc[:])
}
gcmSm4Finish(&g.bytesProductTable, &tagMask, &tagOut, uint64(len(plaintext)), uint64(len(data)))
copy(out[len(plaintext):], tagOut[:])
@ -137,7 +137,7 @@ func (g *gcmNI) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
panic("cipher: invalid buffer overlap")
}
if len(ciphertext) > 0 {
gcmSm4niDec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc)
gcmSm4niDec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc[:])
}
gcmSm4Finish(&g.bytesProductTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data)))