mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-27 20:56:18 +08:00
MAGIC - refactor
This commit is contained in:
parent
e3b5c05ec0
commit
a1cb0a2616
106
sm2/sm2.go
106
sm2/sm2.go
@ -47,6 +47,35 @@ type ecdsaSignature struct {
|
|||||||
R, S *big.Int
|
R, S *big.Int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type pointMarshalMode byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
//MarshalUncompressed uncompressed mashal mode
|
||||||
|
MarshalUncompressed pointMarshalMode = iota
|
||||||
|
//MarshalCompressed compressed mashal mode
|
||||||
|
MarshalCompressed
|
||||||
|
//MarshalMixed mixed mashal mode
|
||||||
|
MarshalMixed
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncrypterOpts encryption options
|
||||||
|
type EncrypterOpts struct {
|
||||||
|
PointMarshalMode pointMarshalMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||||
|
switch mode {
|
||||||
|
case MarshalCompressed:
|
||||||
|
return point2CompressedBytes(curve, x, y)
|
||||||
|
case MarshalMixed:
|
||||||
|
return point2MixedBytes(curve, x, y)
|
||||||
|
default:
|
||||||
|
return point2UncompressedBytes(curve, x, y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultEncrypterOpts = EncrypterOpts{MarshalUncompressed}
|
||||||
|
|
||||||
// Sign signs digest with priv, reading randomness from rand. The opts argument
|
// Sign signs digest with priv, reading randomness from rand. The opts argument
|
||||||
// is not currently used but, in keeping with the crypto.Signer interface,
|
// is not currently used but, in keeping with the crypto.Signer interface,
|
||||||
// should be the hash function used to digest the message.
|
// should be the hash function used to digest the message.
|
||||||
@ -73,6 +102,12 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er
|
|||||||
return asn1.Marshal(ecdsaSignature{r, s})
|
return asn1.Marshal(ecdsaSignature{r, s})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Decrypt decrypts msg. The opts argument should be appropriate for
|
||||||
|
// the primitive used.
|
||||||
|
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
|
||||||
|
return Decrypt(priv, msg)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
one = new(big.Int).SetInt64(1)
|
one = new(big.Int).SetInt64(1)
|
||||||
initonce sync.Once
|
initonce sync.Once
|
||||||
@ -106,17 +141,17 @@ func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error)
|
|||||||
///////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////
|
||||||
func kdf(z []byte, len int) ([]byte, bool) {
|
func kdf(z []byte, len int) ([]byte, bool) {
|
||||||
limit := (len + sm3.Size - 1) >> sm3.SizeBitSize
|
limit := (len + sm3.Size - 1) >> sm3.SizeBitSize
|
||||||
sm3Hasher := sm3.New()
|
md := sm3.New()
|
||||||
var countBytes [4]byte
|
var countBytes [4]byte
|
||||||
var ct uint32 = 1
|
var ct uint32 = 1
|
||||||
k := make([]byte, len+sm3.Size-1)
|
k := make([]byte, len+sm3.Size-1)
|
||||||
for i := 0; i < limit; i++ {
|
for i := 0; i < limit; i++ {
|
||||||
binary.BigEndian.PutUint32(countBytes[:], ct)
|
binary.BigEndian.PutUint32(countBytes[:], ct)
|
||||||
sm3Hasher.Write(z)
|
md.Write(z)
|
||||||
sm3Hasher.Write(countBytes[:])
|
md.Write(countBytes[:])
|
||||||
copy(k[i*sm3.Size:], sm3Hasher.Sum(nil))
|
copy(k[i*sm3.Size:], md.Sum(nil))
|
||||||
ct++
|
ct++
|
||||||
sm3Hasher.Reset()
|
md.Reset()
|
||||||
}
|
}
|
||||||
for i := 0; i < len; i++ {
|
for i := 0; i < len; i++ {
|
||||||
if k[i] != 0 {
|
if k[i] != 0 {
|
||||||
@ -127,17 +162,20 @@ func kdf(z []byte, len int) ([]byte, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte {
|
func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte {
|
||||||
hasher := sm3.New()
|
md := sm3.New()
|
||||||
hasher.Write(toBytes(curve, x2))
|
md.Write(toBytes(curve, x2))
|
||||||
hasher.Write(msg)
|
md.Write(msg)
|
||||||
hasher.Write(toBytes(curve, y2))
|
md.Write(toBytes(curve, y2))
|
||||||
return hasher.Sum(nil)
|
return md.Sum(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt sm2 encrypt implementation
|
// Encrypt sm2 encrypt implementation
|
||||||
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) {
|
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) {
|
||||||
curve := pub.Curve
|
curve := pub.Curve
|
||||||
msgLen := len(msg)
|
msgLen := len(msg)
|
||||||
|
if opts == nil {
|
||||||
|
opts = &defaultEncrypterOpts
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
//A1, generate random k
|
//A1, generate random k
|
||||||
k, err := randFieldElement(curve, random)
|
k, err := randFieldElement(curve, random)
|
||||||
@ -147,7 +185,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error)
|
|||||||
|
|
||||||
//A2, calculate C1 = k * G
|
//A2, calculate C1 = k * G
|
||||||
x1, y1 := curve.ScalarBaseMult(k.Bytes())
|
x1, y1 := curve.ScalarBaseMult(k.Bytes())
|
||||||
c1 := point2UncompressedBytes(curve, x1, y1)
|
c1 := opts.PointMarshalMode.mashal(curve, x1, y1)
|
||||||
|
|
||||||
//A3, skipped
|
//A3, skipped
|
||||||
//A4, calculate k * P (point of Public Key)
|
//A4, calculate k * P (point of Public Key)
|
||||||
@ -362,26 +400,26 @@ func Sign(rand io.Reader, priv *ecdsa.PrivateKey, hash []byte) (r, s *big.Int, e
|
|||||||
|
|
||||||
var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
|
var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
|
||||||
|
|
||||||
// CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
|
// calculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
|
||||||
func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) {
|
func calculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) {
|
||||||
uidLen := len(uid)
|
uidLen := len(uid)
|
||||||
if uidLen >= 0x2000 {
|
if uidLen >= 0x2000 {
|
||||||
return nil, errors.New("the uid is too long")
|
return nil, errors.New("the uid is too long")
|
||||||
}
|
}
|
||||||
entla := uint16(uidLen) << 3
|
entla := uint16(uidLen) << 3
|
||||||
hasher := sm3.New()
|
md := sm3.New()
|
||||||
hasher.Write([]byte{byte(entla >> 8), byte(entla)})
|
md.Write([]byte{byte(entla >> 8), byte(entla)})
|
||||||
if uidLen > 0 {
|
if uidLen > 0 {
|
||||||
hasher.Write(uid)
|
md.Write(uid)
|
||||||
}
|
}
|
||||||
a := new(big.Int).Sub(pub.Params().P, big.NewInt(3))
|
a := new(big.Int).Sub(pub.Params().P, big.NewInt(3))
|
||||||
hasher.Write(toBytes(pub.Curve, a))
|
md.Write(toBytes(pub.Curve, a))
|
||||||
hasher.Write(toBytes(pub.Curve, pub.Params().B))
|
md.Write(toBytes(pub.Curve, pub.Params().B))
|
||||||
hasher.Write(toBytes(pub.Curve, pub.Params().Gx))
|
md.Write(toBytes(pub.Curve, pub.Params().Gx))
|
||||||
hasher.Write(toBytes(pub.Curve, pub.Params().Gy))
|
md.Write(toBytes(pub.Curve, pub.Params().Gy))
|
||||||
hasher.Write(toBytes(pub.Curve, pub.X))
|
md.Write(toBytes(pub.Curve, pub.X))
|
||||||
hasher.Write(toBytes(pub.Curve, pub.Y))
|
md.Write(toBytes(pub.Curve, pub.Y))
|
||||||
return hasher.Sum(nil), nil
|
return md.Sum(nil), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignWithSM2 follow sm2 dsa standards for hash part
|
// SignWithSM2 follow sm2 dsa standards for hash part
|
||||||
@ -389,15 +427,15 @@ func SignWithSM2(rand io.Reader, priv *ecdsa.PrivateKey, uid, msg []byte) (r, s
|
|||||||
if len(uid) == 0 {
|
if len(uid) == 0 {
|
||||||
uid = defaultUID
|
uid = defaultUID
|
||||||
}
|
}
|
||||||
za, err := CalculateZA(&priv.PublicKey, uid)
|
za, err := calculateZA(&priv.PublicKey, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
hasher := sm3.New()
|
md := sm3.New()
|
||||||
hasher.Write(za)
|
md.Write(za)
|
||||||
hasher.Write(msg)
|
md.Write(msg)
|
||||||
|
|
||||||
return Sign(rand, priv, hasher.Sum(nil))
|
return Sign(rand, priv, md.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify verifies the signature in r, s of hash using the public key, pub. Its
|
// Verify verifies the signature in r, s of hash using the public key, pub. Its
|
||||||
@ -442,14 +480,14 @@ func VerifyWithSM2(pub *ecdsa.PublicKey, uid, msg []byte, r, s *big.Int) bool {
|
|||||||
if len(uid) == 0 {
|
if len(uid) == 0 {
|
||||||
uid = defaultUID
|
uid = defaultUID
|
||||||
}
|
}
|
||||||
za, err := CalculateZA(pub, uid)
|
za, err := calculateZA(pub, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
hasher := sm3.New()
|
md := sm3.New()
|
||||||
hasher.Write(za)
|
md.Write(za)
|
||||||
hasher.Write(msg)
|
md.Write(msg)
|
||||||
return Verify(pub, hasher.Sum(nil), r, s)
|
return Verify(pub, md.Sum(nil), r, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
type zr struct {
|
type zr struct {
|
||||||
|
@ -43,7 +43,7 @@ func Test_encryptDecrypt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText))
|
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("encrypt failed %v", err)
|
t.Fatalf("encrypt failed %v", err)
|
||||||
}
|
}
|
||||||
@ -54,6 +54,33 @@ func Test_encryptDecrypt(t *testing.T) {
|
|||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
}
|
}
|
||||||
|
// compress mode
|
||||||
|
encrypterOpts := EncrypterOpts{MarshalCompressed}
|
||||||
|
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err = Decrypt(priv, ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mixed mode
|
||||||
|
encrypterOpts = EncrypterOpts{MarshalMixed}
|
||||||
|
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err = Decrypt(priv, ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -87,7 +114,7 @@ func Test_signVerify(t *testing.T) {
|
|||||||
func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext string) {
|
func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext string) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
priv, _ := ecdsa.GenerateKey(curve, rand.Reader)
|
priv, _ := ecdsa.GenerateKey(curve, rand.Reader)
|
||||||
Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext))
|
Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ func TestParsePKIXPublicKey(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
pub1 := pub.(*ecdsa.PublicKey)
|
pub1 := pub.(*ecdsa.PublicKey)
|
||||||
encrypted, err := sm2.Encrypt(rand.Reader, pub1, []byte("testfile"))
|
encrypted, err := sm2.Encrypt(rand.Reader, pub1, []byte("testfile"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user