mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-27 04:36:19 +08:00
MAGIC - sm2, basic implementation
This commit is contained in:
parent
4d7305a6f6
commit
be62e3a042
156
sm2/sm2.go
Normal file
156
sm2/sm2.go
Normal file
@ -0,0 +1,156 @@
|
||||
package sm2
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gmsm/sm3"
|
||||
"io"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
const (
|
||||
Uncompressed byte = 0x04
|
||||
Compressed_02 byte = 0x02
|
||||
Compressed_03 byte = 0x03
|
||||
Mixed_06 byte = 0x06
|
||||
Mixed_07 byte = 0x07
|
||||
)
|
||||
|
||||
///////////////// below code ship from golan crypto/ecdsa ////////////////////
|
||||
var one = new(big.Int).SetInt64(1)
|
||||
|
||||
// randFieldElement returns a random element of the field underlying the given
|
||||
// curve using the procedure given in [NSA] A.2.1.
|
||||
func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
|
||||
params := c.Params()
|
||||
b := make([]byte, params.BitSize/8+8)
|
||||
_, err = io.ReadFull(rand, b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
k = new(big.Int).SetBytes(b)
|
||||
n := new(big.Int).Sub(params.N, one)
|
||||
k.Mod(k, n)
|
||||
k.Add(k, one)
|
||||
return
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
func kdf(z []byte, len int) ([]byte, bool) {
|
||||
limit := (len + sm3.Size - 1) / sm3.Size
|
||||
sm3Hasher := sm3.New()
|
||||
var countBytes [4]byte
|
||||
var ct uint32 = 1
|
||||
k := make([]byte, len+sm3.Size-1)
|
||||
for i := 0; i < limit; i++ {
|
||||
binary.BigEndian.PutUint32(countBytes[:], ct)
|
||||
sm3Hasher.Write(z)
|
||||
sm3Hasher.Write(countBytes[:])
|
||||
copy(k[i*sm3.Size:], sm3Hasher.Sum(nil))
|
||||
ct++
|
||||
sm3Hasher.Reset()
|
||||
}
|
||||
for i := 0; i < len; i++ {
|
||||
if k[i] != 0 {
|
||||
return k[:len], true
|
||||
}
|
||||
}
|
||||
return k, false
|
||||
}
|
||||
|
||||
func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte {
|
||||
hasher := sm3.New()
|
||||
hasher.Write(toBytes(curve, x2))
|
||||
hasher.Write(msg)
|
||||
hasher.Write(toBytes(curve, y2))
|
||||
return hasher.Sum(nil)
|
||||
}
|
||||
|
||||
// Encrypt sm2 encrypt implementation
|
||||
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) {
|
||||
curve := pub.Curve
|
||||
msgLen := len(msg)
|
||||
for {
|
||||
//A1, generate random k
|
||||
k, err := randFieldElement(curve, random)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//A2, calculate C1 = k * G
|
||||
x1, y1 := curve.ScalarBaseMult(k.Bytes())
|
||||
c1 := point2CompressedBytes(curve, x1, y1)
|
||||
|
||||
//A3, skipped
|
||||
//A4, calculate k * P (point of Public Key)
|
||||
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
|
||||
|
||||
//A5, calculate t=KDF(x2||y2, klen)
|
||||
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||
if !success {
|
||||
fmt.Println("A5, failed to get valid t")
|
||||
continue
|
||||
}
|
||||
|
||||
//A6, C2 = M + t;
|
||||
c2 := make([]byte, msgLen)
|
||||
for i := 0; i < msgLen; i++ {
|
||||
c2[i] = msg[i] ^ t[i]
|
||||
}
|
||||
|
||||
//A7, C3 = hash(x2||M||y2)
|
||||
c3 := calculateC3(curve, x2, y2, msg)
|
||||
|
||||
return append(append(c1, c2...), c3...), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Decrypt sm2 decrypt implementation
|
||||
func Decrypt(priv *ecdsa.PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||
ciphertextLen := len(ciphertext)
|
||||
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
|
||||
return nil, errors.New("invalid ciphertext length")
|
||||
}
|
||||
curve := priv.Curve
|
||||
// B1, get C1, and check C1
|
||||
x1, y1, c2Start, err := bytes2Point(curve, ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !curve.IsOnCurve(x1, y1) {
|
||||
return nil, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name)
|
||||
}
|
||||
|
||||
//B2 is ignored
|
||||
//B3, calculate x2, y2
|
||||
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
|
||||
|
||||
//B4, calculate t=KDF(x2||y2, klen)
|
||||
c2 := ciphertext[c2Start : ciphertextLen-sm3.Size]
|
||||
msgLen := len(c2)
|
||||
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||
if !success {
|
||||
return nil, errors.New("invalid cipher text")
|
||||
}
|
||||
|
||||
//B5, calculate msg = c2 ^ t
|
||||
msg := make([]byte, msgLen)
|
||||
for i := 0; i < msgLen; i++ {
|
||||
msg[i] = c2[i] ^ t[i]
|
||||
}
|
||||
|
||||
//B6, calculate hash and compare it
|
||||
c3 := ciphertext[ciphertextLen-sm3.Size:]
|
||||
u := calculateC3(curve, x2, y2, msg)
|
||||
for i := 0; i < sm3.Size; i++ {
|
||||
if c3[i] != u[i] {
|
||||
return nil, errors.New("invalid hash value")
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
57
sm2/sm2_test.go
Normal file
57
sm2/sm2_test.go
Normal file
@ -0,0 +1,57 @@
|
||||
package sm2
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_kdf(t *testing.T) {
|
||||
x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16)
|
||||
y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16)
|
||||
|
||||
expected := "006e30dae231b071dfad8aa379e90264491603"
|
||||
|
||||
result, success := kdf(append(x2.Bytes(), y2.Bytes()...), 19)
|
||||
if !success {
|
||||
t.Fatalf("failed")
|
||||
}
|
||||
|
||||
resultStr := hex.EncodeToString(result)
|
||||
|
||||
if expected != resultStr {
|
||||
t.Fatalf("expected %s, real value %s", expected, resultStr)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_encryptDecrypt(t *testing.T) {
|
||||
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
tests := []struct {
|
||||
name string
|
||||
plainText string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{"less than 32", "encryption standard"},
|
||||
{"equals 32", "encryption standard encryption "},
|
||||
{"long than 32", "encryption standard encryption standard"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt failed %v", err)
|
||||
}
|
||||
plaintext, err := Decrypt(priv, ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt failed %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
120
sm2/util.go
Normal file
120
sm2/util.go
Normal file
@ -0,0 +1,120 @@
|
||||
package sm2
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var zero = new(big.Int).SetInt64(0)
|
||||
|
||||
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
|
||||
bytes := value.Bytes()
|
||||
byteLen := (curve.Params().BitSize + 7) >> 3
|
||||
if byteLen == len(bytes) {
|
||||
return bytes
|
||||
}
|
||||
result := make([]byte, byteLen)
|
||||
copy(result[byteLen-len(bytes):], bytes)
|
||||
return result
|
||||
}
|
||||
|
||||
func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||
return elliptic.Marshal(curve, x, y)
|
||||
}
|
||||
|
||||
func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||
buffer := make([]byte, (curve.Params().BitSize+7)>>3+1)
|
||||
copy(buffer[1:], toBytes(curve, x))
|
||||
if getLastBitOfY(x, y) > 0 {
|
||||
buffer[0] = Compressed_03
|
||||
} else {
|
||||
buffer[0] = Compressed_02
|
||||
}
|
||||
return buffer
|
||||
}
|
||||
|
||||
func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||
buffer := elliptic.Marshal(curve, x, y)
|
||||
if getLastBitOfY(x, y) > 0 {
|
||||
buffer[0] = Mixed_07
|
||||
} else {
|
||||
buffer[0] = Mixed_06
|
||||
}
|
||||
return buffer
|
||||
}
|
||||
|
||||
func getLastBitOfY(x, y *big.Int) uint {
|
||||
if x.Cmp(zero) == 0 {
|
||||
return 0
|
||||
}
|
||||
return y.Bit(0)
|
||||
}
|
||||
|
||||
func toPointXY(bytes []byte) *big.Int {
|
||||
return new(big.Int).SetBytes(bytes)
|
||||
}
|
||||
|
||||
func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) {
|
||||
x3 := new(big.Int).Mul(x, x)
|
||||
x3.Mul(x3, x)
|
||||
|
||||
threeX := new(big.Int).Lsh(x, 1)
|
||||
threeX.Add(threeX, x)
|
||||
|
||||
x3.Sub(x3, threeX)
|
||||
x3.Add(x3, curve.Params().B)
|
||||
x3.Mod(x3, curve.Params().P)
|
||||
y := x3.ModSqrt(x3, curve.Params().P)
|
||||
|
||||
if y == nil {
|
||||
return nil, errors.New("can't calculate y based on x")
|
||||
}
|
||||
return y, nil
|
||||
}
|
||||
|
||||
func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
|
||||
if len(bytes) < 1+(curve.Params().BitSize/8) {
|
||||
return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes))
|
||||
}
|
||||
format := bytes[0]
|
||||
byteLen := (curve.Params().BitSize + 7) >> 3
|
||||
switch format {
|
||||
case Uncompressed:
|
||||
if len(bytes) < 1+byteLen*2 {
|
||||
return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes))
|
||||
}
|
||||
x := toPointXY(bytes[1 : 1+byteLen])
|
||||
y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
|
||||
return x, y, 1 + byteLen*2, nil
|
||||
case Compressed_02, Compressed_03:
|
||||
if len(bytes) < 1+byteLen {
|
||||
return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes))
|
||||
}
|
||||
if strings.HasPrefix(curve.Params().Name, "P-") {
|
||||
// y² = x³ - 3x + b
|
||||
x := toPointXY(bytes[1 : 1+byteLen])
|
||||
y, err := calculatePrimeCurveY(curve, x)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
if (getLastBitOfY(x, y) > 0 && format == Compressed_02) || (getLastBitOfY(x, y) == 0 && format == Compressed_03) {
|
||||
y.Sub(curve.Params().P, y)
|
||||
}
|
||||
return x, y, 1 + byteLen, nil
|
||||
}
|
||||
return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name)
|
||||
case Mixed_06, Mixed_07:
|
||||
// what's the mixed format purpose?
|
||||
if len(bytes) < 1+byteLen*2 {
|
||||
return nil, nil, 0, fmt.Errorf("invalid mixed bytes length %d", len(bytes))
|
||||
}
|
||||
x := toPointXY(bytes[1 : 1+byteLen])
|
||||
y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
|
||||
return x, y, 1 + byteLen*2, nil
|
||||
}
|
||||
return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format)
|
||||
}
|
79
sm2/util_test.go
Normal file
79
sm2/util_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
package sm2
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_toBytes(t *testing.T) {
|
||||
type args struct {
|
||||
value string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{"less than 32", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "00d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
|
||||
{"equals 32", args{"58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v, _ := new(big.Int).SetString(tt.args.value, 16)
|
||||
if got := toBytes(elliptic.P256(), v); !reflect.DeepEqual(hex.EncodeToString(got), tt.want) {
|
||||
t.Errorf("toBytes() = %v, want %v", hex.EncodeToString(got), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getLastBitOfY(t *testing.T) {
|
||||
type args struct {
|
||||
y string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want uint
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{"0", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, 0},
|
||||
{"1", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865ff"}, 1},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
y, _ := new(big.Int).SetString(tt.args.y, 16)
|
||||
if got := getLastBitOfY(y, y); got != tt.want {
|
||||
t.Errorf("getLastBitOfY() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_toPointXY(t *testing.T) {
|
||||
type args struct {
|
||||
bytes string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{"has zero padding", args{"00d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
|
||||
{"no zero padding", args{"58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
bytes, _ := hex.DecodeString(tt.args.bytes)
|
||||
expectedInt, _ := new(big.Int).SetString(tt.want, 16)
|
||||
if got := toPointXY(bytes); !reflect.DeepEqual(got, expectedInt) {
|
||||
t.Errorf("toPointXY() = %v, want %v", got, expectedInt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -7,10 +7,10 @@ import (
|
||||
)
|
||||
|
||||
// Size the size of a SM3 checksum in bytes.
|
||||
const Size = 32
|
||||
const Size int = 32
|
||||
|
||||
// BlockSize the blocksize of SM3 in bytes.
|
||||
const BlockSize = 64
|
||||
const BlockSize int = 64
|
||||
|
||||
const (
|
||||
chunk = 64
|
||||
|
Loading…
x
Reference in New Issue
Block a user