cipher: refactor xts mode #149

This commit is contained in:
Sun Yimin 2023-08-17 12:48:53 +08:00 committed by GitHub
parent 71ab69ef9b
commit 9d6e46cafd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 223 additions and 136 deletions

View File

@ -19,102 +19,170 @@ type concurrentBlocks interface {
DecryptBlocks(dst, src []byte)
}
// A XTSBlockMode represents a block cipher running in a XTS mode
type XTSBlockMode interface {
// BlockSize returns the mode's block size.
BlockSize() int
// Encrypt encrypts or decrypts a number of blocks. The length of
// src must be a multiple of the block size. Dst and src must overlap
// entirely or not at all.
//
Encrypt(dst, src []byte, tweak *[blockSize]byte)
// Decrypt decrypts a number of blocks. The length of
// src must be a multiple of the block size. Dst and src must overlap
// entirely or not at all.
//
Decrypt(dst, src []byte, tweak *[blockSize]byte)
// Encrypt encrypts or decrypts a number of blocks. The length of
// src must be a multiple of the block size. Dst and src must overlap
// entirely or not at all.
//
EncryptSector(dst, src []byte, sectorNum uint64)
// Decrypt decrypts a number of blocks. The length of
// src must be a multiple of the block size. Dst and src must overlap
// entirely or not at all.
//
DecryptSector(dst, src []byte, sectorNum uint64)
}
// Cipher contains an expanded key structure. It is safe for concurrent use if
// the underlying block cipher is safe for concurrent use.
type xts struct {
k1, k2 _cipher.Block
isGB bool // if true, follows GB/T 17964-2021
b _cipher.Block
tweak [blockSize]byte
isGB bool // if true, follows GB/T 17964-2021
}
// blockSize is the block size that the underlying cipher must have. XTS is
// only defined for 16-byte ciphers.
const blockSize = 16
// NewGBXTS creates a Cipher given a function for creating the underlying
// block cipher (which must have a block size of 16 bytes). The key must be
// twice the length of the underlying cipher's key.
type xtsEncrypter xts
// xtsEncAble is an interface implemented by ciphers that have a specific
// optimized implementation of XTS encryption, like sm4.
// NewXTSEncrypter will check for this interface and return the specific
// BlockMode if found.
type xtsEncAble interface {
NewXTSEncrypter(encryptedTweak *[blockSize]byte, isGB bool) _cipher.BlockMode
}
// NewXTSEncrypter creates a Cipher given a function for creating the underlying
// block cipher (which must have a block size of 16 bytes).
func NewXTSEncrypter(cipherFunc CipherCreator, key, tweakKey, tweak []byte) (_cipher.BlockMode, error) {
return newXTSEncrypter(cipherFunc, key, tweakKey, tweak, false)
}
// NewXTSEncrypterWithSector creates a Cipher given a function for creating the underlying
// block cipher (which must have a block size of 16 bytes) with sector number.
func NewXTSEncrypterWithSector(cipherFunc CipherCreator, key, tweakKey []byte, sectorNum uint64) (_cipher.BlockMode, error) {
tweak := make([]byte, blockSize)
binary.LittleEndian.PutUint64(tweak[:8], sectorNum)
return NewXTSEncrypter(cipherFunc, key, tweakKey, tweak)
}
// NewGBXTSEncrypter creates a Cipher given a function for creating the underlying
// block cipher (which must have a block size of 16 bytes).
// It follows GB/T 17964-2021.
func NewGBXTS(cipherFunc CipherCreator, key []byte) (XTSBlockMode, error) {
return newXTS(cipherFunc, key, true)
func NewGBXTSEncrypter(cipherFunc CipherCreator, key, tweakKey, tweak []byte) (_cipher.BlockMode, error) {
return newXTSEncrypter(cipherFunc, key, tweakKey, tweak, true)
}
// NewXTS creates a Cipher given a function for creating the underlying
// block cipher (which must have a block size of 16 bytes). The key must be
// twice the length of the underlying cipher's key.
func NewXTS(cipherFunc CipherCreator, key []byte) (XTSBlockMode, error) {
return newXTS(cipherFunc, key, false)
// NewGBXTSEncrypterWithSector creates a Cipher given a function for creating the underlying
// block cipher (which must have a block size of 16 bytes) with sector number.
// It follows GB/T 17964-2021.
func NewGBXTSEncrypterWithSector(cipherFunc CipherCreator, key, tweakKey []byte, sectorNum uint64) (_cipher.BlockMode, error) {
tweak := make([]byte, blockSize)
binary.LittleEndian.PutUint64(tweak[:8], sectorNum)
return NewGBXTSEncrypter(cipherFunc, key, tweakKey, tweak)
}
func newXTS(cipherFunc CipherCreator, key []byte, isGB bool) (*xts, error) {
k1, err := cipherFunc(key[:len(key)/2])
func newXTSEncrypter(cipherFunc CipherCreator, key, tweakKey, tweak []byte, isGB bool) (_cipher.BlockMode, error) {
if len(tweak) != blockSize {
return nil, errors.New("xts: invalid tweak length")
}
k1, err := cipherFunc(key)
if err != nil {
return nil, err
}
k2, err := cipherFunc(key[len(key)/2:])
if k1.BlockSize() != blockSize {
return nil, errors.New("xts: cipher does not have a block size of 16")
}
k2, err := cipherFunc(tweakKey)
if err != nil {
return nil, err
}
if xtsable, ok := k1.(xtsEncAble); ok {
var encryptedTweak [blockSize]byte
k2.Encrypt(encryptedTweak[:], tweak)
return xtsable.NewXTSEncrypter(&encryptedTweak, isGB), nil
}
c := &xts{
k1,
k2,
isGB,
b: k1,
isGB: isGB,
}
if c.k1.BlockSize() != blockSize {
err = errors.New("xts: cipher does not have a block size of 16")
return nil, err
}
return c, nil
k2.Encrypt(c.tweak[:], tweak)
return (*xtsEncrypter)(c), nil
}
func (c *xts) BlockSize() int {
type xtsDecrypter xts
// xtsDecAble is an interface implemented by ciphers that have a specific
// optimized implementation of XTS encryption, like sm4.
// NewXTSDecrypter will check for this interface and return the specific
// BlockMode if found.
type xtsDecAble interface {
NewXTSDecrypter(encryptedTweak *[blockSize]byte, isGB bool) _cipher.BlockMode
}
// NewXTSDecrypter creates a Cipher given a function for creating the underlying
// block cipher (which must have a block size of 16 bytes) for decryption.
func NewXTSDecrypter(cipherFunc CipherCreator, key, tweakKey, tweak []byte) (_cipher.BlockMode, error) {
return newXTSDecrypter(cipherFunc, key, tweakKey, tweak, false)
}
// NewXTSDecrypterWithSector creates a Cipher given a function for creating the underlying
// block cipher (which must have a block size of 16 bytes) with sector number for decryption.
func NewXTSDecrypterWithSector(cipherFunc CipherCreator, key, tweakKey []byte, sectorNum uint64) (_cipher.BlockMode, error) {
tweak := make([]byte, blockSize)
binary.LittleEndian.PutUint64(tweak[:8], sectorNum)
return NewXTSDecrypter(cipherFunc, key, tweakKey, tweak)
}
// NewGBXTSDecrypter creates a Cipher given a function for creating the underlying
// block cipher (which must have a block size of 16 bytes) for decryption.
// It follows GB/T 17964-2021.
func NewGBXTSDecrypter(cipherFunc CipherCreator, key, tweakKey, tweak []byte) (_cipher.BlockMode, error) {
return newXTSDecrypter(cipherFunc, key, tweakKey, tweak, true)
}
// NewGBXTSDecrypterWithSector creates a Cipher given a function for creating the underlying
// block cipher (which must have a block size of 16 bytes) with sector number for decryption.
// It follows GB/T 17964-2021.
func NewGBXTSDecrypterWithSector(cipherFunc CipherCreator, key, tweakKey []byte, sectorNum uint64) (_cipher.BlockMode, error) {
tweak := make([]byte, blockSize)
binary.LittleEndian.PutUint64(tweak[:8], sectorNum)
return NewGBXTSDecrypter(cipherFunc, key, tweakKey, tweak)
}
func newXTSDecrypter(cipherFunc CipherCreator, key, tweakKey, tweak []byte, isGB bool) (_cipher.BlockMode, error) {
if len(tweak) != blockSize {
return nil, errors.New("xts: invalid tweak length")
}
k1, err := cipherFunc(key)
if err != nil {
return nil, err
}
if k1.BlockSize() != blockSize {
return nil, errors.New("xts: cipher does not have a block size of 16")
}
k2, err := cipherFunc(tweakKey)
if err != nil {
return nil, err
}
if xtsable, ok := k1.(xtsDecAble); ok {
var encryptedTweak [blockSize]byte
k2.Encrypt(encryptedTweak[:], tweak)
return xtsable.NewXTSDecrypter(&encryptedTweak, isGB), nil
}
c := &xts{
b: k1,
isGB: isGB,
}
k2.Encrypt(c.tweak[:], tweak)
return (*xtsDecrypter)(c), nil
}
func (c *xtsEncrypter) BlockSize() int {
return blockSize
}
func (c *xts) fillTweak(tweak *[blockSize]byte, sectorNum uint64) {
for i := range tweak {
tweak[i] = 0
}
binary.LittleEndian.PutUint64(tweak[:8], sectorNum)
}
// Encrypt encrypts a sector of plaintext and puts the result into ciphertext.
// CryptBlocks encrypts a sector of plaintext and puts the result into ciphertext.
// Plaintext and ciphertext must overlap entirely or not at all.
// Sectors must be a multiple of 16 bytes and less than 2²⁴ bytes.
func (c *xts) Encrypt(ciphertext, plaintext []byte, tweak *[blockSize]byte) {
if tweak == nil {
panic("xts: invalid tweak")
}
func (c *xtsEncrypter) CryptBlocks(ciphertext, plaintext []byte) {
if len(ciphertext) < len(plaintext) {
panic("xts: ciphertext is smaller than plaintext")
}
@ -125,18 +193,16 @@ func (c *xts) Encrypt(ciphertext, plaintext []byte, tweak *[blockSize]byte) {
panic("xts: invalid buffer overlap")
}
c.k2.Encrypt(tweak[:], tweak[:])
lastCiphertext := ciphertext
if concCipher, ok := c.k1.(concurrentBlocks); ok {
if concCipher, ok := c.b.(concurrentBlocks); ok {
batchSize := concCipher.Concurrency() * blockSize
var tweaks []byte = make([]byte, batchSize)
for len(plaintext) >= batchSize {
for i := 0; i < concCipher.Concurrency(); i++ {
copy(tweaks[blockSize*i:], tweak[:])
mul2(tweak, c.isGB)
copy(tweaks[blockSize*i:], c.tweak[:])
mul2(&c.tweak, c.isGB)
}
subtle.XORBytes(ciphertext, plaintext, tweaks)
concCipher.EncryptBlocks(ciphertext, ciphertext)
@ -147,13 +213,13 @@ func (c *xts) Encrypt(ciphertext, plaintext []byte, tweak *[blockSize]byte) {
}
}
for len(plaintext) >= blockSize {
subtle.XORBytes(ciphertext, plaintext, tweak[:])
c.k1.Encrypt(ciphertext, ciphertext)
subtle.XORBytes(ciphertext, ciphertext, tweak[:])
subtle.XORBytes(ciphertext, plaintext, c.tweak[:])
c.b.Encrypt(ciphertext, ciphertext)
subtle.XORBytes(ciphertext, ciphertext, c.tweak[:])
plaintext = plaintext[blockSize:]
lastCiphertext = ciphertext
ciphertext = ciphertext[blockSize:]
mul2(tweak, c.isGB)
mul2(&c.tweak, c.isGB)
}
// is there a final partial block to handle?
if remain := len(plaintext); remain > 0 {
@ -165,30 +231,22 @@ func (c *xts) Encrypt(ciphertext, plaintext []byte, tweak *[blockSize]byte) {
//Steal ciphertext to complete the block
copy(x[remain:], lastCiphertext[remain:blockSize])
//Merge the tweak into the input block
subtle.XORBytes(x[:], x[:], tweak[:])
subtle.XORBytes(x[:], x[:], c.tweak[:])
//Encrypt the final block using K1
c.k1.Encrypt(x[:], x[:])
c.b.Encrypt(x[:], x[:])
//Merge the tweak into the output block
subtle.XORBytes(lastCiphertext, x[:], tweak[:])
subtle.XORBytes(lastCiphertext, x[:], c.tweak[:])
}
}
// Encrypt encrypts a sector of plaintext and puts the result into ciphertext.
// Plaintext and ciphertext must overlap entirely or not at all.
// Sectors must be a multiple of 16 bytes and less than 2²⁴ bytes.
func (c *xts) EncryptSector(ciphertext, plaintext []byte, sectorNum uint64) {
var tweak [blockSize]byte
c.fillTweak(&tweak, sectorNum)
c.Encrypt(ciphertext, plaintext, &tweak)
func (c *xtsDecrypter) BlockSize() int {
return blockSize
}
// Decrypt decrypts a sector of ciphertext and puts the result into plaintext.
// CryptBlocks decrypts a sector of ciphertext and puts the result into plaintext.
// Plaintext and ciphertext must overlap entirely or not at all.
// Sectors must be a multiple of 16 bytes and less than 2²⁴ bytes.
func (c *xts) Decrypt(plaintext, ciphertext []byte, tweak *[blockSize]byte) {
if tweak == nil {
panic("xts: invalid tweak")
}
func (c *xtsDecrypter) CryptBlocks(plaintext, ciphertext []byte) {
if len(plaintext) < len(ciphertext) {
panic("xts: plaintext is smaller than ciphertext")
}
@ -199,16 +257,14 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, tweak *[blockSize]byte) {
panic("xts: invalid buffer overlap")
}
c.k2.Encrypt(tweak[:], tweak[:])
if concCipher, ok := c.k1.(concurrentBlocks); ok {
if concCipher, ok := c.b.(concurrentBlocks); ok {
batchSize := concCipher.Concurrency() * blockSize
var tweaks []byte = make([]byte, batchSize)
for len(ciphertext) >= batchSize {
for i := 0; i < concCipher.Concurrency(); i++ {
copy(tweaks[blockSize*i:], tweak[:])
mul2(tweak, c.isGB)
copy(tweaks[blockSize*i:], c.tweak[:])
mul2(&c.tweak, c.isGB)
}
subtle.XORBytes(plaintext, ciphertext, tweaks)
concCipher.DecryptBlocks(plaintext, plaintext)
@ -219,23 +275,23 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, tweak *[blockSize]byte) {
}
for len(ciphertext) >= 2*blockSize {
subtle.XORBytes(plaintext, ciphertext, tweak[:])
c.k1.Decrypt(plaintext, plaintext)
subtle.XORBytes(plaintext, plaintext, tweak[:])
subtle.XORBytes(plaintext, ciphertext, c.tweak[:])
c.b.Decrypt(plaintext, plaintext)
subtle.XORBytes(plaintext, plaintext, c.tweak[:])
plaintext = plaintext[blockSize:]
ciphertext = ciphertext[blockSize:]
mul2(tweak, c.isGB)
mul2(&c.tweak, c.isGB)
}
if remain := len(ciphertext); remain >= blockSize {
var x [blockSize]byte
if remain > blockSize {
var tt [blockSize]byte
copy(tt[:], tweak[:])
copy(tt[:], c.tweak[:])
mul2(&tt, c.isGB)
subtle.XORBytes(x[:], ciphertext, tt[:])
c.k1.Decrypt(x[:], x[:])
c.b.Decrypt(x[:], x[:])
subtle.XORBytes(plaintext, x[:], tt[:])
//Retrieve the length of the final block
@ -251,21 +307,12 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, tweak *[blockSize]byte) {
//The last block contains exactly 128 bits
copy(x[:], ciphertext)
}
subtle.XORBytes(x[:], x[:], tweak[:])
c.k1.Decrypt(x[:], x[:])
subtle.XORBytes(plaintext, x[:], tweak[:])
subtle.XORBytes(x[:], x[:], c.tweak[:])
c.b.Decrypt(x[:], x[:])
subtle.XORBytes(plaintext, x[:], c.tweak[:])
}
}
// Decrypt decrypts a sector of ciphertext and puts the result into plaintext.
// Plaintext and ciphertext must overlap entirely or not at all.
// Sectors must be a multiple of 16 bytes and less than 2²⁴ bytes.
func (c *xts) DecryptSector(plaintext, ciphertext []byte, sectorNum uint64) {
var tweak [blockSize]byte
c.fillTweak(&tweak, sectorNum)
c.Decrypt(plaintext, ciphertext, &tweak)
}
// mul2 multiplies tweak by 2 in GF(2¹²⁸) with an irreducible polynomial of
// x¹²⁸ + x⁷ + x² + x + 1.
func mul2(tweak *[blockSize]byte, isGB bool) {

View File

@ -69,15 +69,22 @@ func fromHex(s string) []byte {
func TestXTS(t *testing.T) {
for i, test := range xtsTestVectors {
c, err := cipher.NewXTS(sm4.NewCipher, fromHex(test.key))
key := fromHex(test.key)
encrypter, err := cipher.NewXTSEncrypterWithSector(sm4.NewCipher, key[:len(key)/2], key[len(key)/2:], test.sector)
if err != nil {
t.Errorf("#%d: failed to create cipher: %s", i, err)
t.Errorf("#%d: failed to create encrypter: %s", i, err)
continue
}
decrypter, err := cipher.NewXTSDecrypterWithSector(sm4.NewCipher, key[:len(key)/2], key[len(key)/2:], test.sector)
if err != nil {
t.Errorf("#%d: failed to create decrypter: %s", i, err)
continue
}
plaintext := fromHex(test.plaintext)
ciphertext := make([]byte, len(plaintext))
c.EncryptSector(ciphertext, plaintext, test.sector)
encrypter.CryptBlocks(ciphertext, plaintext)
expectedCiphertext := fromHex(test.ciphertext)
if !bytes.Equal(ciphertext, expectedCiphertext) {
t.Errorf("#%d: encrypted failed, got: %x, want: %x", i, ciphertext, expectedCiphertext)
@ -85,7 +92,7 @@ func TestXTS(t *testing.T) {
}
decrypted := make([]byte, len(ciphertext))
c.DecryptSector(decrypted, ciphertext, test.sector)
decrypter.CryptBlocks(decrypted, ciphertext)
if !bytes.Equal(decrypted, plaintext) {
t.Errorf("#%d: decryption failed, got: %x, want: %x", i, decrypted, plaintext)
}
@ -109,21 +116,22 @@ var xtsGBTestVectors = []struct {
func TestXTS_GB(t *testing.T) {
for i, test := range xtsGBTestVectors {
c, err := cipher.NewGBXTS(sm4.NewCipher, fromHex(test.key))
key := fromHex(test.key)
tweak := fromHex(test.tweak)
encrypter, err := cipher.NewGBXTSEncrypter(sm4.NewCipher, key[:len(key)/2], key[len(key)/2:], tweak)
if err != nil {
t.Errorf("#%d: failed to create cipher: %s", i, err)
t.Errorf("#%d: failed to create encrypter: %s", i, err)
continue
}
decrypter, err := cipher.NewGBXTSDecrypter(sm4.NewCipher, key[:len(key)/2], key[len(key)/2:], tweak)
if err != nil {
t.Errorf("#%d: failed to create decrypter: %s", i, err)
continue
}
plaintext := fromHex(test.plaintext)
ciphertext := make([]byte, len(plaintext))
var tweak1 [16]byte
var tweak2 [16]byte
tweak := fromHex(test.tweak)
copy(tweak1[:], tweak)
copy(tweak2[:], tweak)
c.Encrypt(ciphertext, plaintext, &tweak1)
encrypter.CryptBlocks(ciphertext, plaintext)
expectedCiphertext := fromHex(test.ciphertext)
if !bytes.Equal(ciphertext, expectedCiphertext) {
t.Errorf("#%d: encrypted failed, got: %x, want: %x", i, ciphertext, expectedCiphertext)
@ -131,7 +139,7 @@ func TestXTS_GB(t *testing.T) {
}
decrypted := make([]byte, len(ciphertext))
c.Decrypt(decrypted, ciphertext, &tweak2)
decrypter.CryptBlocks(decrypted, ciphertext)
if !bytes.Equal(decrypted, plaintext) {
t.Errorf("#%d: decryption failed, got: %x, want: %x", i, decrypted, plaintext)
}

View File

@ -66,15 +66,22 @@ var xtsAesTestVectors = []struct {
func TestXTSWithAES(t *testing.T) {
for i, test := range xtsAesTestVectors {
c, err := cipher.NewXTS(aes.NewCipher, fromHex(test.key))
key := fromHex(test.key)
encrypter, err := cipher.NewXTSEncrypterWithSector(aes.NewCipher, key[:len(key)/2], key[len(key)/2:], test.sector)
if err != nil {
t.Errorf("#%d: failed to create cipher: %s", i, err)
t.Errorf("#%d: failed to create encrypter: %s", i, err)
continue
}
decrypter, err := cipher.NewXTSDecrypterWithSector(aes.NewCipher, key[:len(key)/2], key[len(key)/2:], test.sector)
if err != nil {
t.Errorf("#%d: failed to create decrypter: %s", i, err)
continue
}
plaintext := fromHex(test.plaintext)
ciphertext := make([]byte, len(plaintext))
c.EncryptSector(ciphertext, plaintext, test.sector)
encrypter.CryptBlocks(ciphertext, plaintext)
expectedCiphertext := fromHex(test.ciphertext)
if !bytes.Equal(ciphertext, expectedCiphertext) {
t.Errorf("#%d: encrypted failed, got: %x, want: %x", i, ciphertext, expectedCiphertext)
@ -82,7 +89,7 @@ func TestXTSWithAES(t *testing.T) {
}
decrypted := make([]byte, len(ciphertext))
c.DecryptSector(decrypted, ciphertext, test.sector)
decrypter.CryptBlocks(decrypted, ciphertext)
if !bytes.Equal(decrypted, plaintext) {
t.Errorf("#%d: decryption failed, got: %x, want: %x", i, decrypted, plaintext)
}
@ -90,17 +97,22 @@ func TestXTSWithAES(t *testing.T) {
}
func TestShorterCiphertext(t *testing.T) {
c, err := cipher.NewXTS(aes.NewCipher, make([]byte, 32))
encrypter, err := cipher.NewXTSEncrypterWithSector(aes.NewCipher, make([]byte, 16), make([]byte, 16), 0)
if err != nil {
t.Fatalf("NewCipher failed: %s", err)
t.Fatalf("NewXTSEncrypterWithSector failed: %s", err)
}
decrypter, err := cipher.NewXTSDecrypterWithSector(aes.NewCipher, make([]byte, 16), make([]byte, 16), 0)
if err != nil {
t.Fatalf("NewXTSDecrypterWithSector failed: %s", err)
}
plaintext := make([]byte, 32)
encrypted := make([]byte, 48)
decrypted := make([]byte, 48)
c.EncryptSector(encrypted, plaintext, 0)
c.DecryptSector(decrypted, encrypted[:len(plaintext)], 0)
encrypter.CryptBlocks(encrypted, plaintext)
decrypter.CryptBlocks(decrypted, encrypted[:len(plaintext)])
if !bytes.Equal(plaintext, decrypted[:len(plaintext)]) {
t.Errorf("En/Decryption is not inverse")

20
cipher/xts_tweak_test.go Normal file
View File

@ -0,0 +1,20 @@
package cipher
import (
"crypto/aes"
"testing"
)
func BenchmarkDoubleTweak(b *testing.B) {
var tweak [16]byte
block, err := aes.NewCipher(make([]byte, 16))
if err != nil {
b.Failed()
}
block.Encrypt(tweak[:], tweak[:])
b.ResetTimer()
for i := 0; i < b.N; i++ {
mul2(&tweak, false)
}
}