From f9bd2f002a1e30bcd1d76a94ad34f32ac1b336f3 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Tue, 8 Aug 2023 17:26:08 +0800 Subject: [PATCH] cipher: xts supports GB/T 17964-2021 --- cipher/xts.go | 138 ++++++++++++++++++++++++++++------------- cipher/xts_sm4_test.go | 50 ++++++++++++++- cipher/xts_test.go | 8 +-- 3 files changed, 148 insertions(+), 48 deletions(-) diff --git a/cipher/xts.go b/cipher/xts.go index c4a60e7..871972b 100644 --- a/cipher/xts.go +++ b/cipher/xts.go @@ -4,7 +4,6 @@ import ( _cipher "crypto/cipher" "encoding/binary" "errors" - "sync" "github.com/emmansun/gmsm/internal/alias" "github.com/emmansun/gmsm/internal/subtle" @@ -29,43 +28,66 @@ type XTSBlockMode interface { // src must be a multiple of the block size. Dst and src must overlap // entirely or not at all. // - Encrypt(dst, src []byte, sectorNum uint64) + Encrypt(dst, src []byte, tweak *[blockSize]byte) // Decrypt decrypts a number of blocks. The length of // src must be a multiple of the block size. Dst and src must overlap // entirely or not at all. // - Decrypt(dst, src []byte, sectorNum uint64) + Decrypt(dst, src []byte, tweak *[blockSize]byte) + + // Encrypt encrypts or decrypts a number of blocks. The length of + // src must be a multiple of the block size. Dst and src must overlap + // entirely or not at all. + // + EncryptSector(dst, src []byte, sectorNum uint64) + + // Decrypt decrypts a number of blocks. The length of + // src must be a multiple of the block size. Dst and src must overlap + // entirely or not at all. + // + DecryptSector(dst, src []byte, sectorNum uint64) } // Cipher contains an expanded key structure. It is safe for concurrent use if // the underlying block cipher is safe for concurrent use. type xts struct { k1, k2 _cipher.Block + isGB bool // if true, follows GB/T 17964-2021 } // blockSize is the block size that the underlying cipher must have. XTS is // only defined for 16-byte ciphers. const blockSize = 16 -var tweakPool = sync.Pool{ - New: func() interface{} { - return new([blockSize]byte) - }, +// NewGBXTS creates a Cipher given a function for creating the underlying +// block cipher (which must have a block size of 16 bytes). The key must be +// twice the length of the underlying cipher's key. +// It follows GB/T 17964-2021. +func NewGBXTS(cipherFunc CipherCreator, key []byte) (XTSBlockMode, error) { + return newXTS(cipherFunc, key, true) } // NewXTS creates a Cipher given a function for creating the underlying // block cipher (which must have a block size of 16 bytes). The key must be // twice the length of the underlying cipher's key. func NewXTS(cipherFunc CipherCreator, key []byte) (XTSBlockMode, error) { + return newXTS(cipherFunc, key, false) +} + +func newXTS(cipherFunc CipherCreator, key []byte, isGB bool) (*xts, error) { k1, err := cipherFunc(key[:len(key)/2]) if err != nil { return nil, err } k2, err := cipherFunc(key[len(key)/2:]) + if err != nil { + return nil, err + } c := &xts{ k1, k2, + isGB, } if c.k1.BlockSize() != blockSize { @@ -79,10 +101,20 @@ func (c *xts) BlockSize() int { return blockSize } +func (c *xts) fillTweak(tweak *[blockSize]byte, sectorNum uint64) { + for i := range tweak { + tweak[i] = 0 + } + binary.LittleEndian.PutUint64(tweak[:8], sectorNum) +} + // Encrypt encrypts a sector of plaintext and puts the result into ciphertext. // Plaintext and ciphertext must overlap entirely or not at all. // Sectors must be a multiple of 16 bytes and less than 2²⁴ bytes. -func (c *xts) Encrypt(ciphertext, plaintext []byte, sectorNum uint64) { +func (c *xts) Encrypt(ciphertext, plaintext []byte, tweak *[blockSize]byte) { + if tweak == nil { + panic("xts: invalid tweak") + } if len(ciphertext) < len(plaintext) { panic("xts: ciphertext is smaller than plaintext") } @@ -93,12 +125,6 @@ func (c *xts) Encrypt(ciphertext, plaintext []byte, sectorNum uint64) { panic("xts: invalid buffer overlap") } - tweak := tweakPool.Get().(*[blockSize]byte) - - for i := range tweak { - tweak[i] = 0 - } - binary.LittleEndian.PutUint64(tweak[:8], sectorNum) c.k2.Encrypt(tweak[:], tweak[:]) lastCiphertext := ciphertext @@ -110,7 +136,7 @@ func (c *xts) Encrypt(ciphertext, plaintext []byte, sectorNum uint64) { for len(plaintext) >= batchSize { for i := 0; i < concCipher.Concurrency(); i++ { copy(tweaks[blockSize*i:], tweak[:]) - mul2(tweak) + mul2(tweak, c.isGB) } subtle.XORBytes(ciphertext, plaintext, tweaks) concCipher.EncryptBlocks(ciphertext, ciphertext) @@ -127,7 +153,7 @@ func (c *xts) Encrypt(ciphertext, plaintext []byte, sectorNum uint64) { plaintext = plaintext[blockSize:] lastCiphertext = ciphertext ciphertext = ciphertext[blockSize:] - mul2(tweak) + mul2(tweak, c.isGB) } // is there a final partial block to handle? if remain := len(plaintext); remain > 0 { @@ -145,13 +171,24 @@ func (c *xts) Encrypt(ciphertext, plaintext []byte, sectorNum uint64) { //Merge the tweak into the output block subtle.XORBytes(lastCiphertext, x[:], tweak[:]) } - tweakPool.Put(tweak) +} + +// Encrypt encrypts a sector of plaintext and puts the result into ciphertext. +// Plaintext and ciphertext must overlap entirely or not at all. +// Sectors must be a multiple of 16 bytes and less than 2²⁴ bytes. +func (c *xts) EncryptSector(ciphertext, plaintext []byte, sectorNum uint64) { + var tweak [blockSize]byte + c.fillTweak(&tweak, sectorNum) + c.Encrypt(ciphertext, plaintext, &tweak) } // Decrypt decrypts a sector of ciphertext and puts the result into plaintext. // Plaintext and ciphertext must overlap entirely or not at all. // Sectors must be a multiple of 16 bytes and less than 2²⁴ bytes. -func (c *xts) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) { +func (c *xts) Decrypt(plaintext, ciphertext []byte, tweak *[blockSize]byte) { + if tweak == nil { + panic("xts: invalid tweak") + } if len(plaintext) < len(ciphertext) { panic("xts: plaintext is smaller than ciphertext") } @@ -162,12 +199,6 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) { panic("xts: invalid buffer overlap") } - tweak := tweakPool.Get().(*[blockSize]byte) - for i := range tweak { - tweak[i] = 0 - } - binary.LittleEndian.PutUint64(tweak[:8], sectorNum) - c.k2.Encrypt(tweak[:], tweak[:]) if concCipher, ok := c.k1.(concurrentBlocks); ok { @@ -177,7 +208,7 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) { for len(ciphertext) >= batchSize { for i := 0; i < concCipher.Concurrency(); i++ { copy(tweaks[blockSize*i:], tweak[:]) - mul2(tweak) + mul2(tweak, c.isGB) } subtle.XORBytes(plaintext, ciphertext, tweaks) concCipher.DecryptBlocks(plaintext, plaintext) @@ -194,7 +225,7 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) { plaintext = plaintext[blockSize:] ciphertext = ciphertext[blockSize:] - mul2(tweak) + mul2(tweak, c.isGB) } if remain := len(ciphertext); remain >= blockSize { @@ -202,7 +233,7 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) { if remain > blockSize { var tt [blockSize]byte copy(tt[:], tweak[:]) - mul2(&tt) + mul2(&tt, c.isGB) subtle.XORBytes(x[:], ciphertext, tt[:]) c.k1.Decrypt(x[:], x[:]) subtle.XORBytes(plaintext, x[:], tt[:]) @@ -224,26 +255,49 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) { c.k1.Decrypt(x[:], x[:]) subtle.XORBytes(plaintext, x[:], tweak[:]) } +} - tweakPool.Put(tweak) +// Decrypt decrypts a sector of ciphertext and puts the result into plaintext. +// Plaintext and ciphertext must overlap entirely or not at all. +// Sectors must be a multiple of 16 bytes and less than 2²⁴ bytes. +func (c *xts) DecryptSector(plaintext, ciphertext []byte, sectorNum uint64) { + var tweak [blockSize]byte + c.fillTweak(&tweak, sectorNum) + c.Decrypt(plaintext, ciphertext, &tweak) } // mul2 multiplies tweak by 2 in GF(2¹²⁸) with an irreducible polynomial of // x¹²⁸ + x⁷ + x² + x + 1. -func mul2(tweak *[blockSize]byte) { +func mul2(tweak *[blockSize]byte, isGB bool) { var carryIn byte - for j := range tweak { - carryOut := tweak[j] >> 7 - tweak[j] = (tweak[j] << 1) + carryIn - carryIn = carryOut - } - if carryIn != 0 { - // If we have a carry bit then we need to subtract a multiple - // of the irreducible polynomial (x¹²⁸ + x⁷ + x² + x + 1). - // By dropping the carry bit, we're subtracting the x^128 term - // so all that remains is to subtract x⁷ + x² + x + 1. - // Subtraction (and addition) in this representation is just - // XOR. - tweak[0] ^= GF128_FDBK // 1<<7 | 1<<2 | 1<<1 | 1 + if !isGB { + // tweak[0] represents the coefficients of {x^7, x^6, ..., x^0} + // tweak[15] represents the coefficients of {x^127, x^126, ..., x^120} + for j := range tweak { + carryOut := tweak[j] >> 7 + tweak[j] = (tweak[j] << 1) + carryIn + carryIn = carryOut + } + if carryIn != 0 { + // If we have a carry bit then we need to subtract a multiple + // of the irreducible polynomial (x¹²⁸ + x⁷ + x² + x + 1). + // By dropping the carry bit, we're subtracting the x^128 term + // so all that remains is to subtract x⁷ + x² + x + 1. + // Subtraction (and addition) in this representation is just + // XOR. + tweak[0] ^= GF128_FDBK // 1<<7 | 1<<2 | 1<<1 | 1 + } + } else { + // GB/T 17964-2021, because of the bit-ordering, doubling is actually a right shift. + // tweak[0] represents the coefficients of {x^0, x^1, ..., x^7} + // tweak[15] represents the coefficients of {x^120, x^121, ..., x^127} + for j := range tweak { + carryOut := (tweak[j] << 7) & 0x80 + tweak[j] = (tweak[j] >> 1) + carryIn + carryIn = carryOut + } + if carryIn != 0 { + tweak[0] ^= 0xE1 // 1<<7 | 1<<6 | 1<<5 | 1 + } } } diff --git a/cipher/xts_sm4_test.go b/cipher/xts_sm4_test.go index aa6f027..bbf66aa 100644 --- a/cipher/xts_sm4_test.go +++ b/cipher/xts_sm4_test.go @@ -77,7 +77,7 @@ func TestXTS(t *testing.T) { plaintext := fromHex(test.plaintext) ciphertext := make([]byte, len(plaintext)) - c.Encrypt(ciphertext, plaintext, test.sector) + c.EncryptSector(ciphertext, plaintext, test.sector) expectedCiphertext := fromHex(test.ciphertext) if !bytes.Equal(ciphertext, expectedCiphertext) { t.Errorf("#%d: encrypted failed, got: %x, want: %x", i, ciphertext, expectedCiphertext) @@ -85,7 +85,53 @@ func TestXTS(t *testing.T) { } decrypted := make([]byte, len(ciphertext)) - c.Decrypt(decrypted, ciphertext, test.sector) + c.DecryptSector(decrypted, ciphertext, test.sector) + if !bytes.Equal(decrypted, plaintext) { + t.Errorf("#%d: decryption failed, got: %x, want: %x", i, decrypted, plaintext) + } + } +} + +// Test data is from GB/T 17964-2021 B.7 +var xtsGBTestVectors = []struct { + key string + tweak string + plaintext string + ciphertext string +}{ + { + "2B7E151628AED2A6ABF7158809CF4F3C000102030405060708090A0B0C0D0E0F", + "F0F1F2F3F4F5F6F7F8F9FAFBFCFDFEFF", + "6BC1BEE22E409F96E93D7E117393172AAE2D8A571E03AC9C9EB76FAC45AF8E5130C81C46A35CE411E5FBC1191A0A52EFF69F2445DF4F9B17", + "E9538251C71D7B80BBE4483FEF497BD12C5C581BD6242FC51E08964FB4F60FDB0BA42F63499279213D318D2C11F6886E903BE7F93A1B3479", + }, +} + +func TestXTS_GB(t *testing.T) { + for i, test := range xtsGBTestVectors { + c, err := cipher.NewGBXTS(sm4.NewCipher, fromHex(test.key)) + if err != nil { + t.Errorf("#%d: failed to create cipher: %s", i, err) + continue + } + plaintext := fromHex(test.plaintext) + ciphertext := make([]byte, len(plaintext)) + var tweak1 [16]byte + var tweak2 [16]byte + + tweak := fromHex(test.tweak) + copy(tweak1[:], tweak) + copy(tweak2[:], tweak) + + c.Encrypt(ciphertext, plaintext, &tweak1) + expectedCiphertext := fromHex(test.ciphertext) + if !bytes.Equal(ciphertext, expectedCiphertext) { + t.Errorf("#%d: encrypted failed, got: %x, want: %x", i, ciphertext, expectedCiphertext) + continue + } + + decrypted := make([]byte, len(ciphertext)) + c.Decrypt(decrypted, ciphertext, &tweak2) if !bytes.Equal(decrypted, plaintext) { t.Errorf("#%d: decryption failed, got: %x, want: %x", i, decrypted, plaintext) } diff --git a/cipher/xts_test.go b/cipher/xts_test.go index 170ef32..3cedf0d 100644 --- a/cipher/xts_test.go +++ b/cipher/xts_test.go @@ -74,7 +74,7 @@ func TestXTSWithAES(t *testing.T) { plaintext := fromHex(test.plaintext) ciphertext := make([]byte, len(plaintext)) - c.Encrypt(ciphertext, plaintext, test.sector) + c.EncryptSector(ciphertext, plaintext, test.sector) expectedCiphertext := fromHex(test.ciphertext) if !bytes.Equal(ciphertext, expectedCiphertext) { t.Errorf("#%d: encrypted failed, got: %x, want: %x", i, ciphertext, expectedCiphertext) @@ -82,7 +82,7 @@ func TestXTSWithAES(t *testing.T) { } decrypted := make([]byte, len(ciphertext)) - c.Decrypt(decrypted, ciphertext, test.sector) + c.DecryptSector(decrypted, ciphertext, test.sector) if !bytes.Equal(decrypted, plaintext) { t.Errorf("#%d: decryption failed, got: %x, want: %x", i, decrypted, plaintext) } @@ -99,8 +99,8 @@ func TestShorterCiphertext(t *testing.T) { encrypted := make([]byte, 48) decrypted := make([]byte, 48) - c.Encrypt(encrypted, plaintext, 0) - c.Decrypt(decrypted, encrypted[:len(plaintext)], 0) + c.EncryptSector(encrypted, plaintext, 0) + c.DecryptSector(decrypted, encrypted[:len(plaintext)], 0) if !bytes.Equal(plaintext, decrypted[:len(plaintext)]) { t.Errorf("En/Decryption is not inverse")