diff --git a/api/account/common.go b/api/account/common.go index 9102f4e..8e8c0ff 100644 --- a/api/account/common.go +++ b/api/account/common.go @@ -19,6 +19,7 @@ package account import ( "regexp" + "runtime" "golang.org/x/crypto/argon2" ) @@ -34,13 +35,13 @@ const ( ArgonKeySize = 32 ArgonSaltSize = 16 - ArgonMaxInstances = 16 - UUIDSize = 16 TokenSize = 32 ) var ( + ArgonMaxInstances = runtime.NumCPU() + isValidUsername = regexp.MustCompile(`^\w{1,16}$`).MatchString semaphore = make(chan bool, ArgonMaxInstances) ) diff --git a/api/common.go b/api/common.go index 62ef843..920d41e 100644 --- a/api/common.go +++ b/api/common.go @@ -40,7 +40,6 @@ func Init(mux *http.ServeMux) { mux.HandleFunc("GET /account/logout", handleAccountLogout) // game - mux.HandleFunc("GET /game/playercount", handleGamePlayerCount) mux.HandleFunc("GET /game/titlestats", handleGameTitleStats) mux.HandleFunc("GET /game/classicsessioncount", handleGameClassicSessionCount) diff --git a/api/endpoints.go b/api/endpoints.go index 7da61e0..2227fa6 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -18,7 +18,7 @@ package api import ( - "encoding/base64" + "database/sql" "encoding/json" "fmt" "net/http" @@ -145,10 +145,6 @@ func handleAccountLogout(w http.ResponseWriter, r *http.Request) { // game -func handleGamePlayerCount(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(strconv.Itoa(playerCount))) -} - func handleGameTitleStats(w http.ResponseWriter, r *http.Request) { err := json.NewEncoder(w).Encode(defs.TitleStats{ PlayerCount: playerCount, @@ -285,6 +281,10 @@ func handleSaveData(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/savedata/get": save, err = savedata.Get(uuid, datatype, slot) + if err == sql.ErrNoRows { + http.Error(w, err.Error(), http.StatusNotFound) + return + } case "/savedata/update": err = savedata.Update(uuid, slot, save) case "/savedata/delete": @@ -338,16 +338,8 @@ func handleDailySeed(w http.ResponseWriter, r *http.Request) { httpError(w, r, err, http.StatusInternalServerError) return } - bytes, err := base64.StdEncoding.DecodeString(seed) - if err != nil { - httpError(w, r, err, http.StatusInternalServerError) - return - } - _, err = w.Write(bytes) - if err != nil { - httpError(w, r, err, http.StatusInternalServerError) - return - } + + w.Write([]byte(seed)) } func handleDailyRankings(w http.ResponseWriter, r *http.Request) { diff --git a/db/db.go b/db/db.go index 3628d83..f246267 100644 --- a/db/db.go +++ b/db/db.go @@ -21,9 +21,11 @@ import ( "database/sql" "encoding/hex" "fmt" - _ "github.com/go-sql-driver/mysql" "log" "os" + "time" + + _ "github.com/go-sql-driver/mysql" ) var handle *sql.DB @@ -35,8 +37,16 @@ func Init(username, password, protocol, address, database string) error { if err != nil { return fmt.Errorf("failed to open database connection: %s", err) } + + conns := 1024 + if protocol != "unix" { + conns = 256 + } - handle.SetMaxOpenConns(1000) + handle.SetMaxOpenConns(conns) + handle.SetMaxIdleConns(conns/4) + + handle.SetConnMaxIdleTime(time.Second * 10) tx, err := handle.Begin() if err != nil { @@ -45,7 +55,6 @@ func Init(username, password, protocol, address, database string) error { // accounts tx.Exec("CREATE TABLE IF NOT EXISTS accounts (uuid BINARY(16) NOT NULL PRIMARY KEY, username VARCHAR(16) UNIQUE NOT NULL, hash BINARY(32) NOT NULL, salt BINARY(16) NOT NULL, registered TIMESTAMP NOT NULL, lastLoggedIn TIMESTAMP DEFAULT NULL, lastActivity TIMESTAMP DEFAULT NULL, banned TINYINT(1) NOT NULL DEFAULT 0, trainerId SMALLINT(5) UNSIGNED DEFAULT 0, secretId SMALLINT(5) UNSIGNED DEFAULT 0)") - tx.Exec("CREATE UNIQUE INDEX IF NOT EXISTS accountsByUsername ON accounts (username)") // sessions tx.Exec("CREATE TABLE IF NOT EXISTS sessions (token BINARY(32) NOT NULL PRIMARY KEY, uuid BINARY(16) NOT NULL, active TINYINT(1) NOT NULL DEFAULT 0, expire TIMESTAMP DEFAULT NULL, CONSTRAINT sessions_ibfk_1 FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE)") @@ -65,7 +74,7 @@ func Init(username, password, protocol, address, database string) error { tx.Exec("CREATE TABLE IF NOT EXISTS dailyRunCompletions (uuid BINARY(16) NOT NULL, seed CHAR(24) CHARACTER SET ascii COLLATE ascii_bin NOT NULL, mode INT(11) NOT NULL DEFAULT 0, score INT(11) NOT NULL DEFAULT 0, timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (uuid, seed), CONSTRAINT dailyRunCompletions_ibfk_1 FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE)") tx.Exec("CREATE INDEX IF NOT EXISTS dailyRunCompletionsByUuidAndSeed ON dailyRunCompletions (uuid, seed)") - tx.Exec("CREATE TABLE IF NOT EXISTS accountDailyRuns (uuid BINARY(16) NOT NULL, date DATE NOT NULL, score INT(11) NOT NULL DEFAULT 0, WAVE INT(11) NOT NULL DEFAULT 0, timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (uuid, date), CONSTRAINT accountDailyRuns_ibfk_1 FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT accountDailyRuns_ibfk_2 FOREIGN KEY (date) REFERENCES dailyRuns (date) ON DELETE NO ACTION ON UPDATE NO ACTION)") + tx.Exec("CREATE TABLE IF NOT EXISTS accountDailyRuns (uuid BINARY(16) NOT NULL, date DATE NOT NULL, score INT(11) NOT NULL DEFAULT 0, wave INT(11) NOT NULL DEFAULT 0, timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (uuid, date), CONSTRAINT accountDailyRuns_ibfk_1 FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT accountDailyRuns_ibfk_2 FOREIGN KEY (date) REFERENCES dailyRuns (date) ON DELETE NO ACTION ON UPDATE NO ACTION)") tx.Exec("CREATE INDEX IF NOT EXISTS accountDailyRunsByDate ON accountDailyRuns (date)") // save data @@ -106,6 +115,12 @@ func Init(username, password, protocol, address, database string) error { continue } + var count int + err = handle.QueryRow("SELECT COUNT(*) FROM systemSaveData WHERE uuid = ?", uuid).Scan(&count) + if err != nil || count != 0 { + continue + } + // store new system data systemData, err := LegacyReadSystemSaveData(uuid) if err != nil { diff --git a/db/savedata.go b/db/savedata.go index 881345f..dc2c860 100644 --- a/db/savedata.go +++ b/db/savedata.go @@ -65,7 +65,7 @@ func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error { return err } - _, err = handle.Exec("REPLACE INTO systemSaveData (uuid, data, timestamp) VALUES (?, ?, UTC_TIMESTAMP())", uuid, buf.Bytes()) + _, err = handle.Exec("INSERT INTO systemSaveData (uuid, data, timestamp) VALUES (?, ?, UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE data = ?, timestamp = UTC_TIMESTAMP()", uuid, buf.Bytes(), buf.Bytes()) if err != nil { return err } @@ -116,7 +116,7 @@ func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) erro return err } - _, err = handle.Exec("REPLACE INTO sessionSaveData (uuid, slot, data, timestamp) VALUES (?, ?, ?, UTC_TIMESTAMP())", uuid, slot, buf.Bytes()) + _, err = handle.Exec("INSERT INTO sessionSaveData (uuid, slot, data, timestamp) VALUES (?, ?, ?, UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE data = ?, timestamp = UTC_TIMESTAMP()", uuid, slot, buf.Bytes(), buf.Bytes()) if err != nil { return err } diff --git a/rogueserver.go b/rogueserver.go index 8686c94..91aa1a2 100644 --- a/rogueserver.go +++ b/rogueserver.go @@ -42,6 +42,9 @@ func main() { dbaddr := flag.String("dbaddr", "localhost", "database address") dbname := flag.String("dbname", "pokeroguedb", "database name") + tlscert := flag.String("tlscert", "", "tls certificate path") + tlskey := flag.String("tlskey", "", "tls key path") + flag.Parse() // register gob types @@ -66,10 +69,15 @@ func main() { api.Init(mux) // start web server + handler := prodHandler(mux) if *debug { - err = http.Serve(listener, debugHandler(mux)) + handler = debugHandler(mux) + } + + if *tlscert == "" { + err = http.Serve(listener, handler) } else { - err = http.Serve(listener, mux) + err = http.ServeTLS(listener, handler, *tlscert, *tlskey) } if err != nil { log.Fatalf("failed to create http server or server errored: %s", err) @@ -93,6 +101,21 @@ func createListener(proto, addr string) (net.Listener, error) { return listener, nil } +func prodHandler(router *http.ServeMux) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + w.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET, POST") + w.Header().Set("Access-Control-Allow-Origin", "https://pokerogue.net") + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + router.ServeHTTP(w, r) + }) +} + func debugHandler(router *http.ServeMux) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Headers", "*") @@ -107,3 +130,4 @@ func debugHandler(router *http.ServeMux) http.Handler { router.ServeHTTP(w, r) }) } +