From 46b1279852e119ce4e570901193307b94240dffe Mon Sep 17 00:00:00 2001 From: Frederico Santos Date: Wed, 21 Aug 2024 22:58:54 +0100 Subject: [PATCH] feat: Add admin Discord link endpoint --- api/account/discord.go | 37 ++++++++++++++++++++++++++++++++++++- api/account/info.go | 4 +++- api/common.go | 3 +++ api/endpoints.go | 40 +++++++++++++++++++++++++++++++++++++++- db/account.go | 28 ++++++++++++++++++++++++++++ go.mod | 6 +++++- go.sum | 10 ++++++++++ rogueserver.go | 4 ++++ 8 files changed, 128 insertions(+), 4 deletions(-) diff --git a/api/account/discord.go b/api/account/discord.go index 792afe6..29affb3 100644 --- a/api/account/discord.go +++ b/api/account/discord.go @@ -22,12 +22,17 @@ import ( "errors" "net/http" "net/url" + + "github.com/bwmarrin/discordgo" ) var ( DiscordClientID string DiscordClientSecret string DiscordCallbackURL string + + DiscordSession *discordgo.Session + DiscordGuildId string ) func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) { @@ -36,7 +41,6 @@ func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, erro http.Redirect(w, r, GameURL, http.StatusSeeOther) return "", errors.New("code is empty") } - discordId, err := RetrieveDiscordId(code) if err != nil { http.Redirect(w, r, GameURL, http.StatusSeeOther) @@ -106,3 +110,34 @@ func RetrieveDiscordId(code string) (string, error) { return user.Id, nil } + +func IsUserDiscordAdmin(discordId string, discordGuildId string) (bool, error) { + // fetch all roles from discord + roles, err := DiscordSession.GuildRoles(discordGuildId) + if err != nil { + return false, err + } + + // fetch all roles from user + userRoles, err := DiscordSession.GuildMember(discordGuildId, discordId) + if err != nil { + return false, err + } + + // check if user has a "Dev" or a "Division Heads" role + var hasRole bool + for _, role := range userRoles.Roles { + for _, guildRole := range roles { + if role == guildRole.ID && (guildRole.Name == "Dev" || guildRole.Name == "Division Heads") { + hasRole = true + break + } + } + } + + if !hasRole { + return false, nil + } + + return true, nil +} diff --git a/api/account/info.go b/api/account/info.go index 6802238..2d29a78 100644 --- a/api/account/info.go +++ b/api/account/info.go @@ -26,16 +26,18 @@ type InfoResponse struct { DiscordId string `json:"discordId"` GoogleId string `json:"googleId"` LastSessionSlot int `json:"lastSessionSlot"` + HasAdminRole bool `json:"hasAdminRole"` } // /account/info - get account info -func Info(username string, discordId string, googleId string, uuid []byte) (InfoResponse, error) { +func Info(username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) { slot, _ := db.GetLatestSessionSaveDataSlot(uuid) response := InfoResponse{ Username: username, LastSessionSlot: slot, DiscordId: discordId, GoogleId: googleId, + HasAdminRole: hasAdminRole, } return response, nil } diff --git a/api/common.go b/api/common.go index 7e3c2aa..d9b26ae 100644 --- a/api/common.go +++ b/api/common.go @@ -66,6 +66,9 @@ func Init(mux *http.ServeMux) error { // auth mux.HandleFunc("/auth/{provider}/callback", handleProviderCallback) mux.HandleFunc("/auth/{provider}/logout", handleProviderLogout) + + // admin + mux.HandleFunc("POST /admin/account/discord-link", handleAdminDiscordLink) return nil } diff --git a/api/endpoints.go b/api/endpoints.go index b9802fe..7d7924a 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -68,7 +68,10 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) { return } } - response, err := account.Info(username, discordId, googleId, uuid) + + hasAdminRole, _ := account.IsUserDiscordAdmin(discordId, account.DiscordGuildId) + + response, err := account.Info(username, discordId, googleId, uuid, hasAdminRole) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -660,3 +663,38 @@ func handleProviderLogout(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusOK) } + +func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + if err != nil { + httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest) + return + } + + uuid, err := uuidFromRequest(r) + if err != nil { + httpError(w, r, err, http.StatusUnauthorized) + return + } + + userDiscordId, err := db.FetchDiscordIdByUUID(uuid) + if err != nil { + httpError(w, r, err, http.StatusUnauthorized) + return + } + + hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildId) + + if !hasRole || err != nil { + httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) + return + } + + err = db.AddDiscordIdByUsername(r.Form.Get("discordId"), r.Form.Get("username")) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) +} diff --git a/db/account.go b/db/account.go index d463211..d628df7 100644 --- a/db/account.go +++ b/db/account.go @@ -116,6 +116,34 @@ func FetchGoogleIdByUsername(username string) (string, error) { return googleId.String, nil } +func FetchDiscordIdByUUID(uuid []byte) (string, error) { + var discordId sql.NullString + err := handle.QueryRow("SELECT discordId FROM accounts WHERE uuid = ?", uuid).Scan(&discordId) + if err != nil { + return "", err + } + + if !discordId.Valid { + return "", nil + } + + return discordId.String, nil +} + +func FetchGoogleIdByUUID(uuid []byte) (string, error) { + var googleId sql.NullString + err := handle.QueryRow("SELECT googleId FROM accounts WHERE uuid = ?", uuid).Scan(&googleId) + if err != nil { + return "", err + } + + if !googleId.Valid { + return "", nil + } + + return googleId.String, nil +} + func FetchUsernameBySessionToken(token []byte) (string, error) { var username string err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON a.uuid = s.uuid WHERE s.token = ?", token).Scan(&username) diff --git a/go.mod b/go.mod index 04ae233..14f1e75 100644 --- a/go.mod +++ b/go.mod @@ -13,4 +13,8 @@ require ( github.com/klauspost/compress v1.17.9 ) -require golang.org/x/sys v0.19.0 // indirect +require ( + github.com/bwmarrin/discordgo v0.28.1 // indirect + github.com/gorilla/websocket v1.4.2 // indirect + golang.org/x/sys v0.19.0 // indirect +) diff --git a/go.sum b/go.sum index 7e0a6c3..51feab8 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,22 @@ +github.com/bwmarrin/discordgo v0.28.1 h1:gXsuo2GBO7NbR6uqmrrBDplPUx2T3nzu775q/Rd1aG4= +github.com/bwmarrin/discordgo v0.28.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/rogueserver.go b/rogueserver.go index 0b007ae..2706813 100644 --- a/rogueserver.go +++ b/rogueserver.go @@ -25,6 +25,7 @@ import ( "net/http" "os" + "github.com/bwmarrin/discordgo" "github.com/pagefaultgames/rogueserver/api" "github.com/pagefaultgames/rogueserver/api/account" "github.com/pagefaultgames/rogueserver/db" @@ -55,6 +56,8 @@ func main() { gameurl := flag.String("gameurl", "https://pokerogue.net", "URL for game server") + discordbottoken := flag.String("discordbottoken", "dbt", "Discord Bot Token") + flag.Parse() account.GameURL = *gameurl @@ -66,6 +69,7 @@ func main() { account.GoogleClientID = *googleclientid account.GoogleClientSecret = *googlesecretid account.GoogleCallbackURL = *callbackurl + "/auth/google/callback" + account.DiscordSession, _ = discordgo.New("Bot " + *discordbottoken) // register gob types gob.Register([]interface{}{})