From 9c6374eedf4f9d3e278b7395ced15b3aa02e3304 Mon Sep 17 00:00:00 2001 From: Krystian Chmura Date: Wed, 15 May 2024 20:57:58 +0200 Subject: [PATCH] fix errors in /account --- api/account/changepw.go | 4 +-- api/account/common.go | 22 ++++++++++++++ api/account/common_test.go | 61 ++++++++++++++++++++++++++++++++++++++ api/account/login.go | 14 ++++----- api/account/register.go | 15 ++++++---- api/common.go | 30 ++++++++++++++----- api/common_test.go | 28 +++++++++++++++++ api/endpoints.go | 34 ++++++++++----------- db/account.go | 9 +++++- db/db.go | 2 -- db/errors.go | 8 +++++ errors/errors.go | 14 +++++++++ go.mod | 8 ++++- go.sum | 10 +++++++ 14 files changed, 215 insertions(+), 44 deletions(-) create mode 100644 api/account/common_test.go create mode 100644 api/common_test.go create mode 100644 db/errors.go create mode 100644 errors/errors.go diff --git a/api/account/changepw.go b/api/account/changepw.go index 2e79971..1f56e54 100644 --- a/api/account/changepw.go +++ b/api/account/changepw.go @@ -25,8 +25,8 @@ import ( ) func ChangePW(uuid []byte, password string) error { - if len(password) < 6 { - return fmt.Errorf("invalid password") + if err := validatePassword(password); err != nil { + return err } salt := make([]byte, ArgonSaltSize) diff --git a/api/account/common.go b/api/account/common.go index 8e8c0ff..13c8cfa 100644 --- a/api/account/common.go +++ b/api/account/common.go @@ -18,9 +18,12 @@ package account import ( + "cmp" + "net/http" "regexp" "runtime" + "github.com/pagefaultgames/rogueserver/errors" "golang.org/x/crypto/argon2" ) @@ -52,3 +55,22 @@ func deriveArgon2IDKey(password, salt []byte) []byte { return argon2.IDKey(password, salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize) } + +func validateUsernamePassword(username string, password string) error { + return cmp.Or(validateUsername(username), validatePassword(password)) +} + +func validateUsername(username string) error { + if !isValidUsername(username) { + return errors.NewHttpError(http.StatusBadRequest, "invalid username") + } + return nil +} + +func validatePassword(password string) error { + if len(password) < 6 { + return errors.NewHttpError(http.StatusBadRequest, "invalid password") + } + + return nil +} diff --git a/api/account/common_test.go b/api/account/common_test.go new file mode 100644 index 0000000..205dafc --- /dev/null +++ b/api/account/common_test.go @@ -0,0 +1,61 @@ +package account + +import ( + "net/http" + "testing" + + "github.com/pagefaultgames/rogueserver/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateUsernamePassword(t *testing.T) { + t.Run("valid username and password", func(t *testing.T) { + err := validateUsernamePassword("validUser", "validPass") + assert.NoError(t, err) + }) + + t.Run("invalid username", func(t *testing.T) { + err := validateUsernamePassword("", "validPass") + require.NotNil(t, err) + assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username")) + }) + + t.Run("invalid password", func(t *testing.T) { + err := validateUsernamePassword("validUser", "123") + require.NotNil(t, err) + assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid password")) + }) + + t.Run("invalid username and password", func(t *testing.T) { + err := validateUsernamePassword("", "123") + require.NotNil(t, err) + assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username")) + }) +} + +func TestValidateUsername(t *testing.T) { + t.Run("valid username", func(t *testing.T) { + err := validateUsername("validUser") + assert.NoError(t, err) + }) + + t.Run("invalid username", func(t *testing.T) { + err := validateUsername("") + require.NotNil(t, err) + assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username")) + }) +} + +func TestValidatePassword(t *testing.T) { + t.Run("valid password", func(t *testing.T) { + err := validatePassword("validPass") + assert.NoError(t, err) + }) + + t.Run("invalid password", func(t *testing.T) { + err := validatePassword("123") + require.NotNil(t, err) + assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid password")) + }) +} diff --git a/api/account/login.go b/api/account/login.go index bfd11be..0730e36 100644 --- a/api/account/login.go +++ b/api/account/login.go @@ -23,8 +23,10 @@ import ( "database/sql" "encoding/base64" "fmt" + "net/http" "github.com/pagefaultgames/rogueserver/db" + "github.com/pagefaultgames/rogueserver/errors" ) type LoginResponse GenericAuthResponse @@ -33,25 +35,21 @@ type LoginResponse GenericAuthResponse func Login(username, password string) (LoginResponse, error) { var response LoginResponse - if !isValidUsername(username) { - return response, fmt.Errorf("invalid username") - } - - if len(password) < 6 { - return response, fmt.Errorf("invalid password") + if err := validateUsernamePassword(username, password); err != nil { + return response, err } key, salt, err := db.FetchAccountKeySaltFromUsername(username) if err != nil { if err == sql.ErrNoRows { - return response, fmt.Errorf("account doesn't exist") + return response, errors.NewHttpError(http.StatusNotFound, "account doesn't exist") } return response, err } if !bytes.Equal(key, deriveArgon2IDKey([]byte(password), salt)) { - return response, fmt.Errorf("password doesn't match") + return response, errors.NewHttpError(http.StatusUnauthorized, "password doesn't match") } token := make([]byte, TokenSize) diff --git a/api/account/register.go b/api/account/register.go index f2ed611..1fe1917 100644 --- a/api/account/register.go +++ b/api/account/register.go @@ -19,18 +19,18 @@ package account import ( "crypto/rand" + stderrors "errors" "fmt" + "net/http" + "github.com/pagefaultgames/rogueserver/db" + "github.com/pagefaultgames/rogueserver/errors" ) // /account/register - register account func Register(username, password string) error { - if !isValidUsername(username) { - return fmt.Errorf("invalid username") - } - - if len(password) < 6 { - return fmt.Errorf("invalid password") + if err := validateUsernamePassword(username, password); err != nil { + return err } uuid := make([]byte, UUIDSize) @@ -47,6 +47,9 @@ func Register(username, password string) error { err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt) if err != nil { + if stderrors.Is(err, db.ErrAccountAlreadyExists) { + return errors.NewHttpError(http.StatusConflict, fmt.Sprintf(`username "%s" already taken`, username)) + } return fmt.Errorf("failed to add account record: %s", err) } diff --git a/api/common.go b/api/common.go index bdd5aab..f06fd72 100644 --- a/api/common.go +++ b/api/common.go @@ -20,12 +20,15 @@ package api import ( "encoding/base64" "encoding/json" + stderrors "errors" "fmt" + "log" + "net/http" + "github.com/pagefaultgames/rogueserver/api/account" "github.com/pagefaultgames/rogueserver/api/daily" "github.com/pagefaultgames/rogueserver/db" - "log" - "net/http" + "github.com/pagefaultgames/rogueserver/errors" ) func Init(mux *http.ServeMux) error { @@ -69,16 +72,16 @@ func Init(mux *http.ServeMux) error { func tokenFromRequest(r *http.Request) ([]byte, error) { if r.Header.Get("Authorization") == "" { - return nil, fmt.Errorf("missing token") + return nil, errors.NewHttpError(http.StatusBadRequest, "missing token") } token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) if err != nil { - return nil, fmt.Errorf("failed to decode token: %s", err) + return nil, errors.NewHttpError(http.StatusBadRequest, "failed to decode token") } if len(token) != account.TokenSize { - return nil, fmt.Errorf("invalid token length: got %d, expected %d", len(token), account.TokenSize) + return nil, errors.NewHttpError(http.StatusBadRequest, "invalid token length") } return token, nil @@ -97,14 +100,17 @@ func tokenAndUuidFromRequest(r *http.Request) ([]byte, []byte, error) { uuid, err := db.FetchUUIDFromToken(token) if err != nil { - return nil, nil, fmt.Errorf("failed to validate token: %s", err) + if stderrors.Is(err, db.ErrTokenNotFound) { + return nil, nil, errors.NewHttpError(http.StatusUnauthorized, "bad token") + } + return nil, nil, fmt.Errorf("failed to fetch uuid from db: %w", err) } return token, uuid, nil } func httpError(w http.ResponseWriter, r *http.Request, err error, code int) { - log.Printf("%s: %s\n", r.URL.Path, err) + log.Printf("%s: %s\n", r.URL.Path, err.Error()) http.Error(w, err.Error(), code) } @@ -116,3 +122,13 @@ func jsonResponse(w http.ResponseWriter, r *http.Request, data any) { return } } + +func statusCodeFromError(err error) int { + var httpErr *errors.HttpError + + if stderrors.As(err, &httpErr) { + return httpErr.Code + } + + return http.StatusInternalServerError +} diff --git a/api/common_test.go b/api/common_test.go new file mode 100644 index 0000000..299aaa2 --- /dev/null +++ b/api/common_test.go @@ -0,0 +1,28 @@ +package api + +import ( + stderrors "errors" + "net/http" + "testing" + + "github.com/pagefaultgames/rogueserver/errors" + "github.com/stretchr/testify/assert" +) + +func TestStatusCodeFromError(t *testing.T) { + t.Run("nil", func(t *testing.T) { + code := statusCodeFromError(nil) + assert.Equal(t, http.StatusInternalServerError, code) + }) + t.Run("http error", func(t *testing.T) { + err := errors.NewHttpError(http.StatusTeapot, "teapot") + code := statusCodeFromError(err) + assert.Equal(t, http.StatusTeapot, code) + }) + + t.Run("standard error", func(t *testing.T) { + err := stderrors.New("standard error") + code := statusCodeFromError(err) + assert.Equal(t, http.StatusInternalServerError, code) + }) +} diff --git a/api/endpoints.go b/api/endpoints.go index 47c5847..7b60adb 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -43,7 +43,7 @@ import ( func handleAccountInfo(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, err, statusCodeFromError(err)) return } @@ -71,11 +71,11 @@ func handleAccountRegister(w http.ResponseWriter, r *http.Request) { err = account.Register(r.Form.Get("username"), r.Form.Get("password")) if err != nil { - httpError(w, r, err, http.StatusInternalServerError) + httpError(w, r, err, statusCodeFromError(err)) return } - w.WriteHeader(http.StatusOK) + w.WriteHeader(http.StatusCreated) } func handleAccountLogin(w http.ResponseWriter, r *http.Request) { @@ -87,7 +87,7 @@ func handleAccountLogin(w http.ResponseWriter, r *http.Request) { response, err := account.Login(r.Form.Get("username"), r.Form.Get("password")) if err != nil { - httpError(w, r, err, http.StatusInternalServerError) + httpError(w, r, err, statusCodeFromError(err)) return } @@ -103,17 +103,17 @@ func handleAccountChangePW(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, err, statusCodeFromError(err)) return } err = account.ChangePW(uuid, r.Form.Get("password")) if err != nil { - httpError(w, r, err, http.StatusInternalServerError) + httpError(w, r, err, statusCodeFromError(err)) return } - w.WriteHeader(http.StatusOK) + w.WriteHeader(http.StatusNoContent) } func handleAccountLogout(w http.ResponseWriter, r *http.Request) { @@ -129,7 +129,7 @@ func handleAccountLogout(w http.ResponseWriter, r *http.Request) { return } - w.WriteHeader(http.StatusOK) + w.WriteHeader(http.StatusNoContent) } // game @@ -149,7 +149,7 @@ func handleGameClassicSessionCount(w http.ResponseWriter, r *http.Request) { func handleGetSessionData(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, err, statusCodeFromError(err)) return } @@ -195,7 +195,7 @@ const legacyClientSessionId = "LEGACY_CLIENT" func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, err, statusCodeFromError(err)) return } @@ -244,7 +244,7 @@ func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) { func clearSessionData(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, err, statusCodeFromError(err)) return } @@ -332,7 +332,7 @@ func clearSessionData(w http.ResponseWriter, r *http.Request) { func deleteSystemSave(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, err, statusCodeFromError(err)) return } @@ -412,7 +412,7 @@ func deleteSystemSave(w http.ResponseWriter, r *http.Request) { func legacyHandleSaveData(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, err, statusCodeFromError(err)) return } @@ -581,7 +581,7 @@ type CombinedSaveData struct { func handleUpdateAll(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, err, statusCodeFromError(err)) return } @@ -653,7 +653,7 @@ type SystemVerifyRequest struct { func handleSystemVerify(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, err, statusCodeFromError(err)) return } @@ -705,7 +705,7 @@ func handleSystemVerify(w http.ResponseWriter, r *http.Request) { func handleGetSystemData(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, err, statusCodeFromError(err)) return } @@ -741,7 +741,7 @@ func handleGetSystemData(w http.ResponseWriter, r *http.Request) { func legacyHandleNewClear(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, err, statusCodeFromError(err)) return } diff --git a/db/account.go b/db/account.go index 23a1a6d..04569e0 100644 --- a/db/account.go +++ b/db/account.go @@ -23,13 +23,17 @@ import ( "fmt" "slices" - _ "github.com/go-sql-driver/mysql" + "github.com/go-sql-driver/mysql" "github.com/pagefaultgames/rogueserver/defs" ) func AddAccountRecord(uuid []byte, username string, key, salt []byte) error { _, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt) if err != nil { + var mysqlErr *mysql.MySQLError + if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 { + return ErrAccountAlreadyExists + } return err } @@ -240,6 +244,9 @@ func FetchUUIDFromToken(token []byte) ([]byte, error) { var uuid []byte err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ?", token).Scan(&uuid) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrTokenNotFound + } return nil, err } diff --git a/db/db.go b/db/db.go index 998af39..42f3bb5 100644 --- a/db/db.go +++ b/db/db.go @@ -24,8 +24,6 @@ import ( "log" "os" "time" - - _ "github.com/go-sql-driver/mysql" ) var handle *sql.DB diff --git a/db/errors.go b/db/errors.go new file mode 100644 index 0000000..bf65871 --- /dev/null +++ b/db/errors.go @@ -0,0 +1,8 @@ +package db + +import "errors" + +var ( + ErrAccountAlreadyExists = errors.New("account already exists") + ErrTokenNotFound = errors.New("token not found") +) diff --git a/errors/errors.go b/errors/errors.go new file mode 100644 index 0000000..bc73af3 --- /dev/null +++ b/errors/errors.go @@ -0,0 +1,14 @@ +package errors + +type HttpError struct { + Code int + Message string +} + +func NewHttpError(code int, message string) *HttpError { + return &HttpError{Code: code, Message: message} +} + +func (h HttpError) Error() string { + return h.Message +} diff --git a/go.mod b/go.mod index ed86e1f..72fae90 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,13 @@ require ( github.com/go-sql-driver/mysql v1.7.1 github.com/klauspost/compress v1.17.4 github.com/robfig/cron/v3 v3.0.1 + github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.16.0 ) -require golang.org/x/sys v0.15.0 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.15.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 88cd836..0352ce2 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,20 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=