mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-27 04:36:19 +08:00
MAGIC - refactor
This commit is contained in:
parent
e3b5c05ec0
commit
a1cb0a2616
106
sm2/sm2.go
106
sm2/sm2.go
@ -47,6 +47,35 @@ type ecdsaSignature struct {
|
||||
R, S *big.Int
|
||||
}
|
||||
|
||||
type pointMarshalMode byte
|
||||
|
||||
const (
|
||||
//MarshalUncompressed uncompressed mashal mode
|
||||
MarshalUncompressed pointMarshalMode = iota
|
||||
//MarshalCompressed compressed mashal mode
|
||||
MarshalCompressed
|
||||
//MarshalMixed mixed mashal mode
|
||||
MarshalMixed
|
||||
)
|
||||
|
||||
// EncrypterOpts encryption options
|
||||
type EncrypterOpts struct {
|
||||
PointMarshalMode pointMarshalMode
|
||||
}
|
||||
|
||||
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||
switch mode {
|
||||
case MarshalCompressed:
|
||||
return point2CompressedBytes(curve, x, y)
|
||||
case MarshalMixed:
|
||||
return point2MixedBytes(curve, x, y)
|
||||
default:
|
||||
return point2UncompressedBytes(curve, x, y)
|
||||
}
|
||||
}
|
||||
|
||||
var defaultEncrypterOpts = EncrypterOpts{MarshalUncompressed}
|
||||
|
||||
// Sign signs digest with priv, reading randomness from rand. The opts argument
|
||||
// is not currently used but, in keeping with the crypto.Signer interface,
|
||||
// should be the hash function used to digest the message.
|
||||
@ -73,6 +102,12 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er
|
||||
return asn1.Marshal(ecdsaSignature{r, s})
|
||||
}
|
||||
|
||||
// Decrypt decrypts msg. The opts argument should be appropriate for
|
||||
// the primitive used.
|
||||
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
|
||||
return Decrypt(priv, msg)
|
||||
}
|
||||
|
||||
var (
|
||||
one = new(big.Int).SetInt64(1)
|
||||
initonce sync.Once
|
||||
@ -106,17 +141,17 @@ func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error)
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
func kdf(z []byte, len int) ([]byte, bool) {
|
||||
limit := (len + sm3.Size - 1) >> sm3.SizeBitSize
|
||||
sm3Hasher := sm3.New()
|
||||
md := sm3.New()
|
||||
var countBytes [4]byte
|
||||
var ct uint32 = 1
|
||||
k := make([]byte, len+sm3.Size-1)
|
||||
for i := 0; i < limit; i++ {
|
||||
binary.BigEndian.PutUint32(countBytes[:], ct)
|
||||
sm3Hasher.Write(z)
|
||||
sm3Hasher.Write(countBytes[:])
|
||||
copy(k[i*sm3.Size:], sm3Hasher.Sum(nil))
|
||||
md.Write(z)
|
||||
md.Write(countBytes[:])
|
||||
copy(k[i*sm3.Size:], md.Sum(nil))
|
||||
ct++
|
||||
sm3Hasher.Reset()
|
||||
md.Reset()
|
||||
}
|
||||
for i := 0; i < len; i++ {
|
||||
if k[i] != 0 {
|
||||
@ -127,17 +162,20 @@ func kdf(z []byte, len int) ([]byte, bool) {
|
||||
}
|
||||
|
||||
func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte {
|
||||
hasher := sm3.New()
|
||||
hasher.Write(toBytes(curve, x2))
|
||||
hasher.Write(msg)
|
||||
hasher.Write(toBytes(curve, y2))
|
||||
return hasher.Sum(nil)
|
||||
md := sm3.New()
|
||||
md.Write(toBytes(curve, x2))
|
||||
md.Write(msg)
|
||||
md.Write(toBytes(curve, y2))
|
||||
return md.Sum(nil)
|
||||
}
|
||||
|
||||
// Encrypt sm2 encrypt implementation
|
||||
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) {
|
||||
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) {
|
||||
curve := pub.Curve
|
||||
msgLen := len(msg)
|
||||
if opts == nil {
|
||||
opts = &defaultEncrypterOpts
|
||||
}
|
||||
for {
|
||||
//A1, generate random k
|
||||
k, err := randFieldElement(curve, random)
|
||||
@ -147,7 +185,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error)
|
||||
|
||||
//A2, calculate C1 = k * G
|
||||
x1, y1 := curve.ScalarBaseMult(k.Bytes())
|
||||
c1 := point2UncompressedBytes(curve, x1, y1)
|
||||
c1 := opts.PointMarshalMode.mashal(curve, x1, y1)
|
||||
|
||||
//A3, skipped
|
||||
//A4, calculate k * P (point of Public Key)
|
||||
@ -362,26 +400,26 @@ func Sign(rand io.Reader, priv *ecdsa.PrivateKey, hash []byte) (r, s *big.Int, e
|
||||
|
||||
var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
|
||||
|
||||
// CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
|
||||
func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) {
|
||||
// calculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
|
||||
func calculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) {
|
||||
uidLen := len(uid)
|
||||
if uidLen >= 0x2000 {
|
||||
return nil, errors.New("the uid is too long")
|
||||
}
|
||||
entla := uint16(uidLen) << 3
|
||||
hasher := sm3.New()
|
||||
hasher.Write([]byte{byte(entla >> 8), byte(entla)})
|
||||
md := sm3.New()
|
||||
md.Write([]byte{byte(entla >> 8), byte(entla)})
|
||||
if uidLen > 0 {
|
||||
hasher.Write(uid)
|
||||
md.Write(uid)
|
||||
}
|
||||
a := new(big.Int).Sub(pub.Params().P, big.NewInt(3))
|
||||
hasher.Write(toBytes(pub.Curve, a))
|
||||
hasher.Write(toBytes(pub.Curve, pub.Params().B))
|
||||
hasher.Write(toBytes(pub.Curve, pub.Params().Gx))
|
||||
hasher.Write(toBytes(pub.Curve, pub.Params().Gy))
|
||||
hasher.Write(toBytes(pub.Curve, pub.X))
|
||||
hasher.Write(toBytes(pub.Curve, pub.Y))
|
||||
return hasher.Sum(nil), nil
|
||||
md.Write(toBytes(pub.Curve, a))
|
||||
md.Write(toBytes(pub.Curve, pub.Params().B))
|
||||
md.Write(toBytes(pub.Curve, pub.Params().Gx))
|
||||
md.Write(toBytes(pub.Curve, pub.Params().Gy))
|
||||
md.Write(toBytes(pub.Curve, pub.X))
|
||||
md.Write(toBytes(pub.Curve, pub.Y))
|
||||
return md.Sum(nil), nil
|
||||
}
|
||||
|
||||
// SignWithSM2 follow sm2 dsa standards for hash part
|
||||
@ -389,15 +427,15 @@ func SignWithSM2(rand io.Reader, priv *ecdsa.PrivateKey, uid, msg []byte) (r, s
|
||||
if len(uid) == 0 {
|
||||
uid = defaultUID
|
||||
}
|
||||
za, err := CalculateZA(&priv.PublicKey, uid)
|
||||
za, err := calculateZA(&priv.PublicKey, uid)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
hasher := sm3.New()
|
||||
hasher.Write(za)
|
||||
hasher.Write(msg)
|
||||
md := sm3.New()
|
||||
md.Write(za)
|
||||
md.Write(msg)
|
||||
|
||||
return Sign(rand, priv, hasher.Sum(nil))
|
||||
return Sign(rand, priv, md.Sum(nil))
|
||||
}
|
||||
|
||||
// Verify verifies the signature in r, s of hash using the public key, pub. Its
|
||||
@ -442,14 +480,14 @@ func VerifyWithSM2(pub *ecdsa.PublicKey, uid, msg []byte, r, s *big.Int) bool {
|
||||
if len(uid) == 0 {
|
||||
uid = defaultUID
|
||||
}
|
||||
za, err := CalculateZA(pub, uid)
|
||||
za, err := calculateZA(pub, uid)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
hasher := sm3.New()
|
||||
hasher.Write(za)
|
||||
hasher.Write(msg)
|
||||
return Verify(pub, hasher.Sum(nil), r, s)
|
||||
md := sm3.New()
|
||||
md.Write(za)
|
||||
md.Write(msg)
|
||||
return Verify(pub, md.Sum(nil), r, s)
|
||||
}
|
||||
|
||||
type zr struct {
|
||||
|
@ -43,7 +43,7 @@ func Test_encryptDecrypt(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText))
|
||||
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt failed %v", err)
|
||||
}
|
||||
@ -54,6 +54,33 @@ func Test_encryptDecrypt(t *testing.T) {
|
||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||
}
|
||||
// compress mode
|
||||
encrypterOpts := EncrypterOpts{MarshalCompressed}
|
||||
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt failed %v", err)
|
||||
}
|
||||
plaintext, err = Decrypt(priv, ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt failed %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||
}
|
||||
|
||||
// mixed mode
|
||||
encrypterOpts = EncrypterOpts{MarshalMixed}
|
||||
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt failed %v", err)
|
||||
}
|
||||
plaintext, err = Decrypt(priv, ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt failed %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -87,7 +114,7 @@ func Test_signVerify(t *testing.T) {
|
||||
func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext string) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
priv, _ := ecdsa.GenerateKey(curve, rand.Reader)
|
||||
Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext))
|
||||
Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,7 +118,7 @@ func TestParsePKIXPublicKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pub1 := pub.(*ecdsa.PublicKey)
|
||||
encrypted, err := sm2.Encrypt(rand.Reader, pub1, []byte("testfile"))
|
||||
encrypted, err := sm2.Encrypt(rand.Reader, pub1, []byte("testfile"), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user