You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

200 lines
4.5 KiB
Go

package db
import (
"database/sql"
"fmt"
"slices"
"github.com/Flashfyre/pokerogue-server/defs"
_ "github.com/go-sql-driver/mysql"
)
func AddAccountRecord(uuid []byte, username string, key, salt []byte) error {
_, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
if err != nil {
return err
}
return nil
}
func AddAccountSession(username string, token []byte) error {
_, err := handle.Exec("INSERT INTO sessions (uuid, token, expire) SELECT a.uuid, ?, DATE_ADD(UTC_TIMESTAMP(), INTERVAL 1 WEEK) FROM accounts a WHERE a.username = ?", token, username)
if err != nil {
return err
}
_, err = handle.Exec("UPDATE accounts SET lastLoggedIn = UTC_TIMESTAMP() WHERE username = ?", username)
if err != nil {
return err
}
return nil
}
func UpdateAccountLastActivity(uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET lastActivity = UTC_TIMESTAMP() WHERE uuid = ?", uuid)
if err != nil {
return err
}
return nil
}
func UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error {
var columns = []string{"playTime", "battles", "classicSessionsPlayed", "sessionsWon", "highestEndlessWave", "highestLevel", "pokemonSeen", "pokemonDefeated", "pokemonCaught", "pokemonHatched", "eggsPulled", "regularVouchers", "plusVouchers", "premiumVouchers", "goldenVouchers"}
var statCols []string
var statValues []interface{}
m, ok := stats.(map[string]interface{})
if !ok {
return fmt.Errorf("expected map[string]interface{}, got %T", stats)
}
for k, v := range m {
value, ok := v.(float64)
if !ok {
return fmt.Errorf("expected float64, got %T", v)
}
if slices.Contains(columns, k) {
statCols = append(statCols, k)
statValues = append(statValues, value)
}
}
for k, v := range voucherCounts {
var column string
switch k {
case "0":
column = "regularVouchers"
case "1":
column = "plusVouchers"
case "2":
column = "premiumVouchers"
case "3":
column = "goldenVouchers"
default:
continue
}
statCols = append(statCols, column)
statValues = append(statValues, v)
}
var statArgs []interface{}
statArgs = append(statArgs, uuid)
for range 2 {
statArgs = append(statArgs, statValues...)
}
query := "INSERT INTO accountStats (uuid"
for _, col := range statCols {
query += ", " + col
}
query += ") VALUES (?"
for range len(statCols) {
query += ", ?"
}
query += ") ON DUPLICATE KEY UPDATE "
for i, col := range statCols {
if i > 0 {
query += ", "
}
query += col + " = ?"
}
_, err := handle.Exec(query, statArgs...)
if err != nil {
return err
}
return nil
}
func FetchAndClaimAccountCompensations(uuid []byte) (map[int]int, error) {
var compensations = make(map[int]int)
results, err := handle.Query("SELECT voucherType, count FROM accountCompensations WHERE uuid = ?", uuid)
if err != nil {
return nil, err
}
defer results.Close()
for results.Next() {
var voucherType int
var count int
err := results.Scan(&voucherType, &count)
if err != nil {
return compensations, err
}
compensations[voucherType] = count
}
_, err = handle.Exec("UPDATE accountCompensations SET claimed = 1 WHERE uuid = ?", uuid)
if err != nil {
return compensations, err
}
return compensations, nil
}
func DeleteClaimedAccountCompensations(uuid []byte) error {
_, err := handle.Exec("DELETE FROM accountCompensations WHERE uuid = ? AND claimed = 1", uuid)
if err != nil {
return err
}
return nil
}
func FetchUsernameFromToken(token []byte) (string, error) {
var username string
err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON s.uuid = a.uuid WHERE s.token = ? AND s.expire > UTC_TIMESTAMP()", token).Scan(&username)
if err != nil {
return "", err
}
return username, nil
}
func FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
var key, salt []byte
err := handle.QueryRow("SELECT hash, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt)
if err != nil {
return nil, nil, err
}
return key, salt, nil
}
func FetchUUIDFromToken(token []byte) ([]byte, error) {
var uuid []byte
err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ? AND expire > UTC_TIMESTAMP()", token).Scan(&uuid)
if err != nil {
if err == sql.ErrNoRows {
return nil, err
}
return nil, err
}
return uuid, nil
}
func RemoveSessionFromToken(token []byte) error {
_, err := handle.Exec("DELETE FROM sessions WHERE token = ?", token)
if err != nil {
return err
}
return nil
}