diff --git a/api/savedata.go b/api/savedata.go index f7b6827..31656b1 100644 --- a/api/savedata.go +++ b/api/savedata.go @@ -25,6 +25,16 @@ func handleSavedataGet(uuid []byte, datatype, slot int) (any, error) { return nil, err } + compensations, err := db.FetchAndClaimAccountCompensations(uuid) + if err != nil { + return nil, fmt.Errorf("failed to fetch compensations: %s", err) + } + + for k, v := range compensations { + typeKey := strconv.Itoa(k) + system.VoucherCounts[typeKey] += v + } + return system, nil case 1: // Session if slot < 0 || slot >= sessionSlotCount { @@ -116,6 +126,8 @@ func handleSavedataUpdate(uuid []byte, slot int, save any) error { if err != nil { return fmt.Errorf("failed to write save file: %s", err) } + + db.DeleteClaimedAccountCompensations(uuid) default: return fmt.Errorf("invalid data type") } diff --git a/db/account.go b/db/account.go index 7319451..fe94953 100644 --- a/db/account.go +++ b/db/account.go @@ -100,6 +100,43 @@ func UpdateAccountStats(uuid []byte, stats defs.GameStats) error { return nil } +func FetchAndClaimAccountCompensations(uuid []byte) (map[int]int, error) { + var compensations = make(map[int]int) + + results, err := handle.Query("SELECT voucherType, count FROM accountCompensations WHERE uuid = ?", uuid) + if err != nil { + return nil, err + } + + defer results.Close() + + for results.Next() { + var voucherType int + var count int + err := results.Scan(&voucherType, &count) + if err != nil { + return compensations, err + } + compensations[voucherType] = count + } + + _, err = handle.Exec("UPDATE accountCompensations SET claimed = 1 WHERE uuid = ?", uuid) + if err != nil { + return compensations, err + } + + return compensations, nil +} + +func DeleteClaimedAccountCompensations(uuid []byte) error { + _, err := handle.Exec("DELETE FROM accountCompensations WHERE uuid = ? AND claimed = 1", uuid) + if err != nil { + return err + } + + return nil +} + func FetchUsernameFromToken(token []byte) (string, error) { var username string err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON s.uuid = a.uuid WHERE s.token = ? AND s.expire > UTC_TIMESTAMP()", token).Scan(&username)