diff --git a/api/daily/common.go b/api/daily/common.go index fa2604a..9c1df94 100644 --- a/api/daily/common.go +++ b/api/daily/common.go @@ -18,21 +18,15 @@ package daily import ( - "bytes" - "context" "crypto/md5" "crypto/rand" "encoding/base64" "encoding/binary" - "encoding/json" "fmt" "log" "os" "time" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/pagefaultgames/rogueserver/db" "github.com/robfig/cron/v3" ) @@ -90,17 +84,6 @@ func Init() error { scheduler.Start() - if os.Getenv("AWS_ENDPOINT_URL_S3") != "" { - go func() { - for { - err = S3SaveMigration() - if err != nil { - return - } - } - }() - } - return nil } @@ -116,61 +99,3 @@ func deriveSeed(seedTime time.Time) []byte { return hashedSeed[:] } - -func S3SaveMigration() error { - cfg, _ := config.LoadDefaultConfig(context.TODO()) - - svc := s3.NewFromConfig(cfg, func(o *s3.Options) { - o.BaseEndpoint = aws.String(os.Getenv("AWS_ENDPOINT_URL_S3")) - }) - - _, err := svc.CreateBucket(context.Background(), &s3.CreateBucketInput{ - Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")), - }) - if err != nil { - log.Printf("error while creating bucket (already exists?): %s", err) - } - - // retrieve accounts from db - accounts, err := db.GetLocalSystemAccounts() - if err != nil { - return fmt.Errorf("failed to retrieve old accounts: %s", err) - } - - for _, user := range accounts { - data, err := db.ReadSystemSaveData(user) - if err != nil { - continue - } - - username, err := db.FetchUsernameFromUUID(user) - if err != nil { - continue - } - - json, err := json.Marshal(data) - if err != nil { - continue - } - - _, err = svc.PutObject(context.Background(), &s3.PutObjectInput{ - Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")), - Key: aws.String(username), - Body: bytes.NewReader(json), - }) - if err != nil { - log.Printf("error while saving data in S3 for user %s: %s", username, err) - continue - } - - err = db.DeleteSystemSaveData(user) - if err != nil { - log.Printf("failed to delete old save for user %s: %s", username, err) - continue - } - - log.Printf("saved data in S3 for user %s", username) - } - - return nil -} diff --git a/api/savedata/system.go b/api/savedata/system.go index 6c02f2a..cf07bf5 100644 --- a/api/savedata/system.go +++ b/api/savedata/system.go @@ -19,13 +19,21 @@ package savedata import ( "fmt" + "os" "github.com/pagefaultgames/rogueserver/db" "github.com/pagefaultgames/rogueserver/defs" ) func GetSystem(uuid []byte) (defs.SystemSaveData, error) { - system, err := db.ReadSystemSaveData(uuid) + var system defs.SystemSaveData + var err error + + if os.Getenv("AWS_ENDPOINT_URL_S3") != "" { // use S3 + system, err = db.GetSystemSaveFromS3(uuid) + } else { // use database + system, err = db.ReadSystemSaveData(uuid) + } if err != nil { return system, err } @@ -43,7 +51,16 @@ func UpdateSystem(uuid []byte, data defs.SystemSaveData) error { return fmt.Errorf("failed to update account stats: %s", err) } - return db.StoreSystemSaveData(uuid, data) + if os.Getenv("AWS_ENDPOINT_URL_S3") != "" { // use S3 + err = db.StoreSystemSaveDataS3(uuid, data) + } else { + err = db.StoreSystemSaveData(uuid, data) + } + if err != nil { + return err + } + + return nil } func DeleteSystem(uuid []byte) error { diff --git a/db/savedata.go b/db/savedata.go index 3765bc2..f0cda99 100644 --- a/db/savedata.go +++ b/db/savedata.go @@ -60,53 +60,53 @@ func ReadSeedCompleted(uuid []byte, seed string) (bool, error) { } func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) { - // get and return save from S3 - system, err := GetSystemSaveFromS3(uuid) - if err == nil { - return system, nil - } + var system defs.SystemSaveData - // otherwise look in database and try to move it var data []byte - err = handle.QueryRow("SELECT data FROM systemSaveData WHERE uuid = ?", uuid).Scan(&data) + err := handle.QueryRow("SELECT data FROM systemSaveData WHERE uuid = ?", uuid).Scan(&data) if err != nil { return system, err } - dec, err := zstd.NewReader(nil) + zr, err := zstd.NewReader(bytes.NewReader(data)) if err != nil { return system, err } - defer dec.Close() + defer zr.Close() - decompressed, err := dec.DecodeAll(data, nil) - if err == nil { - // replace if it worked, otherwise use the original data - data = decompressed + err = gob.NewDecoder(zr).Decode(&system) + if err != nil { + return system, err } - err = gob.NewDecoder(bytes.NewReader(data)).Decode(&system) + return system, nil +} + +func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error { + buf := new(bytes.Buffer) + + zw, err := zstd.NewWriter(buf) if err != nil { - return system, err + return err } - // put it in S3 - err = StoreSystemSaveData(uuid, system) + defer zw.Close() + + err = gob.NewEncoder(zw).Encode(data) if err != nil { - return system, err + return err } - // delete the one in db - err = DeleteSystemSaveData(uuid) + _, err = handle.Exec("REPLACE INTO systemSaveData (uuid, data, timestamp) VALUES (?, ?, UTC_TIMESTAMP())", uuid, buf.Bytes()) if err != nil { - return system, err + return err } - return system, nil + return nil } -func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error { +func StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error { cfg, _ := config.LoadDefaultConfig(context.TODO()) client := s3.NewFromConfig(cfg, func(o *s3.Options) { @@ -118,7 +118,9 @@ func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error { return err } - json, err := json.Marshal(data) + buf := new(bytes.Buffer) + + err = json.NewEncoder(buf).Encode(data) if err != nil { return err } @@ -126,7 +128,7 @@ func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error { _, err = client.PutObject(context.Background(), &s3.PutObjectInput{ Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")), Key: aws.String(username), - Body: bytes.NewReader(json), + Body: buf, }) if err != nil { return err @@ -153,20 +155,14 @@ func ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) { return session, err } - dec, err := zstd.NewReader(nil) + zr, err := zstd.NewReader(bytes.NewReader(data)) if err != nil { return session, err } - defer dec.Close() + defer zr.Close() - decompressed, err := dec.DecodeAll(data, nil) - if err == nil { - // replace if it worked, otherwise use the original data - data = decompressed - } - - err = gob.NewDecoder(bytes.NewReader(data)).Decode(&session) + err = gob.NewDecoder(zr).Decode(&session) if err != nil { return session, err } @@ -185,20 +181,21 @@ func GetLatestSessionSaveDataSlot(uuid []byte) (int, error) { } func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error { - var buf bytes.Buffer - err := gob.NewEncoder(&buf).Encode(data) + buf := new(bytes.Buffer) + + zw, err := zstd.NewWriter(buf) if err != nil { return err } - enc, err := zstd.NewWriter(nil) + defer zw.Close() + + err = gob.NewEncoder(zw).Encode(data) if err != nil { return err } - defer enc.Close() - - _, err = handle.Exec("REPLACE INTO sessionSaveData (uuid, slot, data, timestamp) VALUES (?, ?, ?, UTC_TIMESTAMP())", uuid, slot, enc.EncodeAll(buf.Bytes(), nil)) + _, err = handle.Exec("REPLACE INTO sessionSaveData (uuid, slot, data, timestamp) VALUES (?, ?, ?, UTC_TIMESTAMP())", uuid, slot, buf.Bytes()) if err != nil { return err } @@ -250,13 +247,12 @@ func GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) { return system, err } - var session defs.SystemSaveData - err = json.NewDecoder(resp.Body).Decode(&session) + err = json.NewDecoder(resp.Body).Decode(&system) if err != nil { return system, err } - return session, nil + return system, nil } func GetLocalSystemAccounts() ([][]byte, error) {