MAGIC - refactor

This commit is contained in:
emmansun 2021-02-15 10:36:28 +08:00
parent e3b5c05ec0
commit a1cb0a2616
3 changed files with 102 additions and 37 deletions

View File

@ -47,6 +47,35 @@ type ecdsaSignature struct {
R, S *big.Int
}
type pointMarshalMode byte
const (
//MarshalUncompressed uncompressed mashal mode
MarshalUncompressed pointMarshalMode = iota
//MarshalCompressed compressed mashal mode
MarshalCompressed
//MarshalMixed mixed mashal mode
MarshalMixed
)
// EncrypterOpts encryption options
type EncrypterOpts struct {
PointMarshalMode pointMarshalMode
}
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
switch mode {
case MarshalCompressed:
return point2CompressedBytes(curve, x, y)
case MarshalMixed:
return point2MixedBytes(curve, x, y)
default:
return point2UncompressedBytes(curve, x, y)
}
}
var defaultEncrypterOpts = EncrypterOpts{MarshalUncompressed}
// Sign signs digest with priv, reading randomness from rand. The opts argument
// is not currently used but, in keeping with the crypto.Signer interface,
// should be the hash function used to digest the message.
@ -73,6 +102,12 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er
return asn1.Marshal(ecdsaSignature{r, s})
}
// Decrypt decrypts msg. The opts argument should be appropriate for
// the primitive used.
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
return Decrypt(priv, msg)
}
var (
one = new(big.Int).SetInt64(1)
initonce sync.Once
@ -106,17 +141,17 @@ func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error)
///////////////////////////////////////////////////////////////////////////////////
func kdf(z []byte, len int) ([]byte, bool) {
limit := (len + sm3.Size - 1) >> sm3.SizeBitSize
sm3Hasher := sm3.New()
md := sm3.New()
var countBytes [4]byte
var ct uint32 = 1
k := make([]byte, len+sm3.Size-1)
for i := 0; i < limit; i++ {
binary.BigEndian.PutUint32(countBytes[:], ct)
sm3Hasher.Write(z)
sm3Hasher.Write(countBytes[:])
copy(k[i*sm3.Size:], sm3Hasher.Sum(nil))
md.Write(z)
md.Write(countBytes[:])
copy(k[i*sm3.Size:], md.Sum(nil))
ct++
sm3Hasher.Reset()
md.Reset()
}
for i := 0; i < len; i++ {
if k[i] != 0 {
@ -127,17 +162,20 @@ func kdf(z []byte, len int) ([]byte, bool) {
}
func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte {
hasher := sm3.New()
hasher.Write(toBytes(curve, x2))
hasher.Write(msg)
hasher.Write(toBytes(curve, y2))
return hasher.Sum(nil)
md := sm3.New()
md.Write(toBytes(curve, x2))
md.Write(msg)
md.Write(toBytes(curve, y2))
return md.Sum(nil)
}
// Encrypt sm2 encrypt implementation
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) {
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) {
curve := pub.Curve
msgLen := len(msg)
if opts == nil {
opts = &defaultEncrypterOpts
}
for {
//A1, generate random k
k, err := randFieldElement(curve, random)
@ -147,7 +185,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error)
//A2, calculate C1 = k * G
x1, y1 := curve.ScalarBaseMult(k.Bytes())
c1 := point2UncompressedBytes(curve, x1, y1)
c1 := opts.PointMarshalMode.mashal(curve, x1, y1)
//A3, skipped
//A4, calculate k * P (point of Public Key)
@ -362,26 +400,26 @@ func Sign(rand io.Reader, priv *ecdsa.PrivateKey, hash []byte) (r, s *big.Int, e
var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
// CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) {
// calculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
func calculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) {
uidLen := len(uid)
if uidLen >= 0x2000 {
return nil, errors.New("the uid is too long")
}
entla := uint16(uidLen) << 3
hasher := sm3.New()
hasher.Write([]byte{byte(entla >> 8), byte(entla)})
md := sm3.New()
md.Write([]byte{byte(entla >> 8), byte(entla)})
if uidLen > 0 {
hasher.Write(uid)
md.Write(uid)
}
a := new(big.Int).Sub(pub.Params().P, big.NewInt(3))
hasher.Write(toBytes(pub.Curve, a))
hasher.Write(toBytes(pub.Curve, pub.Params().B))
hasher.Write(toBytes(pub.Curve, pub.Params().Gx))
hasher.Write(toBytes(pub.Curve, pub.Params().Gy))
hasher.Write(toBytes(pub.Curve, pub.X))
hasher.Write(toBytes(pub.Curve, pub.Y))
return hasher.Sum(nil), nil
md.Write(toBytes(pub.Curve, a))
md.Write(toBytes(pub.Curve, pub.Params().B))
md.Write(toBytes(pub.Curve, pub.Params().Gx))
md.Write(toBytes(pub.Curve, pub.Params().Gy))
md.Write(toBytes(pub.Curve, pub.X))
md.Write(toBytes(pub.Curve, pub.Y))
return md.Sum(nil), nil
}
// SignWithSM2 follow sm2 dsa standards for hash part
@ -389,15 +427,15 @@ func SignWithSM2(rand io.Reader, priv *ecdsa.PrivateKey, uid, msg []byte) (r, s
if len(uid) == 0 {
uid = defaultUID
}
za, err := CalculateZA(&priv.PublicKey, uid)
za, err := calculateZA(&priv.PublicKey, uid)
if err != nil {
return nil, nil, err
}
hasher := sm3.New()
hasher.Write(za)
hasher.Write(msg)
md := sm3.New()
md.Write(za)
md.Write(msg)
return Sign(rand, priv, hasher.Sum(nil))
return Sign(rand, priv, md.Sum(nil))
}
// Verify verifies the signature in r, s of hash using the public key, pub. Its
@ -442,14 +480,14 @@ func VerifyWithSM2(pub *ecdsa.PublicKey, uid, msg []byte, r, s *big.Int) bool {
if len(uid) == 0 {
uid = defaultUID
}
za, err := CalculateZA(pub, uid)
za, err := calculateZA(pub, uid)
if err != nil {
return false
}
hasher := sm3.New()
hasher.Write(za)
hasher.Write(msg)
return Verify(pub, hasher.Sum(nil), r, s)
md := sm3.New()
md.Write(za)
md.Write(msg)
return Verify(pub, md.Sum(nil), r, s)
}
type zr struct {

View File

@ -43,7 +43,7 @@ func Test_encryptDecrypt(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText))
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
@ -54,6 +54,33 @@ func Test_encryptDecrypt(t *testing.T) {
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
// compress mode
encrypterOpts := EncrypterOpts{MarshalCompressed}
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err = Decrypt(priv, ciphertext)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
// mixed mode
encrypterOpts = EncrypterOpts{MarshalMixed}
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err = Decrypt(priv, ciphertext)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
@ -87,7 +114,7 @@ func Test_signVerify(t *testing.T) {
func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext string) {
for i := 0; i < b.N; i++ {
priv, _ := ecdsa.GenerateKey(curve, rand.Reader)
Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext))
Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil)
}
}

View File

@ -118,7 +118,7 @@ func TestParsePKIXPublicKey(t *testing.T) {
t.Fatal(err)
}
pub1 := pub.(*ecdsa.PublicKey)
encrypted, err := sm2.Encrypt(rand.Reader, pub1, []byte("testfile"))
encrypted, err := sm2.Encrypt(rand.Reader, pub1, []byte("testfile"), nil)
if err != nil {
t.Fatal(err)
}