diff --git a/cipher/stream.go b/cipher/stream.go new file mode 100644 index 0000000..719c9e1 --- /dev/null +++ b/cipher/stream.go @@ -0,0 +1,13 @@ +// Copyright 2024 Sun Yimin. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package cipher + +import "crypto/cipher" + +type SeekableStream interface { + cipher.Stream + // XORKeyStreamAt XORs the given src with the key stream at the given offset and writes the result to dst. + XORKeyStreamAt(dst, src []byte, offset uint64) +} diff --git a/zuc/eea.go b/zuc/eea.go index ea01b58..d679212 100644 --- a/zuc/eea.go +++ b/zuc/eea.go @@ -1,8 +1,7 @@ package zuc import ( - "crypto/cipher" - + "github.com/emmansun/gmsm/cipher" "github.com/emmansun/gmsm/internal/alias" "github.com/emmansun/gmsm/internal/byteorder" "github.com/emmansun/gmsm/internal/subtle" @@ -12,41 +11,39 @@ const RoundWords = 32 type eea struct { zucState32 - x [4]byte // remaining bytes buffer - xLen int // number of remaining bytes + x [4]byte // remaining bytes buffer + xLen int // number of remaining bytes + initState zucState32 + used uint64 } // NewCipher create a stream cipher based on key and iv aguments. // The key must be 16 bytes long and iv must be 16 bytes long for zuc 128; // or the key must be 32 bytes long and iv must be 23 bytes long for zuc 256; // otherwise, an error will be returned. -func NewCipher(key, iv []byte) (cipher.Stream, error) { +func NewCipher(key, iv []byte) (cipher.SeekableStream, error) { s, err := newZUCState(key, iv) if err != nil { return nil, err } c := new(eea) c.zucState32 = *s + c.initState = *s + c.used = 0 return c, nil } // NewEEACipher create a stream cipher based on key, count, bearer and direction arguments according specification. // The key must be 16 bytes long and iv must be 16 bytes long, otherwise, an error will be returned. -// The count is the 32-bit counter value, the bearer is the 5-bit bearer identity and the direction is the 1-bit +// The count is the 32-bit counter value, the bearer is the 5-bit bearer identity and the direction is the 1-bit // transmission direction flag. -func NewEEACipher(key []byte, count, bearer, direction uint32) (cipher.Stream, error) { +func NewEEACipher(key []byte, count, bearer, direction uint32) (cipher.SeekableStream, error) { iv := make([]byte, 16) byteorder.BEPutUint32(iv, count) copy(iv[8:12], iv[:4]) iv[4] = byte(((bearer << 1) | (direction & 1)) << 2) iv[12] = iv[4] - s, err := newZUCState(key, iv) - if err != nil { - return nil, err - } - c := new(eea) - c.zucState32 = *s - return c, nil + return NewCipher(key, iv) } func genKeyStreamRev32Generic(keyStream []byte, pState *zucState32) { @@ -64,6 +61,7 @@ func (c *eea) XORKeyStream(dst, src []byte) { if alias.InexactOverlap(dst[:len(src)], src) { panic("zuc: invalid buffer overlap") } + used := len(src) if c.xLen > 0 { // handle remaining key bytes n := subtle.XORBytes(dst, src, c.x[:c.xLen]) @@ -72,6 +70,7 @@ func (c *eea) XORKeyStream(dst, src []byte) { src = src[n:] if c.xLen > 0 { copy(c.x[:], c.x[n:c.xLen+n]) + c.used += uint64(used) return } } @@ -94,4 +93,61 @@ func (c *eea) XORKeyStream(dst, src []byte) { copy(c.x[:], keyBytes[n:byteLen]) } } + c.used += uint64(used) +} + +func (c *eea) reset() { + c.zucState32 = c.initState + c.xLen = 0 + c.used = 0 +} + +func (c *eea) XORKeyStreamAt(dst, src []byte, offset uint64) { + if len(dst) < len(src) { + panic("zuc: output smaller than input") + } + if alias.InexactOverlap(dst[:len(src)], src) { + panic("zuc: invalid buffer overlap") + } + if offset < c.used { + c.reset() + } else if offset == c.used { + c.XORKeyStream(dst, src) + return + } + + diff := offset - c.used + if diff <= uint64(c.xLen) { + c.xLen -= int(diff) + c.used += diff + c.XORKeyStream(dst, src) + return + } + + // forward the state to the offset + // this part can be optimized by a little bit + stepLen := uint64(RoundWords * 4) + var keys [RoundWords]uint32 + for ; diff >= uint64(stepLen); diff -= stepLen { + genKeyStream(keys[:], &c.zucState32) + c.used += stepLen + } + + // handle remaining key bytes + if diff > 0 { + limit := (diff + 3) / 4 + remaining := int(diff % 4) + genKeyStream(keys[:limit], &c.zucState32) + c.used += limit * 4 + if remaining > 0 { + var keyBytes [4]byte + c.used -= 4 + c.xLen = 4 - remaining + if c.xLen > 0 { + byteorder.BEPutUint32(keyBytes[:], keys[limit-1]) + copy(c.x[:], keyBytes[remaining:]) + } + } + } + c.XORKeyStream(dst, src) } diff --git a/zuc/eea_test.go b/zuc/eea_test.go index 514abde..fb747e6 100644 --- a/zuc/eea_test.go +++ b/zuc/eea_test.go @@ -1,6 +1,7 @@ package zuc import ( + "bytes" "crypto/cipher" "encoding/hex" "testing" @@ -73,6 +74,53 @@ func TestEEAStream(t *testing.T) { }) } +func TestXORStreamAt(t *testing.T) { + key, err := hex.DecodeString(zucEEATests[0].key) + if err != nil { + t.Error(err) + } + c, err := NewEEACipher(key, zucEEATests[0].count, zucEEATests[0].bearer, zucEEATests[0].direction) + if err != nil { + t.Error(err) + } + src1 := make([]byte, 1000) + dst1 := make([]byte, 1000) + src2 := make([]byte, 1000) + dst2 := make([]byte, 1000) + + c.XORKeyStream(dst1, src1) + for i := 0; i < 65; i++ { + c.XORKeyStreamAt(dst2[i:], src2[i:], uint64(i)) + if !bytes.Equal(dst1[i:], dst2[i:]) { + t.Errorf("At %d, expected=%x, result=%x\n", i, dst1[i:], dst2[i:]) + } + } + + // test used == offset case + c.XORKeyStreamAt(dst2[:16], src2[:16], 0) + c.XORKeyStreamAt(dst2[16:32], src2[16:32], 16) + if !bytes.Equal(dst2[:32], dst1[:32]) { + t.Errorf("expected=%x, result=%x\n", dst1[:32], dst2[:32]) + } + + // test offset - used > 128 bytes case + c.XORKeyStreamAt(dst2[:16], src2[:16], 0) + offset := 700 + c.XORKeyStreamAt(dst2[offset:], src2[offset:], uint64(offset)) + if !bytes.Equal(dst2[offset:], dst1[offset:]) { + t.Errorf("expected=%x, result=%x\n", dst1[offset:], dst2[offset:]) + } + + // XORKeyStreamAt with XORKeyStream + c.XORKeyStreamAt(dst2[:16], src2[:16], 0) + c.XORKeyStream(dst2[16:32], src2[16:32]) + c.XORKeyStreamAt(dst2[32:64], src2[32:64], 32) + c.XORKeyStream(dst2[64:128], src2[64:128]) + if !bytes.Equal(dst2[:128], dst1[:128]) { + t.Errorf("expected=%x, result=%x\n", dst1[:128], dst2[:128]) + } +} + func benchmarkStream(b *testing.B, buf []byte) { b.SetBytes(int64(len(buf)))