From 8e20875453bc3031d7954becb9ec4affb3a59196 Mon Sep 17 00:00:00 2001 From: Frederico Santos Date: Sat, 14 Sep 2024 03:32:49 +0100 Subject: [PATCH] feat: Add admin Discord link endpoint (#49) * feat: Add admin Discord link endpoint * feat: Add Discord Guild ID flag to server configuration * feat: Add logging for Discord ID addition in admin Discord link endpoint * chore: Update variable name for Discord guild ID in account package * chore: Add logging for Discord ID addition in admin Discord link endpoint * chore: Add admin Discord link endpoint * chore: Add logging for Discord ID addition in admin Discord link endpoint * chore: Remove unnecessary code in handleAdminDiscordLink function * chore: Update logging format in handleAdminDiscordLink function * chore: Refactor handleAdminDiscordLink function for improved logging * chore: Update Discord Bot Token and Discord Guild ID flags in server configuration * chore: Refactor handleAccountInfo function for improved readability and error handling * chore: Update server configuration flags for Discord Bot Token and Guild ID * Refactor handleAdminDiscordLink function for improved error handling and logging * feat: Add "Helper" role to Discord admin check for enhanced access control --- api/account/discord.go | 37 +++++++++++++++++++- api/account/info.go | 4 ++- api/common.go | 4 +++ api/endpoints.go | 45 +++++++++++++++++++++++- db/account.go | 28 +++++++++++++++ go.mod | 6 +++- go.sum | 10 ++++++ rogueserver.go | 79 ++++++++++++++++++++++++------------------ 8 files changed, 175 insertions(+), 38 deletions(-) diff --git a/api/account/discord.go b/api/account/discord.go index 792afe6..bf2a00f 100644 --- a/api/account/discord.go +++ b/api/account/discord.go @@ -22,12 +22,17 @@ import ( "errors" "net/http" "net/url" + + "github.com/bwmarrin/discordgo" ) var ( DiscordClientID string DiscordClientSecret string DiscordCallbackURL string + + DiscordSession *discordgo.Session + DiscordGuildID string ) func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) { @@ -36,7 +41,6 @@ func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, erro http.Redirect(w, r, GameURL, http.StatusSeeOther) return "", errors.New("code is empty") } - discordId, err := RetrieveDiscordId(code) if err != nil { http.Redirect(w, r, GameURL, http.StatusSeeOther) @@ -106,3 +110,34 @@ func RetrieveDiscordId(code string) (string, error) { return user.Id, nil } + +func IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) { + // fetch all roles from discord + roles, err := DiscordSession.GuildRoles(discordGuildID) + if err != nil { + return false, err + } + + // fetch all roles from user + userRoles, err := DiscordSession.GuildMember(discordGuildID, discordId) + if err != nil { + return false, err + } + + // check if user has a "Dev" or a "Division Heads" role + var hasRole bool + for _, role := range userRoles.Roles { + for _, guildRole := range roles { + if role == guildRole.ID && (guildRole.Name == "Dev" || guildRole.Name == "Division Heads" || guildRole.Name == "Helper") { + hasRole = true + break + } + } + } + + if !hasRole { + return false, nil + } + + return true, nil +} diff --git a/api/account/info.go b/api/account/info.go index 6802238..2d29a78 100644 --- a/api/account/info.go +++ b/api/account/info.go @@ -26,16 +26,18 @@ type InfoResponse struct { DiscordId string `json:"discordId"` GoogleId string `json:"googleId"` LastSessionSlot int `json:"lastSessionSlot"` + HasAdminRole bool `json:"hasAdminRole"` } // /account/info - get account info -func Info(username string, discordId string, googleId string, uuid []byte) (InfoResponse, error) { +func Info(username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) { slot, _ := db.GetLatestSessionSaveDataSlot(uuid) response := InfoResponse{ Username: username, LastSessionSlot: slot, DiscordId: discordId, GoogleId: googleId, + HasAdminRole: hasAdminRole, } return response, nil } diff --git a/api/common.go b/api/common.go index 7e3c2aa..6abad39 100644 --- a/api/common.go +++ b/api/common.go @@ -66,6 +66,10 @@ func Init(mux *http.ServeMux) error { // auth mux.HandleFunc("/auth/{provider}/callback", handleProviderCallback) mux.HandleFunc("/auth/{provider}/logout", handleProviderLogout) + + // admin + mux.HandleFunc("POST /admin/account/discord-link", handleAdminDiscordLink) + return nil } diff --git a/api/endpoints.go b/api/endpoints.go index b9802fe..d9fb802 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -23,6 +23,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/http" "strconv" "strings" @@ -68,7 +69,13 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) { return } } - response, err := account.Info(username, discordId, googleId, uuid) + + var hasAdminRole bool + if discordId != "" { + hasAdminRole, _ = account.IsUserDiscordAdmin(discordId, account.DiscordGuildID) + } + + response, err := account.Info(username, discordId, googleId, uuid, hasAdminRole) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -660,3 +667,39 @@ func handleProviderLogout(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusOK) } + +func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + if err != nil { + httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest) + return + } + + uuid, err := uuidFromRequest(r) + if err != nil { + httpError(w, r, err, http.StatusUnauthorized) + return + } + + userDiscordId, err := db.FetchDiscordIdByUUID(uuid) + if err != nil { + httpError(w, r, err, http.StatusUnauthorized) + return + } + + hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) + if !hasRole || err != nil { + httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) + return + } + + err = db.AddDiscordIdByUsername(r.Form.Get("discordId"), r.Form.Get("username")) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } + + log.Printf("%s: %s added discord id %s to username %s", r.URL.Path, userDiscordId, r.Form.Get("discordId"), r.Form.Get("username")) + + w.WriteHeader(http.StatusOK) +} diff --git a/db/account.go b/db/account.go index d463211..d628df7 100644 --- a/db/account.go +++ b/db/account.go @@ -116,6 +116,34 @@ func FetchGoogleIdByUsername(username string) (string, error) { return googleId.String, nil } +func FetchDiscordIdByUUID(uuid []byte) (string, error) { + var discordId sql.NullString + err := handle.QueryRow("SELECT discordId FROM accounts WHERE uuid = ?", uuid).Scan(&discordId) + if err != nil { + return "", err + } + + if !discordId.Valid { + return "", nil + } + + return discordId.String, nil +} + +func FetchGoogleIdByUUID(uuid []byte) (string, error) { + var googleId sql.NullString + err := handle.QueryRow("SELECT googleId FROM accounts WHERE uuid = ?", uuid).Scan(&googleId) + if err != nil { + return "", err + } + + if !googleId.Valid { + return "", nil + } + + return googleId.String, nil +} + func FetchUsernameBySessionToken(token []byte) (string, error) { var username string err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON a.uuid = s.uuid WHERE s.token = ?", token).Scan(&username) diff --git a/go.mod b/go.mod index 04ae233..14f1e75 100644 --- a/go.mod +++ b/go.mod @@ -13,4 +13,8 @@ require ( github.com/klauspost/compress v1.17.9 ) -require golang.org/x/sys v0.19.0 // indirect +require ( + github.com/bwmarrin/discordgo v0.28.1 // indirect + github.com/gorilla/websocket v1.4.2 // indirect + golang.org/x/sys v0.19.0 // indirect +) diff --git a/go.sum b/go.sum index 7e0a6c3..51feab8 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,22 @@ +github.com/bwmarrin/discordgo v0.28.1 h1:gXsuo2GBO7NbR6uqmrrBDplPUx2T3nzu775q/Rd1aG4= +github.com/bwmarrin/discordgo v0.28.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/rogueserver.go b/rogueserver.go index 0b007ae..c54c6de 100644 --- a/rogueserver.go +++ b/rogueserver.go @@ -19,66 +19,69 @@ package main import ( "encoding/gob" - "flag" "log" "net" "net/http" "os" + "strconv" + "github.com/bwmarrin/discordgo" "github.com/pagefaultgames/rogueserver/api" "github.com/pagefaultgames/rogueserver/api/account" "github.com/pagefaultgames/rogueserver/db" ) func main() { - // flag stuff - debug := flag.Bool("debug", false, "use debug mode") + // env stuff + debug, _ := strconv.ParseBool(os.Getenv("debug")) - proto := flag.String("proto", "tcp", "protocol for api to use (tcp, unix)") - addr := flag.String("addr", "0.0.0.0:8001", "network address for api to listen on") - tlscert := flag.String("tlscert", "", "tls certificate path") - tlskey := flag.String("tlskey", "", "tls key path") + proto := getEnv("proto", "tcp") + addr := getEnv("addr", "0.0.0.0:8001") + tlscert := getEnv("tlscert", "") + tlskey := getEnv("tlskey", "") - dbuser := flag.String("dbuser", "pokerogue", "database username") - dbpass := flag.String("dbpass", "pokerogue", "database password") - dbproto := flag.String("dbproto", "tcp", "protocol for database connection") - dbaddr := flag.String("dbaddr", "localhost", "database address") - dbname := flag.String("dbname", "pokeroguedb", "database name") + dbuser := getEnv("dbuser", "pokerogue") + dbpass := getEnv("dbpass", "pokerogue") + dbproto := getEnv("dbproto", "tcp") + dbaddr := getEnv("dbaddr", "localhost") + dbname := getEnv("dbname", "pokeroguedb") - discordclientid := flag.String("discordclientid", "dcid", "Discord Oauth2 Client ID") - discordsecretid := flag.String("discordsecretid", "dsid", "Discord Oauth2 Secret ID") + discordclientid := getEnv("discordclientid", "") + discordsecretid := getEnv("discordsecretid", "") - googleclientid := flag.String("googleclientid", "gcid", "Google Oauth2 Client ID") - googlesecretid := flag.String("googlesecretid", "gsid", "Google Oauth2 Secret ID") + googleclientid := getEnv("googleclientid", "") + googlesecretid := getEnv("googlesecretid", "") - callbackurl := flag.String("callbackurl", "http://localhost:8001/", "Callback URL for Oauth2 Client") + callbackurl := getEnv("callbackurl", "http://localhost:8001/") - gameurl := flag.String("gameurl", "https://pokerogue.net", "URL for game server") + gameurl := getEnv("gameurl", "https://pokerogue.net") - flag.Parse() + discordbottoken := getEnv("discordbottoken", "") + discordguildid := getEnv("discordguildid", "") - account.GameURL = *gameurl + account.GameURL = gameurl - account.DiscordClientID = *discordclientid - account.DiscordClientSecret = *discordsecretid - account.DiscordCallbackURL = *callbackurl + "/auth/discord/callback" - - account.GoogleClientID = *googleclientid - account.GoogleClientSecret = *googlesecretid - account.GoogleCallbackURL = *callbackurl + "/auth/google/callback" + account.DiscordClientID = discordclientid + account.DiscordClientSecret = discordsecretid + account.DiscordCallbackURL = callbackurl + "/auth/discord/callback" + account.GoogleClientID = googleclientid + account.GoogleClientSecret = googlesecretid + account.GoogleCallbackURL = callbackurl + "/auth/google/callback" + account.DiscordSession, _ = discordgo.New("Bot " + discordbottoken) + account.DiscordGuildID = discordguildid // register gob types gob.Register([]interface{}{}) gob.Register(map[string]interface{}{}) // get database connection - err := db.Init(*dbuser, *dbpass, *dbproto, *dbaddr, *dbname) + err := db.Init(dbuser, dbpass, dbproto, dbaddr, dbname) if err != nil { log.Fatalf("failed to initialize database: %s", err) } // create listener - listener, err := createListener(*proto, *addr) + listener, err := createListener(proto, addr) if err != nil { log.Fatalf("failed to create net listener: %s", err) } @@ -92,14 +95,14 @@ func main() { // start web server handler := prodHandler(mux, gameurl) - if *debug { + if debug { handler = debugHandler(mux) } - if *tlscert == "" { + if tlscert == "" { err = http.Serve(listener, handler) } else { - err = http.ServeTLS(listener, handler, *tlscert, *tlskey) + err = http.ServeTLS(listener, handler, tlscert, tlskey) } if err != nil { log.Fatalf("failed to create http server or server errored: %s", err) @@ -126,11 +129,11 @@ func createListener(proto, addr string) (net.Listener, error) { return listener, nil } -func prodHandler(router *http.ServeMux, clienturl *string) http.Handler { +func prodHandler(router *http.ServeMux, clienturl string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET, POST") - w.Header().Set("Access-Control-Allow-Origin", *clienturl) + w.Header().Set("Access-Control-Allow-Origin", clienturl) if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) @@ -155,3 +158,11 @@ func debugHandler(router *http.ServeMux) http.Handler { router.ServeHTTP(w, r) }) } + +func getEnv(key string, defaultValue string) string { + if value, ok := os.LookupEnv(key); ok { + return value + } + + return defaultValue +}