diff --git a/api/account/common.go b/api/account/common.go index 9102f4e..8e8c0ff 100644 --- a/api/account/common.go +++ b/api/account/common.go @@ -19,6 +19,7 @@ package account import ( "regexp" + "runtime" "golang.org/x/crypto/argon2" ) @@ -34,13 +35,13 @@ const ( ArgonKeySize = 32 ArgonSaltSize = 16 - ArgonMaxInstances = 16 - UUIDSize = 16 TokenSize = 32 ) var ( + ArgonMaxInstances = runtime.NumCPU() + isValidUsername = regexp.MustCompile(`^\w{1,16}$`).MatchString semaphore = make(chan bool, ArgonMaxInstances) ) diff --git a/api/common.go b/api/common.go index 62ef843..920d41e 100644 --- a/api/common.go +++ b/api/common.go @@ -40,7 +40,6 @@ func Init(mux *http.ServeMux) { mux.HandleFunc("GET /account/logout", handleAccountLogout) // game - mux.HandleFunc("GET /game/playercount", handleGamePlayerCount) mux.HandleFunc("GET /game/titlestats", handleGameTitleStats) mux.HandleFunc("GET /game/classicsessioncount", handleGameClassicSessionCount) diff --git a/api/endpoints.go b/api/endpoints.go index 7da61e0..c23f693 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -18,7 +18,8 @@ package api import ( - "encoding/base64" + "encoding/base64" + "database/sql" "encoding/json" "fmt" "net/http" @@ -145,10 +146,6 @@ func handleAccountLogout(w http.ResponseWriter, r *http.Request) { // game -func handleGamePlayerCount(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(strconv.Itoa(playerCount))) -} - func handleGameTitleStats(w http.ResponseWriter, r *http.Request) { err := json.NewEncoder(w).Encode(defs.TitleStats{ PlayerCount: playerCount, @@ -285,6 +282,10 @@ func handleSaveData(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/savedata/get": save, err = savedata.Get(uuid, datatype, slot) + if err == sql.ErrNoRows { + http.Error(w, err.Error(), http.StatusNotFound) + return + } case "/savedata/update": err = savedata.Update(uuid, slot, save) case "/savedata/delete": diff --git a/db/db.go b/db/db.go index c59f5ef..f246267 100644 --- a/db/db.go +++ b/db/db.go @@ -21,9 +21,11 @@ import ( "database/sql" "encoding/hex" "fmt" - _ "github.com/go-sql-driver/mysql" "log" "os" + "time" + + _ "github.com/go-sql-driver/mysql" ) var handle *sql.DB @@ -35,8 +37,16 @@ func Init(username, password, protocol, address, database string) error { if err != nil { return fmt.Errorf("failed to open database connection: %s", err) } + + conns := 1024 + if protocol != "unix" { + conns = 256 + } - handle.SetMaxOpenConns(1000) + handle.SetMaxOpenConns(conns) + handle.SetMaxIdleConns(conns/4) + + handle.SetConnMaxIdleTime(time.Second * 10) tx, err := handle.Begin() if err != nil { @@ -105,6 +115,12 @@ func Init(username, password, protocol, address, database string) error { continue } + var count int + err = handle.QueryRow("SELECT COUNT(*) FROM systemSaveData WHERE uuid = ?", uuid).Scan(&count) + if err != nil || count != 0 { + continue + } + // store new system data systemData, err := LegacyReadSystemSaveData(uuid) if err != nil { diff --git a/db/savedata.go b/db/savedata.go index 881345f..dc2c860 100644 --- a/db/savedata.go +++ b/db/savedata.go @@ -65,7 +65,7 @@ func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error { return err } - _, err = handle.Exec("REPLACE INTO systemSaveData (uuid, data, timestamp) VALUES (?, ?, UTC_TIMESTAMP())", uuid, buf.Bytes()) + _, err = handle.Exec("INSERT INTO systemSaveData (uuid, data, timestamp) VALUES (?, ?, UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE data = ?, timestamp = UTC_TIMESTAMP()", uuid, buf.Bytes(), buf.Bytes()) if err != nil { return err } @@ -116,7 +116,7 @@ func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) erro return err } - _, err = handle.Exec("REPLACE INTO sessionSaveData (uuid, slot, data, timestamp) VALUES (?, ?, ?, UTC_TIMESTAMP())", uuid, slot, buf.Bytes()) + _, err = handle.Exec("INSERT INTO sessionSaveData (uuid, slot, data, timestamp) VALUES (?, ?, ?, UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE data = ?, timestamp = UTC_TIMESTAMP()", uuid, slot, buf.Bytes(), buf.Bytes()) if err != nil { return err } diff --git a/rogueserver.go b/rogueserver.go index 8686c94..91aa1a2 100644 --- a/rogueserver.go +++ b/rogueserver.go @@ -42,6 +42,9 @@ func main() { dbaddr := flag.String("dbaddr", "localhost", "database address") dbname := flag.String("dbname", "pokeroguedb", "database name") + tlscert := flag.String("tlscert", "", "tls certificate path") + tlskey := flag.String("tlskey", "", "tls key path") + flag.Parse() // register gob types @@ -66,10 +69,15 @@ func main() { api.Init(mux) // start web server + handler := prodHandler(mux) if *debug { - err = http.Serve(listener, debugHandler(mux)) + handler = debugHandler(mux) + } + + if *tlscert == "" { + err = http.Serve(listener, handler) } else { - err = http.Serve(listener, mux) + err = http.ServeTLS(listener, handler, *tlscert, *tlskey) } if err != nil { log.Fatalf("failed to create http server or server errored: %s", err) @@ -93,6 +101,21 @@ func createListener(proto, addr string) (net.Listener, error) { return listener, nil } +func prodHandler(router *http.ServeMux) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + w.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET, POST") + w.Header().Set("Access-Control-Allow-Origin", "https://pokerogue.net") + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + router.ServeHTTP(w, r) + }) +} + func debugHandler(router *http.ServeMux) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Headers", "*") @@ -107,3 +130,4 @@ func debugHandler(router *http.ServeMux) http.Handler { router.ServeHTTP(w, r) }) } +