diff --git a/client.go b/client.go index 969595e..a7b3d5f 100644 --- a/client.go +++ b/client.go @@ -497,7 +497,11 @@ func (c *ClientCommon) Reply(m Message, value MsgVal) error { } func (c *ClientCommon) ExchangeKey(newKey []byte) error { - newSendKey, err := starcrypto.RSAEncrypt(newKey, c.handshakeRsaPubKey) + pubKey, err := starcrypto.DecodePublicKey(c.handshakeRsaPubKey) + if err != nil { + return err + } + newSendKey, err := starcrypto.RSAEncrypt(pubKey, newKey) if err != nil { return err } diff --git a/msg.go b/msg.go index 745a2eb..d21e5be 100644 --- a/msg.go +++ b/msg.go @@ -146,8 +146,13 @@ func (c *ClientConn) readTUMessage() { } func (c *ClientConn) rsaDecode(message Message) { - unknownKey := message.Value - data, err := starcrypto.RSADecrypt(unknownKey, c.handshakeRsaKey, "") + privKey, err := starcrypto.DecodePrivateKey(c.handshakeRsaKey, "") + if err != nil { + fmt.Println(err) + message.Reply([]byte("failed")) + return + } + data, err := starcrypto.RSADecrypt(privKey, message.Value) if err != nil { fmt.Println(err) message.Reply([]byte("failed")) diff --git a/v2cs_test.go b/v2cs_test.go index 1bb00c6..3760e80 100644 --- a/v2cs_test.go +++ b/v2cs_test.go @@ -14,22 +14,24 @@ func Test_ServerTuAndClientCommon(t *testing.T) { noEn := func(key, bn []byte) []byte { return bn } + _ = noEn server := NewServer() - server.SetDefaultCommDecode(noEn) - server.SetDefaultCommEncode(noEn) + //server.SetDefaultCommDecode(noEn) + //server.SetDefaultCommEncode(noEn) err := server.Listen("tcp", "127.0.0.1:12345") if err != nil { panic(err) } server.SetLink("notify", notify) - for i := 1; i <= 5000; i++ { + for i := 1; i <= 100; i++ { go func() { client := NewClient() - client.SetMsgEn(noEn) - client.SetMsgDe(noEn) - client.SetSkipExchangeKey(true) + //client.SetMsgEn(noEn) + //client.SetMsgDe(noEn) + //client.SetSkipExchangeKey(true) err = client.Connect("tcp", "127.0.0.1:12345") if err != nil { + t.Fatal(err) time.Sleep(time.Second * 2) return } @@ -37,7 +39,8 @@ func Test_ServerTuAndClientCommon(t *testing.T) { for { //nowd = time.Now().UnixNano() - client.SendWait("notify", []byte("client hello"),time.Second*15) + client.SendWait("notify", []byte("client hello"), time.Second*15) + //client.Send("notify", []byte("client hello")) //time.Sleep(time.Millisecond) //fmt.Println("finished:", float64(time.Now().UnixNano()-nowd)/1000000) //client.Send("notify", []byte("client")) @@ -65,7 +68,10 @@ func notify(msg *Message) { } func Test_normal(t *testing.T) { - server, _ := net.Listen("udp", "127.0.0.1:12345") + server, err := net.Listen("tcp", "127.0.0.1:12345") + if err != nil { + t.Fatal(err) + } go func() { for { conn, err := server.Accept() @@ -87,7 +93,7 @@ func Test_normal(t *testing.T) { time.Sleep(time.Second * 5) for i := 1; i <= 100; i++ { go func() { - conn, err := net.Dial("udp", "127.0.0.1:12345") + conn, err := net.Dial("tcp", "127.0.0.1:12345") if err != nil { panic(err) }