MAGIC - sm2, basic implementation

This commit is contained in:
Emman 2020-12-16 16:27:36 +08:00
parent 4d7305a6f6
commit be62e3a042
5 changed files with 414 additions and 2 deletions

156
sm2/sm2.go Normal file
View File

@ -0,0 +1,156 @@
package sm2
import (
"crypto/ecdsa"
"crypto/elliptic"
"encoding/binary"
"errors"
"fmt"
"gmsm/sm3"
"io"
"math/big"
)
const (
Uncompressed byte = 0x04
Compressed_02 byte = 0x02
Compressed_03 byte = 0x03
Mixed_06 byte = 0x06
Mixed_07 byte = 0x07
)
///////////////// below code ship from golan crypto/ecdsa ////////////////////
var one = new(big.Int).SetInt64(1)
// randFieldElement returns a random element of the field underlying the given
// curve using the procedure given in [NSA] A.2.1.
func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
params := c.Params()
b := make([]byte, params.BitSize/8+8)
_, err = io.ReadFull(rand, b)
if err != nil {
return
}
k = new(big.Int).SetBytes(b)
n := new(big.Int).Sub(params.N, one)
k.Mod(k, n)
k.Add(k, one)
return
}
///////////////////////////////////////////////////////////////////////////////////
func kdf(z []byte, len int) ([]byte, bool) {
limit := (len + sm3.Size - 1) / sm3.Size
sm3Hasher := sm3.New()
var countBytes [4]byte
var ct uint32 = 1
k := make([]byte, len+sm3.Size-1)
for i := 0; i < limit; i++ {
binary.BigEndian.PutUint32(countBytes[:], ct)
sm3Hasher.Write(z)
sm3Hasher.Write(countBytes[:])
copy(k[i*sm3.Size:], sm3Hasher.Sum(nil))
ct++
sm3Hasher.Reset()
}
for i := 0; i < len; i++ {
if k[i] != 0 {
return k[:len], true
}
}
return k, false
}
func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte {
hasher := sm3.New()
hasher.Write(toBytes(curve, x2))
hasher.Write(msg)
hasher.Write(toBytes(curve, y2))
return hasher.Sum(nil)
}
// Encrypt sm2 encrypt implementation
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) {
curve := pub.Curve
msgLen := len(msg)
for {
//A1, generate random k
k, err := randFieldElement(curve, random)
if err != nil {
return nil, err
}
//A2, calculate C1 = k * G
x1, y1 := curve.ScalarBaseMult(k.Bytes())
c1 := point2CompressedBytes(curve, x1, y1)
//A3, skipped
//A4, calculate k * P (point of Public Key)
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
//A5, calculate t=KDF(x2||y2, klen)
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if !success {
fmt.Println("A5, failed to get valid t")
continue
}
//A6, C2 = M + t;
c2 := make([]byte, msgLen)
for i := 0; i < msgLen; i++ {
c2[i] = msg[i] ^ t[i]
}
//A7, C3 = hash(x2||M||y2)
c3 := calculateC3(curve, x2, y2, msg)
return append(append(c1, c2...), c3...), nil
}
}
// Decrypt sm2 decrypt implementation
func Decrypt(priv *ecdsa.PrivateKey, ciphertext []byte) ([]byte, error) {
ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
return nil, errors.New("invalid ciphertext length")
}
curve := priv.Curve
// B1, get C1, and check C1
x1, y1, c2Start, err := bytes2Point(curve, ciphertext)
if err != nil {
return nil, err
}
if !curve.IsOnCurve(x1, y1) {
return nil, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name)
}
//B2 is ignored
//B3, calculate x2, y2
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
//B4, calculate t=KDF(x2||y2, klen)
c2 := ciphertext[c2Start : ciphertextLen-sm3.Size]
msgLen := len(c2)
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if !success {
return nil, errors.New("invalid cipher text")
}
//B5, calculate msg = c2 ^ t
msg := make([]byte, msgLen)
for i := 0; i < msgLen; i++ {
msg[i] = c2[i] ^ t[i]
}
//B6, calculate hash and compare it
c3 := ciphertext[ciphertextLen-sm3.Size:]
u := calculateC3(curve, x2, y2, msg)
for i := 0; i < sm3.Size; i++ {
if c3[i] != u[i] {
return nil, errors.New("invalid hash value")
}
}
return msg, nil
}

57
sm2/sm2_test.go Normal file
View File

@ -0,0 +1,57 @@
package sm2
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/hex"
"math/big"
"reflect"
"testing"
)
func Test_kdf(t *testing.T) {
x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16)
y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16)
expected := "006e30dae231b071dfad8aa379e90264491603"
result, success := kdf(append(x2.Bytes(), y2.Bytes()...), 19)
if !success {
t.Fatalf("failed")
}
resultStr := hex.EncodeToString(result)
if expected != resultStr {
t.Fatalf("expected %s, real value %s", expected, resultStr)
}
}
func Test_encryptDecrypt(t *testing.T) {
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
tests := []struct {
name string
plainText string
}{
// TODO: Add test cases.
{"less than 32", "encryption standard"},
{"equals 32", "encryption standard encryption "},
{"long than 32", "encryption standard encryption standard"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText))
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err := Decrypt(priv, ciphertext)
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}

120
sm2/util.go Normal file
View File

@ -0,0 +1,120 @@
package sm2
import (
"crypto/elliptic"
"errors"
"fmt"
"math/big"
"strings"
)
var zero = new(big.Int).SetInt64(0)
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
bytes := value.Bytes()
byteLen := (curve.Params().BitSize + 7) >> 3
if byteLen == len(bytes) {
return bytes
}
result := make([]byte, byteLen)
copy(result[byteLen-len(bytes):], bytes)
return result
}
func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
return elliptic.Marshal(curve, x, y)
}
func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
buffer := make([]byte, (curve.Params().BitSize+7)>>3+1)
copy(buffer[1:], toBytes(curve, x))
if getLastBitOfY(x, y) > 0 {
buffer[0] = Compressed_03
} else {
buffer[0] = Compressed_02
}
return buffer
}
func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
buffer := elliptic.Marshal(curve, x, y)
if getLastBitOfY(x, y) > 0 {
buffer[0] = Mixed_07
} else {
buffer[0] = Mixed_06
}
return buffer
}
func getLastBitOfY(x, y *big.Int) uint {
if x.Cmp(zero) == 0 {
return 0
}
return y.Bit(0)
}
func toPointXY(bytes []byte) *big.Int {
return new(big.Int).SetBytes(bytes)
}
func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) {
x3 := new(big.Int).Mul(x, x)
x3.Mul(x3, x)
threeX := new(big.Int).Lsh(x, 1)
threeX.Add(threeX, x)
x3.Sub(x3, threeX)
x3.Add(x3, curve.Params().B)
x3.Mod(x3, curve.Params().P)
y := x3.ModSqrt(x3, curve.Params().P)
if y == nil {
return nil, errors.New("can't calculate y based on x")
}
return y, nil
}
func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
if len(bytes) < 1+(curve.Params().BitSize/8) {
return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes))
}
format := bytes[0]
byteLen := (curve.Params().BitSize + 7) >> 3
switch format {
case Uncompressed:
if len(bytes) < 1+byteLen*2 {
return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes))
}
x := toPointXY(bytes[1 : 1+byteLen])
y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
return x, y, 1 + byteLen*2, nil
case Compressed_02, Compressed_03:
if len(bytes) < 1+byteLen {
return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes))
}
if strings.HasPrefix(curve.Params().Name, "P-") {
// y² = x³ - 3x + b
x := toPointXY(bytes[1 : 1+byteLen])
y, err := calculatePrimeCurveY(curve, x)
if err != nil {
return nil, nil, 0, err
}
if (getLastBitOfY(x, y) > 0 && format == Compressed_02) || (getLastBitOfY(x, y) == 0 && format == Compressed_03) {
y.Sub(curve.Params().P, y)
}
return x, y, 1 + byteLen, nil
}
return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name)
case Mixed_06, Mixed_07:
// what's the mixed format purpose?
if len(bytes) < 1+byteLen*2 {
return nil, nil, 0, fmt.Errorf("invalid mixed bytes length %d", len(bytes))
}
x := toPointXY(bytes[1 : 1+byteLen])
y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
return x, y, 1 + byteLen*2, nil
}
return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format)
}

79
sm2/util_test.go Normal file
View File

@ -0,0 +1,79 @@
package sm2
import (
"crypto/elliptic"
"encoding/hex"
"math/big"
"reflect"
"testing"
)
func Test_toBytes(t *testing.T) {
type args struct {
value string
}
tests := []struct {
name string
args args
want string
}{
// TODO: Add test cases.
{"less than 32", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "00d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
{"equals 32", args{"58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v, _ := new(big.Int).SetString(tt.args.value, 16)
if got := toBytes(elliptic.P256(), v); !reflect.DeepEqual(hex.EncodeToString(got), tt.want) {
t.Errorf("toBytes() = %v, want %v", hex.EncodeToString(got), tt.want)
}
})
}
}
func Test_getLastBitOfY(t *testing.T) {
type args struct {
y string
}
tests := []struct {
name string
args args
want uint
}{
// TODO: Add test cases.
{"0", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, 0},
{"1", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865ff"}, 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
y, _ := new(big.Int).SetString(tt.args.y, 16)
if got := getLastBitOfY(y, y); got != tt.want {
t.Errorf("getLastBitOfY() = %v, want %v", got, tt.want)
}
})
}
}
func Test_toPointXY(t *testing.T) {
type args struct {
bytes string
}
tests := []struct {
name string
args args
want string
}{
// TODO: Add test cases.
{"has zero padding", args{"00d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
{"no zero padding", args{"58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bytes, _ := hex.DecodeString(tt.args.bytes)
expectedInt, _ := new(big.Int).SetString(tt.want, 16)
if got := toPointXY(bytes); !reflect.DeepEqual(got, expectedInt) {
t.Errorf("toPointXY() = %v, want %v", got, expectedInt)
}
})
}
}

View File

@ -7,10 +7,10 @@ import (
)
// Size the size of a SM3 checksum in bytes.
const Size = 32
const Size int = 32
// BlockSize the blocksize of SM3 in bytes.
const BlockSize = 64
const BlockSize int = 64
const (
chunk = 64