From a19280d02c0a9c9f75d673cf064f101933a476b9 Mon Sep 17 00:00:00 2001 From: maru Date: Sun, 31 Dec 2023 16:12:20 -0500 Subject: [PATCH] Add cloud save data --- api/account-helper.go | 53 +++++++++++ api/account.go | 18 +--- api/savedata-defs.go | 74 +++++++++++++++ api/savedata.go | 211 ++++++++++++++++++++++++++++++++++++++++-- db/account.go | 2 +- go.mod | 1 + go.sum | 2 + 7 files changed, 333 insertions(+), 28 deletions(-) create mode 100644 api/account-helper.go create mode 100644 api/savedata-defs.go diff --git a/api/account-helper.go b/api/account-helper.go new file mode 100644 index 0000000..eeaaae6 --- /dev/null +++ b/api/account-helper.go @@ -0,0 +1,53 @@ +package api + +import ( + "encoding/base64" + "fmt" + "net/http" + + "github.com/Flashfyre/pokerogue-server/db" +) + +func GetUsernameFromRequest(request *http.Request) (string, error) { + if request.Header.Get("Authorization") == "" { + return "", fmt.Errorf("missing token") + } + + token, err := base64.StdEncoding.DecodeString(request.Header.Get("Authorization")) + if err != nil { + return "", fmt.Errorf("failed to decode token: %s", err) + } + + if len(token) != 32 { + return "", fmt.Errorf("invalid token length: got %d, expected 32", len(token)) + } + + username, err := db.GetUsernameFromToken(token) + if err != nil { + return "", fmt.Errorf("failed to validate token: %s", err) + } + + return username, nil +} + +func GetUuidFromRequest(request *http.Request) ([]byte, error) { + if request.Header.Get("Authorization") == "" { + return nil, fmt.Errorf("missing token") + } + + token, err := base64.StdEncoding.DecodeString(request.Header.Get("Authorization")) + if err != nil { + return nil, fmt.Errorf("failed to decode token: %s", err) + } + + if len(token) != 32 { + return nil, fmt.Errorf("invalid token length: got %d, expected 32", len(token)) + } + + uuid, err := db.GetUuidFromToken(token) + if err != nil { + return nil, fmt.Errorf("failed to validate token: %s", err) + } + + return uuid, nil +} diff --git a/api/account.go b/api/account.go index 7b2de36..3e68320 100644 --- a/api/account.go +++ b/api/account.go @@ -30,23 +30,7 @@ type AccountInfoResponse struct{ } func (s *Server) HandleAccountInfo(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") == "" { - http.Error(w, "missing token", http.StatusBadRequest) - return - } - - token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) - if err != nil { - http.Error(w, fmt.Sprintf("failed to decode token: %s", err), http.StatusBadRequest) - return - } - - if len(token) != 32 { - http.Error(w, "invalid token", http.StatusBadRequest) - return - } - - username, err := db.GetUsernameFromToken(token) + username, err := GetUsernameFromRequest(r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/api/savedata-defs.go b/api/savedata-defs.go new file mode 100644 index 0000000..41d61ee --- /dev/null +++ b/api/savedata-defs.go @@ -0,0 +1,74 @@ +package api + +type SystemSaveData struct { + TrainerId int `json:"trainerId"` + SecretId int `json:"secretId"` + DexData DexData `json:"dexData"` + Unlocks Unlocks `json:"unlocks"` + AchvUnlocks AchvUnlocks `json:"achvUnlocks"` + VoucherUnlocks VoucherUnlocks `json:"voucherUnlocks"` + VoucherCounts VoucherCounts `json:"voucherCounts"` + Eggs []EggData `json:"eggs"` + GameVersion string `json:"gameVersion"` + Timestamp int `json:"timestamp"` +} + +type DexData map[int]DexEntry + +type DexEntry struct { + SeenAttr interface{} `json:"seenAttr"` // integer or string + CaughtAttr interface{} `json:"caughtAttr"` // integer or string + SeenCount int `json:"seenCount"` + CaughtCount int `json:"caughtCount"` + HatchedCount int `json:"hatchedCount"` + Ivs []int `json:"ivs"` +} + +type Unlocks map[int]bool + +type AchvUnlocks map[string]int + +type VoucherUnlocks map[string]int + +type VoucherCounts map[string]int + +type EggData struct { + Id int `json:"id"` + GachaType GachaType `json:"gachaType"` + HatchWaves int `json:"hatchWaves"` + Timestamp int `json:"timestamp"` +} + +type GachaType int + +type SessionSaveData struct { + Seed string `json:"seed"` + GameMode GameMode `json:"gameMode"` + Party []PokemonData `json:"party"` + EnemyParty []PokemonData `json:"enemyParty"` + EnemyField []PokemonData `json:"enemyField"` + Modifiers []PersistentModifierData `json:"modifiers"` + EnemyModifiers []PersistentModifierData `json:"enemyModifiers"` + Arena ArenaData `json:"arena"` + PokeballCounts PokeballCounts `json:"pokeballCounts"` + Money int `json:"money"` + WaveIndex int `json:"waveIndex"` + BattleType BattleType `json:"battleType"` + Trainer TrainerData `json:"trainer"` + GameVersion string `json:"gameVersion"` + Timestamp int `json:"timestamp"` +} + +type GameMode int + +type PokemonData interface{} + +type PersistentModifierData interface{} + +type ArenaData interface{} + +type PokeballCounts map[string]int + +type BattleType int + +type TrainerData interface{} diff --git a/api/savedata.go b/api/savedata.go index e709dac..d4a3f07 100644 --- a/api/savedata.go +++ b/api/savedata.go @@ -1,30 +1,221 @@ package api -import "net/http" +import ( + "bytes" + "encoding/gob" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "os" -// /savedata/get - get save data + "github.com/klauspost/compress/zstd" +) -type SavedataGetRequest struct{} -type SavedataGetResponse struct{} +// /savedata/get - get save data func (s *Server) HandleSavedataGet(w http.ResponseWriter, r *http.Request) { + uuid, err := GetUuidFromRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + hexUuid := hex.EncodeToString(uuid) + + switch r.URL.Query().Get("datatype") { + case "0": // System + save, err := os.ReadFile("userdata/" + hexUuid + "/system.pzs") + if err != nil { + http.Error(w, fmt.Sprintf("failed to read save file: %s", err), http.StatusInternalServerError) + return + } + + zstdReader, err := zstd.NewReader(nil) + if err != nil { + http.Error(w, fmt.Sprintf("failed to create zstd reader: %s", err), http.StatusInternalServerError) + return + } + + decompressed, err := zstdReader.DecodeAll(save, nil) + if err != nil { + http.Error(w, fmt.Sprintf("failed to decompress save file: %s", err), http.StatusInternalServerError) + return + } + + gobDecoderBuf := bytes.NewBuffer(decompressed) + + var system SystemSaveData + err = gob.NewDecoder(gobDecoderBuf).Decode(&system) + if err != nil { + http.Error(w, fmt.Sprintf("failed to deserialize save: %s", err), http.StatusInternalServerError) + return + } + saveJson, err := json.Marshal(system) + if err != nil { + http.Error(w, fmt.Sprintf("failed to marshal save to json: %s", err), http.StatusInternalServerError) + return + } + + w.Write(saveJson) + case "1": // Session + save, err := os.ReadFile("userdata/" + hexUuid + "/session.pzs") + if err != nil { + http.Error(w, fmt.Sprintf("failed to read save file: %s", err), http.StatusInternalServerError) + return + } + + zstdReader, err := zstd.NewReader(nil) + if err != nil { + http.Error(w, fmt.Sprintf("failed to create zstd reader: %s", err), http.StatusInternalServerError) + return + } + + decompressed, err := zstdReader.DecodeAll(save, nil) + if err != nil { + http.Error(w, fmt.Sprintf("failed to decompress save file: %s", err), http.StatusInternalServerError) + return + } + + gobDecoderBuf := bytes.NewBuffer(decompressed) + + var session SessionSaveData + err = gob.NewDecoder(gobDecoderBuf).Decode(&session) + if err != nil { + http.Error(w, fmt.Sprintf("failed to deserialize save: %s", err), http.StatusInternalServerError) + return + } + + saveJson, err := json.Marshal(session) + if err != nil { + http.Error(w, fmt.Sprintf("failed to marshal save to json: %s", err), http.StatusInternalServerError) + return + } + + w.Write(saveJson) + default: + http.Error(w, "invalid data type", http.StatusBadRequest) + return + } } // /savedata/update - update save data -type SavedataUpdateRequest struct{} -type SavedataUpdateResponse struct{} - func (s *Server) HandleSavedataUpdate(w http.ResponseWriter, r *http.Request) { + uuid, err := GetUuidFromRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + hexUuid := hex.EncodeToString(uuid) + + switch r.URL.Query().Get("datatype") { + case "0": // System + var system SystemSaveData + err = json.NewDecoder(r.Body).Decode(&system) + if err != nil { + http.Error(w, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest) + return + } + + var gobBuffer bytes.Buffer + err = gob.NewEncoder(&gobBuffer).Encode(system) + if err != nil { + http.Error(w, fmt.Sprintf("failed to serialize save: %s", err), http.StatusInternalServerError) + return + } + + zstdWriter, err := zstd.NewWriter(nil) + if err != nil { + http.Error(w, fmt.Sprintf("failed to create zstd writer, %s", err), http.StatusInternalServerError) + return + } + + compressed := zstdWriter.EncodeAll(gobBuffer.Bytes(), nil) + + err = os.MkdirAll("userdata/"+hexUuid, 0755) + if !os.IsExist(err) { + http.Error(w, fmt.Sprintf("failed to create userdata folder: %s", err), http.StatusInternalServerError) + return + } + + err = os.WriteFile("userdata/"+hexUuid+"/system.pzs", compressed, 0644) + if err != nil { + http.Error(w, fmt.Sprintf("failed to write save file: %s", err), http.StatusInternalServerError) + return + } + case "1": // Session + var session SessionSaveData + err = json.NewDecoder(r.Body).Decode(&session) + if err != nil { + http.Error(w, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest) + return + } + + var gobBuffer bytes.Buffer + err = gob.NewEncoder(&gobBuffer).Encode(session) + if err != nil { + http.Error(w, fmt.Sprintf("failed to serialize save: %s", err), http.StatusInternalServerError) + return + } + + zstdWriter, err := zstd.NewWriter(nil) + if err != nil { + http.Error(w, fmt.Sprintf("failed to create zstd writer, %s", err), http.StatusInternalServerError) + return + } + + compressed := zstdWriter.EncodeAll(gobBuffer.Bytes(), nil) + + err = os.MkdirAll("userdata/"+hexUuid, 0755) + if !os.IsExist(err) { + http.Error(w, fmt.Sprintf("failed to create userdata folder: %s", err), http.StatusInternalServerError) + return + } + + err = os.WriteFile("userdata/"+hexUuid+"/session.pzs", compressed, 0644) + if err != nil { + http.Error(w, fmt.Sprintf("failed to write save file: %s", err), http.StatusInternalServerError) + return + } + default: + http.Error(w, "invalid data type", http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusOK) } // /savedata/delete - delete save date -type SavedataDeleteRequest struct{} -type SavedataDeleteResponse struct{} - func (s *Server) HandleSavedataDelete(w http.ResponseWriter, r *http.Request) { + uuid, err := GetUuidFromRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + hexUuid := hex.EncodeToString(uuid) + + switch r.URL.Query().Get("datatype") { + case "0": // System + err := os.Remove("userdata/"+hexUuid+"/system.pzs") + if !os.IsNotExist(err) { + http.Error(w, fmt.Sprintf("failed to delete save file: %s", err), http.StatusInternalServerError) + return + } + case "1": // Session + err := os.Remove("userdata/"+hexUuid+"/session.pzs") + if !os.IsNotExist(err) { + http.Error(w, fmt.Sprintf("failed to delete save file: %s", err), http.StatusInternalServerError) + return + } + default: + http.Error(w, "invalid data type", http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) } diff --git a/db/account.go b/db/account.go index bd0abd8..b3912db 100644 --- a/db/account.go +++ b/db/account.go @@ -49,7 +49,7 @@ func GetAccountKeySaltFromUsername(username string) ([]byte, []byte, error) { return key, salt, nil } -func GetUUIDFromToken(token []byte) ([]byte, error) { +func GetUuidFromToken(token []byte) ([]byte, error) { var uuid []byte err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ? AND expire > UTC_TIMESTAMP()", token).Scan(&uuid) if err != nil { diff --git a/go.mod b/go.mod index 413b799..9d13e8f 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21.4 require ( github.com/go-sql-driver/mysql v1.7.1 + github.com/klauspost/compress v1.17.4 golang.org/x/crypto v0.16.0 ) diff --git a/go.sum b/go.sum index 4c397c3..d95a3d3 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=