diff --git a/api/account.go b/api/account.go index 654c683..a88b077 100644 --- a/api/account.go +++ b/api/account.go @@ -5,9 +5,7 @@ import ( "crypto/rand" "database/sql" "encoding/base64" - "encoding/json" "fmt" - "net/http" "os" "regexp" "strconv" @@ -34,19 +32,7 @@ type AccountInfoResponse struct { } // /account/info - get account info -func (s *Server) handleAccountInfo(w http.ResponseWriter, r *http.Request) { - username, err := getUsernameFromRequest(r) - if err != nil { - httpError(w, r, err.Error(), http.StatusBadRequest) - return - } - - uuid, err := getUUIDFromRequest(r) // lazy - if err != nil { - httpError(w, r, err.Error(), http.StatusBadRequest) - return - } - +func handleAccountInfo(username string, uuid []byte) (AccountInfoResponse, error) { var latestSave time.Time latestSaveID := -1 for id := range sessionSlotCount { @@ -66,142 +52,95 @@ func (s *Server) handleAccountInfo(w http.ResponseWriter, r *http.Request) { } } - response, err := json.Marshal(AccountInfoResponse{Username: username, LastSessionSlot: latestSaveID}) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError) - return - } - - w.Write(response) + return AccountInfoResponse{Username: username, LastSessionSlot: latestSaveID}, nil } type AccountRegisterRequest GenericAuthRequest // /account/register - register account -func (s *Server) handleAccountRegister(w http.ResponseWriter, r *http.Request) { - var request AccountRegisterRequest - err := json.NewDecoder(r.Body).Decode(&request) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest) - return +func handleAccountRegister(username, password string) error { + if !isValidUsername(username) { + return fmt.Errorf("invalid username") } - if !isValidUsername(request.Username) { - httpError(w, r, "invalid username", http.StatusBadRequest) - return - } - - if len(request.Password) < 6 { - httpError(w, r, "invalid password", http.StatusBadRequest) - return + if len(password) < 6 { + return fmt.Errorf("invalid password") } uuid := make([]byte, UUIDSize) - _, err = rand.Read(uuid) + _, err := rand.Read(uuid) if err != nil { - httpError(w, r, fmt.Sprintf("failed to generate uuid: %s", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to generate uuid: %s", err) } salt := make([]byte, ArgonSaltSize) _, err = rand.Read(salt) if err != nil { - httpError(w, r, fmt.Sprintf("failed to generate salt: %s", err), http.StatusInternalServerError) - return + return fmt.Errorf(fmt.Sprintf("failed to generate salt: %s", err)) } - err = db.AddAccountRecord(uuid, request.Username, argon2.IDKey([]byte(request.Password), salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize), salt) + err = db.AddAccountRecord(uuid, username, argon2.IDKey([]byte(password), salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize), salt) if err != nil { - httpError(w, r, fmt.Sprintf("failed to add account record: %s", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to add account record: %s", err) } - w.WriteHeader(http.StatusOK) + return nil } type AccountLoginRequest GenericAuthRequest type AccountLoginResponse GenericAuthResponse // /account/login - log into account -func (s *Server) handleAccountLogin(w http.ResponseWriter, r *http.Request) { - var request AccountLoginRequest - err := json.NewDecoder(r.Body).Decode(&request) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest) - return +func handleAccountLogin(username, password string) (AccountLoginResponse, error) { + if !isValidUsername(username) { + return AccountLoginResponse{}, fmt.Errorf("invalid username") } - if !isValidUsername(request.Username) { - httpError(w, r, "invalid username", http.StatusBadRequest) - return + if len(password) < 6 { + return AccountLoginResponse{}, fmt.Errorf("invalid password") } - if len(request.Password) < 6 { - httpError(w, r, "invalid password", http.StatusBadRequest) - return - } - - key, salt, err := db.FetchAccountKeySaltFromUsername(request.Username) + key, salt, err := db.FetchAccountKeySaltFromUsername(username) if err != nil { if err == sql.ErrNoRows { - httpError(w, r, "account doesn't exist", http.StatusBadRequest) - return + return AccountLoginResponse{}, fmt.Errorf("account doesn't exist") } - httpError(w, r, err.Error(), http.StatusInternalServerError) - return + return AccountLoginResponse{}, err } - if !bytes.Equal(key, argon2.IDKey([]byte(request.Password), salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize)) { - httpError(w, r, "password doesn't match", http.StatusBadRequest) - return + if !bytes.Equal(key, argon2.IDKey([]byte(password), salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize)) { + return AccountLoginResponse{}, fmt.Errorf("password doesn't match") } token := make([]byte, 32) _, err = rand.Read(token) if err != nil { - httpError(w, r, fmt.Sprintf("failed to generate token: %s", err), http.StatusInternalServerError) - return - } - - err = db.AddAccountSession(request.Username, token) - if err != nil { - httpError(w, r, "failed to add account session", http.StatusInternalServerError) - return + return AccountLoginResponse{}, fmt.Errorf("failed to generate token: %s", err) } - response, err := json.Marshal(AccountLoginResponse{Token: base64.StdEncoding.EncodeToString(token)}) + err = db.AddAccountSession(username, token) if err != nil { - httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError) - return + return AccountLoginResponse{}, fmt.Errorf("failed to add account session") } - w.Write(response) + return AccountLoginResponse{Token: base64.StdEncoding.EncodeToString(token)}, nil } // /account/logout - log out of account -func (s *Server) handleAccountLogout(w http.ResponseWriter, r *http.Request) { - token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to decode token: %s", err), http.StatusBadRequest) - return - } - +func handleAccountLogout(token []byte) error { if len(token) != 32 { - httpError(w, r, "invalid token", http.StatusBadRequest) - return + return fmt.Errorf("invalid token") } - err = db.RemoveSessionFromToken(token) + err := db.RemoveSessionFromToken(token) if err != nil { if err == sql.ErrNoRows { - httpError(w, r, "token not found", http.StatusBadRequest) - return + return fmt.Errorf("token not found") } - httpError(w, r, "failed to remove account session", http.StatusInternalServerError) - return + return fmt.Errorf("failed to remove account session") } - w.WriteHeader(http.StatusOK) + return nil } diff --git a/api/daily.go b/api/daily.go index 7d8273a..107703c 100644 --- a/api/daily.go +++ b/api/daily.go @@ -5,15 +5,13 @@ import ( "crypto/rand" "encoding/base64" "encoding/binary" - "encoding/json" "fmt" "log" - "net/http" "os" - "strconv" "time" "github.com/Flashfyre/pokerogue-server/db" + "github.com/Flashfyre/pokerogue-server/defs" "github.com/go-co-op/gocron" ) @@ -81,79 +79,27 @@ func deriveDailyRunSeed(seedTime time.Time) []byte { return hashedSeed[:] } -// /daily/seed - fetch daily run seed -func (s *Server) handleSeed(w http.ResponseWriter) { - w.Write([]byte(dailyRunSeed)) -} - // /daily/rankings - fetch daily rankings -func (s *Server) handleRankings(w http.ResponseWriter, r *http.Request) { - uuid, err := getUUIDFromRequest(r) - if err != nil { - httpError(w, r, err.Error(), http.StatusBadRequest) - return - } - - err = db.UpdateAccountLastActivity(uuid) +func handleRankings(uuid []byte, category, page int) ([]defs.DailyRanking, error) { + err := db.UpdateAccountLastActivity(uuid) if err != nil { log.Print("failed to update account last activity") } - var category int - if r.URL.Query().Has("category") { - category, err = strconv.Atoi(r.URL.Query().Get("category")) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to convert category: %s", err), http.StatusBadRequest) - return - } - } - - page := 1 - if r.URL.Query().Has("page") { - page, err = strconv.Atoi(r.URL.Query().Get("page")) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to convert page: %s", err), http.StatusBadRequest) - return - } - } - rankings, err := db.FetchRankings(category, page) if err != nil { log.Print("failed to retrieve rankings") } - response, err := json.Marshal(rankings) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError) - return - } - - w.Write(response) + return rankings, nil } // /daily/rankingpagecount - fetch daily ranking page count -func (s *Server) handleRankingPageCount(w http.ResponseWriter, r *http.Request) { - var err error - var category int - - if r.URL.Query().Has("category") { - category, err = strconv.Atoi(r.URL.Query().Get("category")) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to convert category: %s", err), http.StatusBadRequest) - return - } - } - +func handleRankingPageCount(category int) (int, error) { pageCount, err := db.FetchRankingPageCount(category) if err != nil { log.Print("failed to retrieve ranking page count") } - response, err := json.Marshal(pageCount) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError) - return - } - - w.Write(response) + return pageCount, nil } diff --git a/api/error.go b/api/error.go index 1576523..2c60438 100644 --- a/api/error.go +++ b/api/error.go @@ -5,7 +5,7 @@ import ( "net/http" ) -func httpError(w http.ResponseWriter, r *http.Request, error string, code int) { - log.Printf("%s: %s\n", r.URL.Path, error) - http.Error(w, error, code) +func httpError(w http.ResponseWriter, r *http.Request, err error, code int) { + log.Printf("%s: %s\n", r.URL.Path, err) + http.Error(w, err.Error(), code) } diff --git a/api/game.go b/api/game.go index c288346..c22b5c9 100644 --- a/api/game.go +++ b/api/game.go @@ -1,14 +1,10 @@ package api import ( - "encoding/json" - "fmt" "log" - "net/http" "time" "github.com/Flashfyre/pokerogue-server/db" - "github.com/Flashfyre/pokerogue-server/defs" "github.com/go-co-op/gocron" ) @@ -30,49 +26,14 @@ func updateStats() { if err != nil { log.Print(err) } + battleCount, err = db.FetchBattleCount() if err != nil { log.Print(err) } + classicSessionCount, err = db.FetchClassicSessionCount() if err != nil { log.Print(err) } } - -// /game/playercount - get player count -func (s *Server) handlePlayerCountGet(w http.ResponseWriter, r *http.Request) { - response, err := json.Marshal(playerCount) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError) - return - } - - w.Write(response) -} - -// /game/titlestats - get title stats -func (s *Server) handleTitleStatsGet(w http.ResponseWriter, r *http.Request) { - titleStats := &defs.TitleStats{ - PlayerCount: playerCount, - BattleCount: battleCount, - } - response, err := json.Marshal(titleStats) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError) - return - } - - w.Write(response) -} - -// /game/classicsessioncount - get classic session count -func (s *Server) handleClassicSessionCountGet(w http.ResponseWriter, r *http.Request) { - response, err := json.Marshal(classicSessionCount) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError) - return - } - - w.Write(response) -} diff --git a/api/generic.go b/api/generic.go index 7235c10..8d7606f 100644 --- a/api/generic.go +++ b/api/generic.go @@ -1,14 +1,26 @@ package api import ( + "encoding/base64" "encoding/gob" + "encoding/json" + "fmt" "net/http" + "strconv" + + "github.com/Flashfyre/pokerogue-server/defs" ) type Server struct { Debug bool } +/* + The caller of endpoint handler functions are responsible for extracting the necessary data from the request. + Handler functions are responsible for checking the validity of this data and returning a result or error. + Handlers should not return serialized JSON, instead return the struct itself. +*/ + func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { gob.Register([]interface{}{}) gob.Register(map[string]interface{}{}) @@ -25,37 +37,238 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } switch r.URL.Path { + // /account case "/account/info": - s.handleAccountInfo(w, r) + username, err := getUsernameFromRequest(r) + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } + + uuid, err := getUUIDFromRequest(r) // lazy + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } + + info, err := handleAccountInfo(username, uuid) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } + + response, err := json.Marshal(info) + if err != nil { + httpError(w, r, fmt.Errorf("failed to marshal response json: %s", err), http.StatusInternalServerError) + return + } + + w.Write(response) case "/account/register": - s.handleAccountRegister(w, r) + var request AccountRegisterRequest + err := json.NewDecoder(r.Body).Decode(&request) + if err != nil { + httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest) + return + } + + err = handleAccountRegister(request.Username, request.Password) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) case "/account/login": - s.handleAccountLogin(w, r) + var request AccountLoginRequest + err := json.NewDecoder(r.Body).Decode(&request) + if err != nil { + httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest) + return + } + + token, err := handleAccountLogin(request.Username, request.Password) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } + + response, err := json.Marshal(token) + if err != nil { + httpError(w, r, fmt.Errorf("failed to marshal response json: %s", err), http.StatusInternalServerError) + return + } + + w.Write(response) case "/account/logout": - s.handleAccountLogout(w, r) + token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) + if err != nil { + httpError(w, r, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest) + return + } + + err = handleAccountLogout(token) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + + // /game case "/game/playercount": - s.handlePlayerCountGet(w, r) + w.Write([]byte(strconv.Itoa(playerCount))) case "/game/titlestats": - s.handleTitleStatsGet(w, r) + response, err := json.Marshal(&defs.TitleStats{ + PlayerCount: playerCount, + BattleCount: battleCount, + }) + if err != nil { + httpError(w, r, fmt.Errorf("failed to marshal response json: %s", err), http.StatusInternalServerError) + return + } + + w.Write(response) case "/game/classicsessioncount": - s.handleClassicSessionCountGet(w, r) + w.Write([]byte(strconv.Itoa(classicSessionCount))) - case "/savedata/get": - s.handleSavedataGet(w, r) - case "/savedata/update": - s.handleSavedataUpdate(w, r) - case "/savedata/delete": - s.handleSavedataDelete(w, r) - case "/savedata/clear": - s.handleSavedataClear(w, r) + // /savedata + case "/savedata/get", "/savedata/update", "/savedata/delete", "/savedata/clear": + uuid, err := getUUIDFromRequest(r) + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } + + datatype := -1 + if r.URL.Query().Has("datatype") { + datatype, err = strconv.Atoi(r.URL.Query().Get("datatype")) + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } + } + + var slot int + if r.URL.Query().Has("slot") { + slot, err = strconv.Atoi(r.URL.Query().Get("slot")) + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } + } + + var save any + // /savedata/delete specifies datatype, but doesn't expect data in body + if r.URL.Path != "/savedata/delete" { + if datatype == 0 { + var system defs.SystemSaveData + err = json.NewDecoder(r.Body).Decode(&system) + if err != nil { + httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest) + return + } + + save = system + // /savedata/clear doesn't specify datatype, it is assumed to be 1 (session) + } else if datatype == 1 || r.URL.Path == "/savedata/clear" { + var session defs.SessionSaveData + err = json.NewDecoder(r.Body).Decode(&session) + if err != nil { + httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest) + return + } + + save = session + } + } + + switch r.URL.Path { + case "/savedata/get": + save, err = handleSavedataGet(uuid, datatype, slot) + case "/savedata/update": + err = handleSavedataUpdate(uuid, slot, save) + case "/savedata/delete": + err = handleSavedataDelete(uuid, datatype, slot) + case "/savedata/clear": + // doesn't return a save, but it works + save, err = handleSavedataClear(uuid, slot, save.(defs.SessionSaveData)) + } + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } + + if save == nil { + w.WriteHeader(http.StatusOK) + } + + response, err := json.Marshal(save) + if err != nil { + httpError(w, r, fmt.Errorf("failed to marshal response json: %s", err), http.StatusInternalServerError) + return + } + + w.Write(response) + // /daily case "/daily/seed": - s.handleSeed(w) + w.Write([]byte(dailyRunSeed)) case "/daily/rankings": - s.handleRankings(w, r) + uuid, err := getUUIDFromRequest(r) + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } + + var category int + if r.URL.Query().Has("category") { + category, err = strconv.Atoi(r.URL.Query().Get("category")) + if err != nil { + httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest) + return + } + } + + page := 1 + if r.URL.Query().Has("page") { + page, err = strconv.Atoi(r.URL.Query().Get("page")) + if err != nil { + httpError(w, r, fmt.Errorf("failed to convert page: %s", err), http.StatusBadRequest) + return + } + } + + rankings, err := handleRankings(uuid, category, page) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } + + response, err := json.Marshal(rankings) + if err != nil { + httpError(w, r, fmt.Errorf("failed to marshal response json: %s", err), http.StatusInternalServerError) + return + } + + w.Write(response) case "/daily/rankingpagecount": - s.handleRankingPageCount(w, r) + var category int + if r.URL.Query().Has("category") { + var err error + category, err = strconv.Atoi(r.URL.Query().Get("category")) + if err != nil { + httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest) + return + } + } + + count, err := handleRankingPageCount(category) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + } + + w.Write([]byte(strconv.Itoa(count))) } } diff --git a/api/savedata.go b/api/savedata.go index 388366b..e52e1ed 100644 --- a/api/savedata.go +++ b/api/savedata.go @@ -4,10 +4,8 @@ import ( "bytes" "encoding/gob" "encoding/hex" - "encoding/json" "fmt" "log" - "net/http" "os" "strconv" @@ -19,228 +17,146 @@ import ( const sessionSlotCount = 3 // /savedata/get - get save data -func (s *Server) handleSavedataGet(w http.ResponseWriter, r *http.Request) { - uuid, err := getUUIDFromRequest(r) - if err != nil { - httpError(w, r, err.Error(), http.StatusBadRequest) - return - } - - switch r.URL.Query().Get("datatype") { - case "0": // System +func handleSavedataGet(uuid []byte, datatype, slot int) (any, error) { + switch datatype { + case 0: // System system, err := readSystemSaveData(uuid) if err != nil { - httpError(w, r, err.Error(), http.StatusInternalServerError) - return + return nil, err } - saveJson, err := json.Marshal(system) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to marshal save to json: %s", err), http.StatusInternalServerError) - return + return system, nil + case 1: // Session + if slot < 0 || slot >= sessionSlotCount { + return nil, fmt.Errorf("slot id %d out of range", slot) } - w.Write(saveJson) - case "1": // Session - slotID, err := strconv.Atoi(r.URL.Query().Get("slot")) + session, err := readSessionSaveData(uuid, slot) if err != nil { - httpError(w, r, fmt.Sprintf("failed to convert slot id: %s", err), http.StatusBadRequest) - return - } - - if slotID < 0 || slotID >= sessionSlotCount { - httpError(w, r, fmt.Sprintf("slot id %d out of range", slotID), http.StatusBadRequest) - return + return nil, err } - session, err := readSessionSaveData(uuid, slotID) - if err != nil { - httpError(w, r, err.Error(), http.StatusInternalServerError) - return - } - - saveJson, err := json.Marshal(session) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to marshal save to json: %s", err), http.StatusInternalServerError) - return - } - - w.Write(saveJson) + return session, nil default: - httpError(w, r, "invalid data type", http.StatusBadRequest) - return + return nil, fmt.Errorf("invalid data type") } } // /savedata/update - update save data -func (s *Server) handleSavedataUpdate(w http.ResponseWriter, r *http.Request) { - uuid, err := getUUIDFromRequest(r) - if err != nil { - httpError(w, r, err.Error(), http.StatusBadRequest) - return - } - - err = db.UpdateAccountLastActivity(uuid) +func handleSavedataUpdate(uuid []byte, slot int, save any) error { + err := db.UpdateAccountLastActivity(uuid) if err != nil { log.Print("failed to update account last activity") } hexUUID := hex.EncodeToString(uuid) - switch r.URL.Query().Get("datatype") { - case "0": // System - var system defs.SystemSaveData - err = json.NewDecoder(r.Body).Decode(&system) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest) - return - } - - if system.TrainerID == 0 && system.SecretID == 0 { - httpError(w, r, "invalid system data", http.StatusInternalServerError) - return + switch save := save.(type) { + case defs.SystemSaveData: // System + if save.TrainerID == 0 && save.SecretID == 0 { + return fmt.Errorf("invalid system data") } - err = db.UpdateAccountStats(uuid, system.GameStats) + err = db.UpdateAccountStats(uuid, save.GameStats) if err != nil { - httpError(w, r, fmt.Sprintf("failed to update account stats: %s", err), http.StatusBadRequest) - return + return fmt.Errorf("failed to update account stats: %s", err) } var gobBuffer bytes.Buffer - err = gob.NewEncoder(&gobBuffer).Encode(system) + err = gob.NewEncoder(&gobBuffer).Encode(save) if err != nil { - httpError(w, r, fmt.Sprintf("failed to serialize save: %s", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to serialize save: %s", err) } zstdWriter, err := zstd.NewWriter(nil) if err != nil { - httpError(w, r, fmt.Sprintf("failed to create zstd writer, %s", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to create zstd writer, %s", err) } compressed := zstdWriter.EncodeAll(gobBuffer.Bytes(), nil) err = os.MkdirAll("userdata/"+hexUUID, 0755) if err != nil && !os.IsExist(err) { - httpError(w, r, fmt.Sprintf("failed to create userdata folder: %s", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to create userdata folder: %s", err) } err = os.WriteFile("userdata/"+hexUUID+"/system.pzs", compressed, 0644) if err != nil { - httpError(w, r, fmt.Sprintf("failed to write save file: %s", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to write save file: %s", err) } - case "1": // Session - slotID, err := strconv.Atoi(r.URL.Query().Get("slot")) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to convert slot id: %s", err), http.StatusBadRequest) - return - } - - if slotID < 0 || slotID >= sessionSlotCount { - httpError(w, r, fmt.Sprintf("slot id %d out of range", slotID), http.StatusBadRequest) - return + case defs.SessionSaveData: // Session + if slot < 0 || slot >= sessionSlotCount { + return fmt.Errorf("slot id %d out of range", slot) } fileName := "session" - if slotID != 0 { - fileName += strconv.Itoa(slotID) - } - - var session defs.SessionSaveData - err = json.NewDecoder(r.Body).Decode(&session) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest) - return + if slot != 0 { + fileName += strconv.Itoa(slot) } var gobBuffer bytes.Buffer - err = gob.NewEncoder(&gobBuffer).Encode(session) + err = gob.NewEncoder(&gobBuffer).Encode(save) if err != nil { - httpError(w, r, fmt.Sprintf("failed to serialize save: %s", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to serialize save: %s", err) } zstdWriter, err := zstd.NewWriter(nil) if err != nil { - httpError(w, r, fmt.Sprintf("failed to create zstd writer, %s", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to create zstd writer, %s", err) } compressed := zstdWriter.EncodeAll(gobBuffer.Bytes(), nil) err = os.MkdirAll("userdata/"+hexUUID, 0755) if err != nil && !os.IsExist(err) { - httpError(w, r, fmt.Sprintf("failed to create userdata folder: %s", err), http.StatusInternalServerError) - return + return fmt.Errorf(fmt.Sprintf("failed to create userdata folder: %s", err)) } err = os.WriteFile(fmt.Sprintf("userdata/%s/%s.pzs", hexUUID, fileName), compressed, 0644) if err != nil { - httpError(w, r, fmt.Sprintf("failed to write save file: %s", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to write save file: %s", err) } default: - httpError(w, r, "invalid data type", http.StatusBadRequest) - return + return fmt.Errorf("invalid data type") } - w.WriteHeader(http.StatusOK) + return nil } // /savedata/delete - delete save data -func (s *Server) handleSavedataDelete(w http.ResponseWriter, r *http.Request) { - uuid, err := getUUIDFromRequest(r) - if err != nil { - httpError(w, r, err.Error(), http.StatusBadRequest) - return - } - - err = db.UpdateAccountLastActivity(uuid) +func handleSavedataDelete(uuid []byte, datatype, slot int) error { + err := db.UpdateAccountLastActivity(uuid) if err != nil { log.Print("failed to update account last activity") } hexUUID := hex.EncodeToString(uuid) - switch r.URL.Query().Get("datatype") { - case "0": // System + switch datatype { + case 0: // System err := os.Remove("userdata/" + hexUUID + "/system.pzs") if err != nil && !os.IsNotExist(err) { - httpError(w, r, fmt.Sprintf("failed to delete save file: %s", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to delete save file: %s", err) } - case "1": // Session - slotID, err := strconv.Atoi(r.URL.Query().Get("slot")) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to convert slot id: %s", err), http.StatusBadRequest) - return - } - - if slotID < 0 || slotID >= sessionSlotCount { - httpError(w, r, fmt.Sprintf("slot id %d out of range", slotID), http.StatusBadRequest) - return + case 1: // Session + if slot < 0 || slot >= sessionSlotCount { + return fmt.Errorf("slot id %d out of range", slot) } fileName := "session" - if slotID != 0 { - fileName += strconv.Itoa(slotID) + if slot != 0 { + fileName += strconv.Itoa(slot) } err = os.Remove(fmt.Sprintf("userdata/%s/%s.pzs", hexUUID, fileName)) if err != nil && !os.IsNotExist(err) { - httpError(w, r, fmt.Sprintf("failed to delete save file: %s", err), http.StatusInternalServerError) - return + return fmt.Errorf("failed to delete save file: %s", err) } default: - httpError(w, r, "invalid data type", http.StatusBadRequest) - return + return fmt.Errorf("invalid data type") } - w.WriteHeader(http.StatusOK) + return nil } type SavedataClearResponse struct { @@ -248,73 +164,46 @@ type SavedataClearResponse struct { } // /savedata/clear - mark session save data as cleared and delete -func (s *Server) handleSavedataClear(w http.ResponseWriter, r *http.Request) { - uuid, err := getUUIDFromRequest(r) - if err != nil { - httpError(w, r, err.Error(), http.StatusBadRequest) - return - } - - err = db.UpdateAccountLastActivity(uuid) +func handleSavedataClear(uuid []byte, slot int, save defs.SessionSaveData) (SavedataClearResponse, error) { + err := db.UpdateAccountLastActivity(uuid) if err != nil { log.Print("failed to update account last activity") } - slotID, err := strconv.Atoi(r.URL.Query().Get("slot")) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to convert slot id: %s", err), http.StatusBadRequest) - return - } - - if slotID < 0 || slotID >= sessionSlotCount { - httpError(w, r, fmt.Sprintf("slot id %d out of range", slotID), http.StatusBadRequest) - return - } - - var session defs.SessionSaveData - err = json.NewDecoder(r.Body).Decode(&session) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest) - return + if slot < 0 || slot >= sessionSlotCount { + return SavedataClearResponse{}, fmt.Errorf("slot id %d out of range", slot) } - sessionCompleted := validateSessionCompleted(session) + sessionCompleted := validateSessionCompleted(save) newCompletion := false - if session.GameMode == 3 && session.Seed == dailyRunSeed { - waveCompleted := session.WaveIndex + if save.GameMode == 3 && save.Seed == dailyRunSeed { + waveCompleted := save.WaveIndex if !sessionCompleted { waveCompleted-- } - err = db.AddOrUpdateAccountDailyRun(uuid, session.Score, waveCompleted) + err = db.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted) if err != nil { log.Printf("failed to add or update daily run record: %s", err) } } if sessionCompleted { - newCompletion, err = db.TryAddSeedCompletion(uuid, session.Seed, int(session.GameMode)) + newCompletion, err = db.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode)) if err != nil { log.Printf("failed to mark seed as completed: %s", err) } } - response, err := json.Marshal(SavedataClearResponse{Success: newCompletion}) - if err != nil { - httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError) - return - } - fileName := "session" - if slotID != 0 { - fileName += strconv.Itoa(slotID) + if slot != 0 { + fileName += strconv.Itoa(slot) } err = os.Remove(fmt.Sprintf("userdata/%s/%s.pzs", hex.EncodeToString(uuid), fileName)) if err != nil && !os.IsNotExist(err) { - httpError(w, r, fmt.Sprintf("failed to delete save file: %s", err), http.StatusInternalServerError) - return + return SavedataClearResponse{}, fmt.Errorf("failed to delete save file: %s", err) } - w.Write(response) + return SavedataClearResponse{Success: newCompletion}, nil }