From e74077e7d53e04b9cf0a271df42501527c075104 Mon Sep 17 00:00:00 2001 From: starainrt Date: Wed, 21 Jul 2021 17:25:57 +0800 Subject: [PATCH] orm update add --- orm_v1.go | 132 +++++++++++++++++-- reflect.go | 103 ++++++++++----- reflect_test.go | 22 +++- stardb_v1.go | 344 ++++++++++++++++++++++++++++++------------------ 4 files changed, 424 insertions(+), 177 deletions(-) diff --git a/orm_v1.go b/orm_v1.go index 95642ea..901234d 100644 --- a/orm_v1.go +++ b/orm_v1.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "fmt" "reflect" "strings" ) @@ -133,6 +134,49 @@ func (star *StarDB) ExecXS(ins interface{}, args ...interface{}) ([]sql.Result, func (star *StarDB) ExecX(ins interface{}, args ...interface{}) (sql.Result, error) { return star.execX(nil, ins, args...) } + +func getUpdateSentence(ins interface{}, sheetName string, primaryKey ...string) (string, []string, error) { + Keys, err := getAllRefKey(ins, "db") + if err != nil { + return "", []string{}, err + } + var mystr string + for k, v := range Keys { + mystr += fmt.Sprintf("%s=? ", v) + Keys[k] = ":" + v + } + mystr = fmt.Sprintf("update %s set %s where ", sheetName, mystr) + var whereSlice []string + for _, v := range primaryKey { + whereSlice = append(whereSlice, v+"=?") + Keys = append(Keys, ":"+v) + } + mystr += strings.Join(whereSlice, " and ") + return mystr, Keys, nil +} + +func getInsertSentence(ins interface{}, sheetName string, autoIncrease ...string) (string, []string, error) { + Keys, err := getAllRefKey(ins, "db") + if err != nil { + return "", []string{}, err + } + var mystr, rps string + var rtnKeys []string +cns: + for _, v := range Keys { + for _, vs := range autoIncrease { + if v == vs { + rps += "null," + continue cns + } + } + rtnKeys = append(rtnKeys, ":"+v) + rps += "?," + } + mystr = fmt.Sprintf("insert into %s (%s) values (%s) ", sheetName, strings.Join(Keys, ","), rps[:len(rps)-1]) + return mystr, rtnKeys, nil +} + func (star *StarDB) execX(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) { kvMap, err := getAllRefValue(ins, "db") if err != nil { @@ -161,10 +205,44 @@ func (star *StarDB) execX(ctx context.Context, ins interface{}, args ...interfac return star.exec(ctx, args...) } -func (star *StarDB) QueryXContext(ctx context.Context,ins interface{}, args ...interface{}) (*StarRows, error) { +func (star *StarDB) Update(ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { + return star.updateinsert(nil, true, ins, sheetName, primaryKey...) +} + +func (star *StarDB) UpdateContext(ctx context.Context, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { + return star.updateinsert(ctx, true, ins, sheetName, primaryKey...) +} + +func (star *StarDB) Insert(ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { + return star.updateinsert(nil, false, ins, sheetName, primaryKey...) +} + +func (star *StarDB) InsertContext(ctx context.Context, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { + return star.updateinsert(ctx, false, ins, sheetName, primaryKey...) +} + +func (star *StarDB) updateinsert(ctx context.Context, isUpdate bool, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { + var sqlStr string + var para []string + var err error + if isUpdate { + sqlStr, para, err = getUpdateSentence(ins, sheetName, primaryKey...) + } else { + sqlStr, para, err = getInsertSentence(ins, sheetName, primaryKey...) + } + if err != nil { + return nil, err + } + tmpStr := append([]interface{}{}, sqlStr) + for _, v := range para { + tmpStr = append(tmpStr, v) + } + return star.execX(ctx, ins, tmpStr...) +} +func (star *StarDB) QueryXContext(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) { return star.queryX(ctx, ins, args) } -func (star *StarDB) QueryXSContext(ctx context.Context,ins interface{}, args ...interface{}) ([]*StarRows, error) { +func (star *StarDB) QueryXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]*StarRows, error) { var starRes []*StarRows t := reflect.TypeOf(ins) v := reflect.ValueOf(ins) @@ -191,7 +269,7 @@ func (star *StarDB) QueryXSContext(ctx context.Context,ins interface{}, args ... return starRes, nil } -func (star *StarDB) ExecXSContext(ctx context.Context,ins interface{}, args ...interface{}) ([]sql.Result, error) { +func (star *StarDB) ExecXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]sql.Result, error) { var starRes []sql.Result t := reflect.TypeOf(ins) v := reflect.ValueOf(ins) @@ -218,12 +296,10 @@ func (star *StarDB) ExecXSContext(ctx context.Context,ins interface{}, args ...i return starRes, nil } -func (star *StarDB) ExecXContext(ctx context.Context,ins interface{}, args ...interface{}) (sql.Result, error) { +func (star *StarDB) ExecXContext(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) { return star.execX(ctx, ins, args...) } - - func (star *StarTx) queryX(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) { kvMap, err := getAllRefValue(ins, "db") if err != nil { @@ -254,6 +330,7 @@ func (star *StarTx) queryX(ctx context.Context, ins interface{}, args ...interfa func (star *StarTx) QueryX(ins interface{}, args ...interface{}) (*StarRows, error) { return star.queryX(nil, ins, args) } + func (star *StarTx) QueryXS(ins interface{}, args ...interface{}) ([]*StarRows, error) { var starRes []*StarRows t := reflect.TypeOf(ins) @@ -280,7 +357,40 @@ func (star *StarTx) QueryXS(ins interface{}, args ...interface{}) ([]*StarRows, } return starRes, nil } +func (star *StarTx) Update(ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { + return star.updateinsert(nil, true, ins, sheetName, primaryKey...) +} +func (star *StarTx) UpdateContext(ctx context.Context, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { + return star.updateinsert(ctx, true, ins, sheetName, primaryKey...) +} + +func (star *StarTx) Insert(ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { + return star.updateinsert(nil, false, ins, sheetName, primaryKey...) +} + +func (star *StarTx) InsertContext(ctx context.Context, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { + return star.updateinsert(ctx, false, ins, sheetName, primaryKey...) +} + +func (star *StarTx) updateinsert(ctx context.Context, isUpdate bool, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { + var sqlStr string + var para []string + var err error + if isUpdate { + sqlStr, para, err = getUpdateSentence(ins, sheetName, primaryKey...) + } else { + sqlStr, para, err = getInsertSentence(ins, sheetName, primaryKey...) + } + if err != nil { + return nil, err + } + tmpStr := append([]interface{}{}, sqlStr) + for _, v := range para { + tmpStr = append(tmpStr, v) + } + return star.execX(ctx, ins, tmpStr...) +} func (star *StarTx) ExecXS(ins interface{}, args ...interface{}) ([]sql.Result, error) { var starRes []sql.Result t := reflect.TypeOf(ins) @@ -339,10 +449,10 @@ func (star *StarTx) execX(ctx context.Context, ins interface{}, args ...interfac return star.exec(ctx, args...) } -func (star *StarTx) QueryXContext(ctx context.Context,ins interface{}, args ...interface{}) (*StarRows, error) { +func (star *StarTx) QueryXContext(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) { return star.queryX(ctx, ins, args) } -func (star *StarTx) QueryXSContext(ctx context.Context,ins interface{}, args ...interface{}) ([]*StarRows, error) { +func (star *StarTx) QueryXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]*StarRows, error) { var starRes []*StarRows t := reflect.TypeOf(ins) v := reflect.ValueOf(ins) @@ -369,7 +479,7 @@ func (star *StarTx) QueryXSContext(ctx context.Context,ins interface{}, args ... return starRes, nil } -func (star *StarTx) ExecXSContext(ctx context.Context,ins interface{}, args ...interface{}) ([]sql.Result, error) { +func (star *StarTx) ExecXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]sql.Result, error) { var starRes []sql.Result t := reflect.TypeOf(ins) v := reflect.ValueOf(ins) @@ -396,6 +506,6 @@ func (star *StarTx) ExecXSContext(ctx context.Context,ins interface{}, args ...i return starRes, nil } -func (star *StarTx) ExecXContext(ctx context.Context,ins interface{}, args ...interface{}) (sql.Result, error) { +func (star *StarTx) ExecXContext(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) { return star.execX(ctx, ins, args...) -} \ No newline at end of file +} diff --git a/reflect.go b/reflect.go index e5593a7..3c8aeba 100644 --- a/reflect.go +++ b/reflect.go @@ -3,6 +3,7 @@ package stardb import ( "errors" "reflect" + "time" ) func (star *StarRows) setAllRefValue(stc interface{}, skey string, rows int) error { @@ -24,20 +25,30 @@ func (star *StarRows) setAllRefValue(stc interface{}, skey string, rows int) err for i := 0; i < t.NumField(); i++ { tp := t.Field(i) srFrd := v.Field(i) + seg := tp.Tag.Get(skey) if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct { - sp := reflect.New(reflect.TypeOf(srFrd.Interface()).Elem()).Interface() - star.setAllRefValue(sp, skey, rows) - v.Field(i).Set(reflect.ValueOf(sp)) - continue + if seg == "" { + continue + } + if seg == "---" { + sp := reflect.New(reflect.TypeOf(srFrd.Interface()).Elem()).Interface() + star.setAllRefValue(sp, skey, rows) + v.Field(i).Set(reflect.ValueOf(sp)) + continue + } } if srFrd.Kind() == reflect.Struct { - sp := reflect.New(reflect.TypeOf(v.Field(i).Interface())).Interface() - star.setAllRefValue(sp, skey, rows) - v.Field(i).Set(reflect.ValueOf(sp).Elem()) - continue + if seg == "" { + continue + } + if seg == "---" { + sp := reflect.New(reflect.TypeOf(v.Field(i).Interface())).Interface() + star.setAllRefValue(sp, skey, rows) + v.Field(i).Set(reflect.ValueOf(sp).Elem()) + continue + } } - seg := tp.Tag.Get(skey) if seg == "" { continue } @@ -71,6 +82,12 @@ func (star *StarRows) setAllRefValue(stc interface{}, skey string, rows int) err v.Field(i).SetFloat(star.Row(rows).MustFloat64(seg)) case reflect.Float32: v.Field(i).SetFloat(float64(star.Row(rows).MustFloat32(seg))) + case reflect.Interface, reflect.Struct, reflect.Ptr: + inf := star.Row(rows).Result[star.columnref[seg]] + switch vtype := inf.(type) { + case time.Time: + v.Field(i).Set(reflect.ValueOf(vtype)) + } default: } @@ -103,37 +120,47 @@ func getAllRefValue(stc interface{}, skey string) (map[string]interface{}, error result := make(map[string]interface{}) t := reflect.TypeOf(stc) v := reflect.ValueOf(stc) - if v.Kind() != reflect.Struct { - return nil, errors.New("interface{} is not a struct") - } if t.Kind() == reflect.Ptr { + if v.IsNil() { + return nil, errors.New("ptr interface{} is nil") + } t = t.Elem() v = v.Elem() } + if v.Kind() != reflect.Struct { + return nil, errors.New("interface{} is not a struct") + } for i := 0; i < t.NumField(); i++ { tp := t.Field(i) srFrd := v.Field(i) + seg := tp.Tag.Get(skey) if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct { - res, err := getAllRefValue(reflect.ValueOf(srFrd.Elem().Interface()).Interface(), skey) - if err != nil { - return result, err + if srFrd.IsNil() { + continue } - for k, v := range res { - result[k] = v + if seg == "---" { + res, err := getAllRefValue(reflect.ValueOf(srFrd.Elem().Interface()).Interface(), skey) + if err != nil { + return result, err + } + for k, v := range res { + result[k] = v + } + continue } - continue } if v.Field(i).Kind() == reflect.Struct { res, err := getAllRefValue(v.Field(i).Interface(), skey) - if err != nil { - return result, err - } - for k, v := range res { - result[k] = v + if seg == "---" { + if err != nil { + return result, err + } + for k, v := range res { + result[k] = v + } + continue } - continue } - seg := tp.Tag.Get(skey) if seg == "" { continue } @@ -155,22 +182,32 @@ func getAllRefKey(stc interface{}, skey string) ([]string, error) { t := reflect.TypeOf(stc) v := reflect.ValueOf(stc) if t.Kind() == reflect.Ptr { + if v.IsNil() { + return []string{}, errors.New("ptr interface{} is nil") + } t = t.Elem() v = v.Elem() } for i := 0; i < t.NumField(); i++ { srFrd := v.Field(i) + profile := t.Field(i) + seg := profile.Tag.Get(skey) if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct { - res, err := getAllRefKey(reflect.ValueOf(srFrd.Elem().Interface()).Interface(), skey) - if err != nil { - return result, err + if srFrd.IsNil() { + continue } - for _, v := range res { - result = append(result, v) + if seg == "---" { + res, err := getAllRefKey(reflect.ValueOf(srFrd.Elem().Interface()).Interface(), skey) + if err != nil { + return result, err + } + for _, v := range res { + result = append(result, v) + } + continue } - continue } - if v.Field(i).Kind() == reflect.Struct { + if v.Field(i).Kind() == reflect.Struct && seg == "---" { res, err := getAllRefKey(v.Field(i).Interface(), skey) if err != nil { return result, err @@ -179,8 +216,6 @@ func getAllRefKey(stc interface{}, skey string) ([]string, error) { result = append(result, v) } } - profile := t.Field(i) - seg := profile.Tag.Get(skey) if seg != "" { result = append(result, seg) } diff --git a/reflect_test.go b/reflect_test.go index 99a58c9..887d66b 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -6,15 +6,26 @@ import ( ) type Useless struct { - Leader string `db:"leader"` - Usable bool `db:"use"` - O *Whoami + Leader string `db:"leader"` + Usable bool `db:"use"` + O *Whoami `db:"---"` } type Whoami struct { Hehe string `db:"hehe"` } +func TestUpInOrm(t *testing.T) { + var hehe = Useless{ + Leader: "no", + Usable: false, + } + sqlstr, param, err := getUpdateSentence(hehe, "ryz", "leader") + fmt.Println(sqlstr, param, err) + sqlstr, param, err = getInsertSentence(hehe, "ryz", "use") + fmt.Println(sqlstr, param, err) +} + func Test_SetRefVal(t *testing.T) { var hehe = Useless{ Leader: "no", @@ -24,6 +35,7 @@ func Test_SetRefVal(t *testing.T) { fmt.Printf("%+v\n", hehe) fmt.Printf("%+v\n", hehe) fmt.Println(getAllRefKey(hehe, "db")) + fmt.Println(getAllRefValue(hehe, "db")) } func Test_Ref(t *testing.T) { @@ -31,5 +43,7 @@ func Test_Ref(t *testing.T) { Leader: "Heheeee", } oooooo.O = &Whoami{"fuck"} - fmt.Println(getAllRefKey(oooooo,"db")) + fmt.Println(getAllRefKey(oooooo, "db")) + fmt.Println(getAllRefValue(oooooo, "db")) + fmt.Println(getAllRefValue(&oooooo, "db")) } diff --git a/stardb_v1.go b/stardb_v1.go index ba71320..0da8483 100644 --- a/stardb_v1.go +++ b/stardb_v1.go @@ -16,9 +16,8 @@ type StarDB struct { } type StarTx struct { - Db *sql.DB - Tx *sql.Tx - ManualScan bool + Db *StarDB + Tx *sql.Tx } // StarRows 为查询结果集(按行) @@ -33,6 +32,11 @@ type StarRows struct { parsed bool } +type StarDBStmt struct { + Stmt *sql.Stmt + Db *StarDB +} + // StarResult 为查询结果集(总) type StarResult struct { Result []interface{} @@ -848,7 +852,7 @@ func (star *StarDB) Begin() (*StarTx, error) { return nil, err } stx := new(StarTx) - stx.Db = star.Db + stx.Db = star stx.Tx = tx return stx, err } @@ -859,7 +863,7 @@ func (star *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, return nil, err } stx := new(StarTx) - stx.Db = star.Db + stx.Db = star stx.Tx = tx return stx, err } @@ -871,40 +875,64 @@ func (star *StarTx) Query(args ...interface{}) (*StarRows, error) { func (star *StarTx) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) { return star.query(ctx, args...) } +func (star *StarTx) ExecStmt(args ...interface{}) (sql.Result, error) { + if len(args) <= 1 { + return nil, errors.New("parameter not enough") + } + stmt, err := star.Prepare(args[0].(string)) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.Exec(args[1:]...) +} + +func (star *StarTx) ExecStmtContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + if len(args) <= 1 { + return nil, errors.New("parameter not enough") + } + stmt, err := star.PrepareContext(ctx, args[0].(string)) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.ExecContext(ctx, args[1:]...) +} + +func (star *StarTx) QueryStmt(args ...interface{}) (*StarRows, error) { + if len(args) <= 1 { + return nil, errors.New("parameter not enough") + } + stmt, err := star.Prepare(args[0].(string)) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.Query(args[1:]...) +} + +func (star *StarTx) QueryStmtContext(ctx context.Context, args ...interface{}) (*StarRows, error) { + if len(args) <= 1 { + return nil, errors.New("parameter not enough") + } + stmt, err := star.PrepareContext(ctx, args[0].(string)) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.QueryContext(ctx, args[1:]...) +} func (star *StarTx) query(ctx context.Context, args ...interface{}) (*StarRows, error) { var err error var rows *sql.Rows - var stmt *sql.Stmt effect := new(StarRows) - if len(args) == 0 { - return effect, errors.New("no args") - } - if len(args) == 1 { - sqlStr := args[0] - if ctx == nil { - if rows, err = star.Tx.Query(sqlStr.(string)); err != nil { - return effect, err - } - } else { - if rows, err = star.Tx.QueryContext(ctx, sqlStr.(string)); err != nil { - return effect, err - } - } - effect.Rows = rows - err = effect.parserows() + if err = star.Db.Ping(); err != nil { return effect, err } - sqlStr := args[0] - if ctx == nil { - stmt, err = star.Tx.Prepare(sqlStr.(string)) - } else { - stmt, err = star.Tx.PrepareContext(ctx, sqlStr.(string)) - } - if err != nil { - return effect, err + if len(args) == 0 { + return effect, errors.New("no args") } - defer stmt.Close() var para []interface{} for k, v := range args { if k != 0 { @@ -915,16 +943,16 @@ func (star *StarTx) query(ctx context.Context, args ...interface{}) (*StarRows, } } if ctx == nil { - if rows, err = stmt.Query(para...); err != nil { + if rows, err = star.Tx.Query(args[0].(string), para...); err != nil { return effect, err } } else { - if rows, err = stmt.QueryContext(ctx, para...); err != nil { + if rows, err = star.Tx.QueryContext(ctx, args[0].(string), para...); err != nil { return effect, err } } effect.Rows = rows - if !star.ManualScan { + if !star.Db.ManualScan { err = effect.parserows() } return effect, err @@ -938,43 +966,162 @@ func (star *StarDB) QueryContext(ctx context.Context, args ...interface{}) (*Sta return star.query(ctx, args...) } -// Query 进行Query操作 -func (star *StarDB) query(ctx context.Context, args ...interface{}) (*StarRows, error) { +func (star *StarDB) QueryStmt(args ...interface{}) (*StarRows, error) { + if len(args) <= 1 { + return nil, errors.New("parameter not enough") + } + stmt, err := star.Prepare(args[0].(string)) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.Query(args[1:]...) +} + +func (star *StarDB) QueryStmtContext(ctx context.Context, args ...interface{}) (*StarRows, error) { + if len(args) <= 1 { + return nil, errors.New("parameter not enough") + } + stmt, err := star.PrepareContext(ctx, args[0].(string)) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.QueryContext(ctx, args[1:]...) +} + +func (star *StarDB) ExecStmt(args ...interface{}) (sql.Result, error) { + if len(args) <= 1 { + return nil, errors.New("parameter not enough") + } + stmt, err := star.Prepare(args[0].(string)) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.Exec(args[1:]...) +} + +func (star *StarDB) ExecStmtContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + if len(args) <= 1 { + return nil, errors.New("parameter not enough") + } + stmt, err := star.PrepareContext(ctx, args[0].(string)) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.ExecContext(ctx, args[1:]...) +} + +func (star *StarDBStmt) Query(args ...interface{}) (*StarRows, error) { + return star.query(nil, args...) +} + +func (star *StarDBStmt) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) { + return star.query(ctx, args...) +} + +func (star *StarDBStmt) Exec(args ...interface{}) (sql.Result, error) { + return star.exec(nil, args...) +} + +func (star *StarDBStmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + return star.exec(ctx, args...) +} + +func (star *StarDBStmt) Close() error { + return star.Stmt.Close() +} + +func (star *StarDBStmt) query(ctx context.Context, args ...interface{}) (*StarRows, error) { var err error var rows *sql.Rows - var stmt *sql.Stmt effect := new(StarRows) - if err = star.Db.Ping(); err != nil { - return effect, err - } if len(args) == 0 { return effect, errors.New("no args") } - if len(args) == 1 { - sqlStr := args[0] - if ctx == nil { - if rows, err = star.Db.Query(sqlStr.(string)); err != nil { - return effect, err - } - } else { - if rows, err = star.Db.Query(sqlStr.(string)); err != nil { - return effect, err - } + if ctx == nil { + if rows, err = star.Stmt.Query(args...); err != nil { + return effect, err + } + } else { + if rows, err = star.Stmt.QueryContext(ctx, args...); err != nil { + return effect, err } - effect.Rows = rows + } + effect.Rows = rows + if !star.Db.ManualScan { err = effect.parserows() - return effect, err } - sqlStr := args[0] + return effect, err +} + +func (star *StarDBStmt) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { + if len(args) == 0 { + return nil, errors.New("no args") + } if ctx == nil { - stmt, err = star.Db.Prepare(sqlStr.(string)) - } else { - stmt, err = star.Db.PrepareContext(ctx, sqlStr.(string)) + return star.Stmt.Exec(args...) + } + return star.Stmt.ExecContext(ctx, args...) +} + +func (star *StarDB) Prepare(sqlStr string) (*StarDBStmt, error) { + stmt := new(StarDBStmt) + stmtS, err := star.Db.Prepare(sqlStr) + if err != nil { + return nil, err + } + stmt.Stmt = stmtS + stmt.Db = star + return stmt, err +} + +func (star *StarDB) PrepareContext(ctx context.Context, sqlStr string) (*StarDBStmt, error) { + stmt := new(StarDBStmt) + stmtS, err := star.Db.PrepareContext(ctx, sqlStr) + if err != nil { + return nil, err } + stmt.Stmt = stmtS + stmt.Db = star + return stmt, err +} + +func (star *StarTx) Prepare(sqlStr string) (*StarDBStmt, error) { + stmt := new(StarDBStmt) + stmtS, err := star.Tx.Prepare(sqlStr) if err != nil { + return nil, err + } + stmt.Stmt = stmtS + stmt.Db = star.Db + return stmt, err +} + +func (star *StarTx) PrepareContext(ctx context.Context, sqlStr string) (*StarDBStmt, error) { + stmt := new(StarDBStmt) + stmtS, err := star.Tx.PrepareContext(ctx, sqlStr) + if err != nil { + return nil, err + } + stmt.Db = star.Db + stmt.Stmt = stmtS + return stmt, err +} + +// Query 进行Query操作 +func (star *StarDB) query(ctx context.Context, args ...interface{}) (*StarRows, error) { + var err error + var rows *sql.Rows + effect := new(StarRows) + if err = star.Db.Ping(); err != nil { return effect, err } - defer stmt.Close() + if len(args) == 0 { + return effect, errors.New("no args") + } var para []interface{} for k, v := range args { if k != 0 { @@ -985,11 +1132,11 @@ func (star *StarDB) query(ctx context.Context, args ...interface{}) (*StarRows, } } if ctx == nil { - if rows, err = stmt.Query(para...); err != nil { + if rows, err = star.Db.Query(args[0].(string), para...); err != nil { return effect, err } } else { - if rows, err = stmt.QueryContext(ctx, para...); err != nil { + if rows, err = star.Db.QueryContext(ctx, args[0].(string), para...); err != nil { return effect, err } } @@ -1046,37 +1193,12 @@ func (star *StarDB) ExecContext(ctx context.Context, args ...interface{}) (sql.R // Exec 执行Exec操作 func (star *StarDB) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { var err error - var effect sql.Result if err = star.Db.Ping(); err != nil { - return effect, err + return nil, err } if len(args) == 0 { - return effect, errors.New("no args") + return nil, errors.New("no args") } - if len(args) == 1 { - sqlStr := args[0] - if ctx == nil { - if effect, err = star.Db.Exec(sqlStr.(string)); err != nil { - return effect, err - } - } else { - if effect, err = star.Db.ExecContext(ctx, sqlStr.(string)); err != nil { - return effect, err - } - } - return effect, nil - } - sqlStr := args[0] - var stmt *sql.Stmt - if ctx == nil { - stmt, err = star.Db.Prepare(sqlStr.(string)) - } else { - stmt, err = star.Db.PrepareContext(ctx, sqlStr.(string)) - } - if err != nil { - return effect, err - } - defer stmt.Close() var para []interface{} for k, v := range args { if k != 0 { @@ -1087,15 +1209,9 @@ func (star *StarDB) exec(ctx context.Context, args ...interface{}) (sql.Result, } } if ctx == nil { - if effect, err = stmt.Exec(para...); err != nil { - return effect, err - } - } else { - if effect, err = stmt.ExecContext(ctx, para...); err != nil { - return effect, err - } + return star.Db.Exec(args[0].(string), para...) } - return effect, nil + return star.Db.ExecContext(ctx, args[0].(string), para...) } func (star *StarTx) Exec(args ...interface{}) (sql.Result, error) { @@ -1107,34 +1223,12 @@ func (star *StarTx) ExecContext(ctx context.Context, args ...interface{}) (sql.R func (star *StarTx) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { var err error - var effect sql.Result - var stmt *sql.Stmt - if len(args) == 0 { - return effect, errors.New("no args") - } - if len(args) == 1 { - sqlStr := args[0] - if ctx == nil { - if _, err = star.Tx.Exec(sqlStr.(string)); err != nil { - return effect, err - } - } else { - if _, err = star.Tx.ExecContext(ctx, sqlStr.(string)); err != nil { - return effect, err - } - } - return effect, nil - } - sqlStr := args[0] - if ctx == nil { - stmt, err = star.Tx.Prepare(sqlStr.(string)) - } else { - stmt, err = star.Tx.PrepareContext(ctx, sqlStr.(string)) + if err = star.Db.Ping(); err != nil { + return nil, err } - if err != nil { - return effect, err + if len(args) == 0 { + return nil, errors.New("no args") } - defer stmt.Close() var para []interface{} for k, v := range args { if k != 0 { @@ -1145,15 +1239,9 @@ func (star *StarTx) exec(ctx context.Context, args ...interface{}) (sql.Result, } } if ctx == nil { - if effect, err = stmt.Exec(para...); err != nil { - return effect, err - } - } else { - if effect, err = stmt.ExecContext(ctx, para...); err != nil { - return effect, err - } + return star.Tx.Exec(args[0].(string), para...) } - return effect, nil + return star.Tx.ExecContext(ctx, args[0].(string), para...) } func (star *StarTx) Commit() error {