diff --git a/orm_v1.go b/orm_v1.go index ea09c06..95642ea 100644 --- a/orm_v1.go +++ b/orm_v1.go @@ -1,6 +1,7 @@ package stardb import ( + "context" "database/sql" "errors" "reflect" @@ -22,14 +23,15 @@ func (star *StarRows) Orm(ins interface{}) error { //now convert to slice t = t.Elem() v = v.Elem() - if star.Length == 0 { - return nil - } if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { //get type of slice sigType := t.Elem() var result reflect.Value result = reflect.New(t).Elem() + if star.Length == 0 { + v.Set(result) + return nil + } for i := 0; i < star.Length; i++ { val := reflect.New(sigType) star.setAllRefValue(val.Interface(), "db", i) @@ -38,16 +40,21 @@ func (star *StarRows) Orm(ins interface{}) error { v.Set(result) return nil } - + if star.Length == 0 { + return nil + } return star.setAllRefValue(ins, "db", 0) } -func (star *StarDB) QueryX(sql string, ins interface{}, args ...interface{}) (*StarRows, error) { +func (star *StarDB) queryX(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) { kvMap, err := getAllRefValue(ins, "db") if err != nil { return nil, err } for k, v := range args { + if k == 0 { + continue + } switch v.(type) { case string: str := v.(string) @@ -64,15 +71,77 @@ func (star *StarDB) QueryX(sql string, ins interface{}, args ...interface{}) (*S } } } - return star.Query(sql, args) + return star.query(ctx, args...) +} +func (star *StarDB) QueryX(ins interface{}, args ...interface{}) (*StarRows, error) { + return star.queryX(nil, ins, args) +} +func (star *StarDB) QueryXS(ins interface{}, args ...interface{}) ([]*StarRows, error) { + var starRes []*StarRows + t := reflect.TypeOf(ins) + v := reflect.ValueOf(ins) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + //now convert to slice + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { + for i := 0; i < v.Len(); i++ { + result, err := star.queryX(nil, v.Index(i).Interface(), args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + } else { + result, err := star.queryX(nil, ins, args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + return starRes, nil } -func (star *StarDB) ExecX(sql string, ins interface{}, args ...interface{}) (sql.Result, error) { +func (star *StarDB) ExecXS(ins interface{}, args ...interface{}) ([]sql.Result, error) { + var starRes []sql.Result + t := reflect.TypeOf(ins) + v := reflect.ValueOf(ins) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + //now convert to slice + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { + for i := 0; i < v.Len(); i++ { + result, err := star.execX(nil, v.Index(i).Interface(), args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + } else { + result, err := star.execX(nil, ins, args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + return starRes, nil +} + +func (star *StarDB) ExecX(ins interface{}, args ...interface{}) (sql.Result, error) { + return star.execX(nil, ins, args...) +} +func (star *StarDB) execX(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) { kvMap, err := getAllRefValue(ins, "db") if err != nil { return nil, err } for k, v := range args { + if k == 0 { + continue + } switch v.(type) { case string: str := v.(string) @@ -89,5 +158,244 @@ func (star *StarDB) ExecX(sql string, ins interface{}, args ...interface{}) (sql } } } - return star.Exec(sql, args) + return star.exec(ctx, args...) +} + +func (star *StarDB) QueryXContext(ctx context.Context,ins interface{}, args ...interface{}) (*StarRows, error) { + return star.queryX(ctx, ins, args) +} +func (star *StarDB) QueryXSContext(ctx context.Context,ins interface{}, args ...interface{}) ([]*StarRows, error) { + var starRes []*StarRows + t := reflect.TypeOf(ins) + v := reflect.ValueOf(ins) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + //now convert to slice + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { + for i := 0; i < v.Len(); i++ { + result, err := star.queryX(ctx, v.Index(i).Interface(), args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + } else { + result, err := star.queryX(ctx, ins, args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + return starRes, nil +} + +func (star *StarDB) ExecXSContext(ctx context.Context,ins interface{}, args ...interface{}) ([]sql.Result, error) { + var starRes []sql.Result + t := reflect.TypeOf(ins) + v := reflect.ValueOf(ins) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + //now convert to slice + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { + for i := 0; i < v.Len(); i++ { + result, err := star.execX(ctx, v.Index(i).Interface(), args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + } else { + result, err := star.execX(ctx, ins, args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + return starRes, nil } + +func (star *StarDB) ExecXContext(ctx context.Context,ins interface{}, args ...interface{}) (sql.Result, error) { + return star.execX(ctx, ins, args...) +} + + + +func (star *StarTx) queryX(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) { + kvMap, err := getAllRefValue(ins, "db") + if err != nil { + return nil, err + } + for k, v := range args { + if k == 0 { + continue + } + switch v.(type) { + case string: + str := v.(string) + if strings.Index(str, ":") == 0 { + if _, ok := kvMap[str[1:]]; ok { + args[k] = kvMap[str[1:]] + } else { + args[k] = "" + } + continue + } + if strings.Index(str, `\:`) == 0 { + args[k] = kvMap[str[1:]] + } + } + } + return star.query(ctx, args...) +} +func (star *StarTx) QueryX(ins interface{}, args ...interface{}) (*StarRows, error) { + return star.queryX(nil, ins, args) +} +func (star *StarTx) QueryXS(ins interface{}, args ...interface{}) ([]*StarRows, error) { + var starRes []*StarRows + t := reflect.TypeOf(ins) + v := reflect.ValueOf(ins) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + //now convert to slice + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { + for i := 0; i < v.Len(); i++ { + result, err := star.queryX(nil, v.Index(i).Interface(), args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + } else { + result, err := star.queryX(nil, ins, args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + return starRes, nil +} + +func (star *StarTx) ExecXS(ins interface{}, args ...interface{}) ([]sql.Result, error) { + var starRes []sql.Result + t := reflect.TypeOf(ins) + v := reflect.ValueOf(ins) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + //now convert to slice + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { + for i := 0; i < v.Len(); i++ { + result, err := star.execX(nil, v.Index(i).Interface(), args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + } else { + result, err := star.execX(nil, ins, args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + return starRes, nil +} + +func (star *StarTx) ExecX(ins interface{}, args ...interface{}) (sql.Result, error) { + return star.execX(nil, ins, args...) +} +func (star *StarTx) execX(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) { + kvMap, err := getAllRefValue(ins, "db") + if err != nil { + return nil, err + } + for k, v := range args { + if k == 0 { + continue + } + switch v.(type) { + case string: + str := v.(string) + if strings.Index(str, ":") == 0 { + if _, ok := kvMap[str[1:]]; ok { + args[k] = kvMap[str[1:]] + } else { + args[k] = "" + } + continue + } + if strings.Index(str, `\:`) == 0 { + args[k] = kvMap[str[1:]] + } + } + } + return star.exec(ctx, args...) +} + +func (star *StarTx) QueryXContext(ctx context.Context,ins interface{}, args ...interface{}) (*StarRows, error) { + return star.queryX(ctx, ins, args) +} +func (star *StarTx) QueryXSContext(ctx context.Context,ins interface{}, args ...interface{}) ([]*StarRows, error) { + var starRes []*StarRows + t := reflect.TypeOf(ins) + v := reflect.ValueOf(ins) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + //now convert to slice + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { + for i := 0; i < v.Len(); i++ { + result, err := star.queryX(ctx, v.Index(i).Interface(), args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + } else { + result, err := star.queryX(ctx, ins, args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + return starRes, nil +} + +func (star *StarTx) ExecXSContext(ctx context.Context,ins interface{}, args ...interface{}) ([]sql.Result, error) { + var starRes []sql.Result + t := reflect.TypeOf(ins) + v := reflect.ValueOf(ins) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + //now convert to slice + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { + for i := 0; i < v.Len(); i++ { + result, err := star.execX(ctx, v.Index(i).Interface(), args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + } else { + result, err := star.execX(ctx, ins, args...) + if err != nil { + return starRes, err + } + starRes = append(starRes, result) + } + return starRes, nil +} + +func (star *StarTx) ExecXContext(ctx context.Context,ins interface{}, args ...interface{}) (sql.Result, error) { + return star.execX(ctx, ins, args...) +} \ No newline at end of file diff --git a/reflect.go b/reflect.go index 021abbd..e5593a7 100644 --- a/reflect.go +++ b/reflect.go @@ -7,16 +7,36 @@ import ( func (star *StarRows) setAllRefValue(stc interface{}, skey string, rows int) error { t := reflect.TypeOf(stc) - v := reflect.ValueOf(stc).Elem() - if t.Kind() != reflect.Ptr || !v.CanSet() { + v := reflect.ValueOf(stc) + if t.Kind() == reflect.Ptr { + v = v.Elem() + } + if t.Kind() != reflect.Ptr && !v.CanSet() { return errors.New("interface{} is not writable") } + if t.Kind() == reflect.Ptr { + t = t.Elem() + } if v.Kind() != reflect.Struct { return errors.New("interface{} is not a struct") } - t = t.Elem() + for i := 0; i < t.NumField(); i++ { tp := t.Field(i) + srFrd := v.Field(i) + + if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct { + sp := reflect.New(reflect.TypeOf(srFrd.Interface()).Elem()).Interface() + star.setAllRefValue(sp, skey, rows) + v.Field(i).Set(reflect.ValueOf(sp)) + continue + } + if srFrd.Kind() == reflect.Struct { + sp := reflect.New(reflect.TypeOf(v.Field(i).Interface())).Interface() + star.setAllRefValue(sp, skey, rows) + v.Field(i).Set(reflect.ValueOf(sp).Elem()) + continue + } seg := tp.Tag.Get(skey) if seg == "" { continue @@ -25,7 +45,7 @@ func (star *StarRows) setAllRefValue(stc interface{}, skey string, rows int) err continue } myInt64 := star.Row(rows).MustInt64(seg) - myUint64 := uint64(star.Row(rows).MustInt64(seg)) + myUint64 := star.Row(rows).MustUint64(seg) switch v.Field(i).Kind() { case reflect.String: v.Field(i).SetString(star.Row(rows).MustString(seg)) @@ -58,30 +78,6 @@ func (star *StarRows) setAllRefValue(stc interface{}, skey string, rows int) err return nil } -func setAllRefValue(stc interface{}, skey string, kv map[string]interface{}) error { - t := reflect.TypeOf(stc) - v := reflect.ValueOf(stc).Elem() - if t.Kind() != reflect.Ptr || !v.CanSet() { - return errors.New("interface{} is not writable") - } - if v.Kind() != reflect.Struct { - return errors.New("interface{} is not a struct") - } - t = t.Elem() - for i := 0; i < t.NumField(); i++ { - tp := t.Field(i) - seg := tp.Tag.Get(skey) - if seg == "" { - continue - } - if _, ok := kv[seg]; !ok { - continue - } - v.Field(i).Set(reflect.ValueOf(kv[seg])) - } - return nil -} - func setRefValue(stc interface{}, skey, key string, value interface{}) error { t := reflect.TypeOf(stc) v := reflect.ValueOf(stc).Elem() @@ -116,6 +112,27 @@ func getAllRefValue(stc interface{}, skey string) (map[string]interface{}, error } for i := 0; i < t.NumField(); i++ { tp := t.Field(i) + srFrd := v.Field(i) + if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct { + res, err := getAllRefValue(reflect.ValueOf(srFrd.Elem().Interface()).Interface(), skey) + if err != nil { + return result, err + } + for k, v := range res { + result[k] = v + } + continue + } + if v.Field(i).Kind() == reflect.Struct { + res, err := getAllRefValue(v.Field(i).Interface(), skey) + if err != nil { + return result, err + } + for k, v := range res { + result[k] = v + } + continue + } seg := tp.Tag.Get(skey) if seg == "" { continue @@ -136,10 +153,32 @@ func getAllRefKey(stc interface{}, skey string) ([]string, error) { return []string{}, errors.New("interface{} is not a struct") } t := reflect.TypeOf(stc) + v := reflect.ValueOf(stc) if t.Kind() == reflect.Ptr { t = t.Elem() + v = v.Elem() } for i := 0; i < t.NumField(); i++ { + srFrd := v.Field(i) + if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct { + res, err := getAllRefKey(reflect.ValueOf(srFrd.Elem().Interface()).Interface(), skey) + if err != nil { + return result, err + } + for _, v := range res { + result = append(result, v) + } + continue + } + if v.Field(i).Kind() == reflect.Struct { + res, err := getAllRefKey(v.Field(i).Interface(), skey) + if err != nil { + return result, err + } + for _, v := range res { + result = append(result, v) + } + } profile := t.Field(i) seg := profile.Tag.Get(skey) if seg != "" { @@ -152,7 +191,7 @@ func getAllRefKey(stc interface{}, skey string) ([]string, error) { func isWritableStruct(stc interface{}) (isWritable bool, isStruct bool) { t := reflect.TypeOf(stc) v := reflect.ValueOf(stc) - if t.Kind() == reflect.Ptr && v.Elem().CanSet() { + if t.Kind() == reflect.Ptr || v.CanSet() { isWritable = true } if v.Kind() == reflect.Struct { diff --git a/reflect_test.go b/reflect_test.go index 9445aad..99a58c9 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -2,43 +2,34 @@ package stardb import ( "fmt" - "reflect" "testing" ) type Useless struct { Leader string `db:"leader"` Usable bool `db:"use"` + O *Whoami +} + +type Whoami struct { + Hehe string `db:"hehe"` } func Test_SetRefVal(t *testing.T) { var hehe = Useless{ Leader: "no", } - mmval := map[string]interface{}{ - "leader": "hehe", - "use": true, - } fmt.Printf("%+v\n", hehe) fmt.Println(setRefValue(&hehe, "db", "leader", "sb")) fmt.Printf("%+v\n", hehe) - fmt.Println(setAllRefValue(&hehe, "db", mmval)) fmt.Printf("%+v\n", hehe) fmt.Println(getAllRefKey(hehe, "db")) } func Test_Ref(t *testing.T) { - var me []Useless - p := reflect.TypeOf(&me).Elem() - v := reflect.ValueOf(&me).Elem() - mmval := map[string]interface{}{ - "leader": "hehe", - "use": true, + oooooo := Useless{ + Leader: "Heheeee", } - newVal := reflect.New(p) - val := reflect.New(p.Elem()) - setAllRefValue(val.Interface(), "db", mmval) - mynum:= reflect.Append(newVal.Elem(), val.Elem()) - v.Set(mynum) - fmt.Println(val.Interface(), me, v) + oooooo.O = &Whoami{"fuck"} + fmt.Println(getAllRefKey(oooooo,"db")) } diff --git a/stardb_v1.go b/stardb_v1.go index 4db6e5f..ba71320 100644 --- a/stardb_v1.go +++ b/stardb_v1.go @@ -1,6 +1,7 @@ package stardb import ( + "context" "database/sql" "errors" "reflect" @@ -10,7 +11,13 @@ import ( // StarDB 一个简单封装的DB库 type StarDB struct { - DB *sql.DB + Db *sql.DB + ManualScan bool +} + +type StarTx struct { + Db *sql.DB + Tx *sql.Tx ManualScan bool } @@ -88,6 +95,12 @@ func (star *StarResultCol) MustBool() []bool { } else { tmp = false } + case uint64: + if vtype > 0 { + tmp = true + } else { + tmp = false + } case string: tmp, _ = strconv.ParseBool(vtype) default: @@ -119,6 +132,8 @@ func (star *StarResultCol) MustFloat32() []float32 { tmp = float32(vtype) case int64: tmp = float32(vtype) + case uint64: + tmp = float32(vtype) case time.Time: tmp = float32(vtype.Unix()) default: @@ -151,6 +166,8 @@ func (star *StarResultCol) MustFloat64() []float64 { tmp = float64(vtype) case int64: tmp = float64(vtype) + case uint64: + tmp = float64(vtype) case time.Time: tmp = float64(vtype.Unix()) default: @@ -185,6 +202,8 @@ func (star *StarResultCol) MustString() []string { tmp = strconv.FormatFloat(float64(vtype), 'f', 10, 32) case int: tmp = strconv.Itoa(vtype) + case uint64: + tmp = strconv.FormatUint(vtype, 10) case time.Time: tmp = vtype.String() default: @@ -214,6 +233,8 @@ func (star *StarResultCol) MustInt32() []int32 { tmp = int32(vtype) case int64: tmp = int32(vtype) + case uint64: + tmp = int32(vtype) case int32: tmp = vtype case time.Time: @@ -247,6 +268,8 @@ func (star *StarResultCol) MustInt64() []int64 { tmp = int64(vtype) case int32: tmp = int64(vtype) + case uint64: + tmp = int64(vtype) case int64: tmp = vtype case time.Time: @@ -260,6 +283,39 @@ func (star *StarResultCol) MustInt64() []int64 { return res } +// MustUint64 列查询结果转Int64 +func (star *StarResultCol) MustUint64() []uint64 { + var res []uint64 + var tmp uint64 + for _, v := range star.Result { + switch vtype := v.(type) { + case nil: + tmp = 0 + case float64: + tmp = uint64(vtype) + case float32: + tmp = uint64(vtype) + case string: + tmp, _ = strconv.ParseUint(vtype, 10, 64) + case int: + tmp = uint64(vtype) + case int32: + tmp = uint64(vtype) + case int64: + tmp = uint64(vtype) + case uint64: + tmp = vtype + case time.Time: + tmp = uint64(vtype.Unix()) + default: + tmpt := string(vtype.([]byte)) + tmp, _ = strconv.ParseUint(tmpt, 10, 64) + } + res = append(res, tmp) + } + return res +} + // MustInt 列查询结果转Int func (star *StarResultCol) MustInt() []int { var res []int @@ -281,6 +337,8 @@ func (star *StarResultCol) MustInt() []int { tmp = int(vtype) case int64: tmp = int(vtype) + case uint64: + tmp = int(vtype) case time.Time: tmp = int(vtype.Unix()) default: @@ -313,6 +371,8 @@ func (star *StarResultCol) MustDate(layout string) []time.Time { tmp = time.Unix(int64(vtype), 0) case int64: tmp = time.Unix(vtype, 0) + case uint64: + tmp = time.Unix(int64(vtype), 0) case time.Time: tmp = vtype default: @@ -378,6 +438,8 @@ func (star *StarResult) MustDate(name, layout string) time.Time { res = time.Unix(int64(vtype), 0) case int64: res = time.Unix(vtype, 0) + case uint64: + res = time.Unix(int64(vtype), 0) case time.Time: res = vtype default: @@ -407,6 +469,8 @@ func (star *StarResult) MustInt64(name string) int64 { res = int64(vtype) case int32: res = int64(vtype) + case uint64: + res = int64(vtype) case int64: res = vtype case time.Time: @@ -441,6 +505,8 @@ func (star *StarResult) MustInt32(name string) int32 { res = vtype case int64: res = int32(vtype) + case uint64: + res = int32(vtype) case time.Time: res = int32(vtype.Unix()) default: @@ -450,6 +516,37 @@ func (star *StarResult) MustInt32(name string) int32 { return res } +// MustUint 列查询结果转uint +func (star *StarResult) MustUint64(name string) uint64 { + var res uint64 + num, ok := star.columnref[name] + if !ok { + return 0 + } + tmp := star.Result[num] + switch vtype := tmp.(type) { + case nil: + res = 0 + case float64: + res = uint64(vtype) + case float32: + res = uint64(vtype) + case string: + res, _ = strconv.ParseUint(vtype, 10, 64) + case uint64: + res = vtype + case int32: + res = uint64(vtype) + case int64: + res = uint64(vtype) + case time.Time: + res = uint64(vtype.Unix()) + default: + res, _ = strconv.ParseUint(string(tmp.([]byte)), 10, 64) + } + return res +} + // MustString 列查询结果转string func (star *StarResult) MustString(name string) string { var res string @@ -474,6 +571,8 @@ func (star *StarResult) MustString(name string) string { res = strconv.FormatFloat(float64(vtype), 'f', 10, 32) case int: res = strconv.Itoa(vtype) + case uint64: + res = strconv.FormatUint(vtype, 10) case time.Time: res = vtype.String() default: @@ -504,6 +603,8 @@ func (star *StarResult) MustFloat64(name string) float64 { res = float64(vtype) case float32: res = float64(vtype) + case uint64: + res = float64(vtype) case time.Time: res = float64(vtype.Unix()) default: @@ -535,6 +636,8 @@ func (star *StarResult) MustFloat32(name string) float32 { res = float32(vtype) case int32: res = float32(vtype) + case uint64: + res = float32(vtype) case time.Time: res = float32(vtype.Unix()) default: @@ -568,6 +671,8 @@ func (star *StarResult) MustInt(name string) int { res = int(vtype) case int64: res = int(vtype) + case uint64: + res = int(vtype) case time.Time: res = int(vtype.Unix()) default: @@ -620,6 +725,12 @@ func (star *StarResult) MustBool(name string) bool { } else { res = false } + case uint64: + if vtype > 0 { + res = true + } else { + res = false + } case string: res, _ = strconv.ParseBool(vtype) default: @@ -731,28 +842,135 @@ func (star *StarRows) parserows() error { return nil } -// Query 进行Query操作 +func (star *StarDB) Begin() (*StarTx, error) { + tx, err := star.Db.Begin() + if err != nil { + return nil, err + } + stx := new(StarTx) + stx.Db = star.Db + stx.Tx = tx + return stx, err +} + +func (star *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, error) { + tx, err := star.Db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + stx := new(StarTx) + stx.Db = star.Db + stx.Tx = tx + return stx, err +} + +func (star *StarTx) Query(args ...interface{}) (*StarRows, error) { + return star.query(nil, args...) +} + +func (star *StarTx) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) { + return star.query(ctx, args...) +} + +func (star *StarTx) query(ctx context.Context, args ...interface{}) (*StarRows, error) { + var err error + var rows *sql.Rows + var stmt *sql.Stmt + effect := new(StarRows) + if len(args) == 0 { + return effect, errors.New("no args") + } + if len(args) == 1 { + sqlStr := args[0] + if ctx == nil { + if rows, err = star.Tx.Query(sqlStr.(string)); err != nil { + return effect, err + } + } else { + if rows, err = star.Tx.QueryContext(ctx, sqlStr.(string)); err != nil { + return effect, err + } + } + effect.Rows = rows + err = effect.parserows() + return effect, err + } + sqlStr := args[0] + if ctx == nil { + stmt, err = star.Tx.Prepare(sqlStr.(string)) + } else { + stmt, err = star.Tx.PrepareContext(ctx, sqlStr.(string)) + } + if err != nil { + return effect, err + } + defer stmt.Close() + var para []interface{} + for k, v := range args { + if k != 0 { + switch vtype := v.(type) { + default: + para = append(para, vtype) + } + } + } + if ctx == nil { + if rows, err = stmt.Query(para...); err != nil { + return effect, err + } + } else { + if rows, err = stmt.QueryContext(ctx, para...); err != nil { + return effect, err + } + } + effect.Rows = rows + if !star.ManualScan { + err = effect.parserows() + } + return effect, err +} + func (star *StarDB) Query(args ...interface{}) (*StarRows, error) { + return star.query(nil, args...) +} + +func (star *StarDB) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) { + return star.query(ctx, args...) +} + +// Query 进行Query操作 +func (star *StarDB) query(ctx context.Context, args ...interface{}) (*StarRows, error) { var err error var rows *sql.Rows + var stmt *sql.Stmt effect := new(StarRows) - if err = star.DB.Ping(); err != nil { + if err = star.Db.Ping(); err != nil { return effect, err } if len(args) == 0 { return effect, errors.New("no args") } if len(args) == 1 { - sql := args[0] - if rows, err = star.DB.Query(sql.(string)); err != nil { - return effect, err + sqlStr := args[0] + if ctx == nil { + if rows, err = star.Db.Query(sqlStr.(string)); err != nil { + return effect, err + } + } else { + if rows, err = star.Db.Query(sqlStr.(string)); err != nil { + return effect, err + } } effect.Rows = rows - effect.parserows() - return effect, nil + err = effect.parserows() + return effect, err + } + sqlStr := args[0] + if ctx == nil { + stmt, err = star.Db.Prepare(sqlStr.(string)) + } else { + stmt, err = star.Db.PrepareContext(ctx, sqlStr.(string)) } - sql := args[0] - stmt, err := star.DB.Prepare(sql.(string)) if err != nil { return effect, err } @@ -766,8 +984,14 @@ func (star *StarDB) Query(args ...interface{}) (*StarRows, error) { } } } - if rows, err = stmt.Query(para...); err != nil { - return effect, err + if ctx == nil { + if rows, err = stmt.Query(para...); err != nil { + return effect, err + } + } else { + if rows, err = stmt.QueryContext(ctx, para...); err != nil { + return effect, err + } } effect.Rows = rows if !star.ManualScan { @@ -779,41 +1003,76 @@ func (star *StarDB) Query(args ...interface{}) (*StarRows, error) { // Open 打开一个新的数据库 func (star *StarDB) Open(Method, ConnStr string) error { var err error - star.DB, err = sql.Open(Method, ConnStr) - if err != nil { - return err - } - err = star.DB.Ping() + star.Db, err = sql.Open(Method, ConnStr) return err } // Close 关闭打开的数据库 func (star *StarDB) Close() error { - if err := star.DB.Close(); err != nil { - return err - } - return star.DB.Close() + return star.Db.Close() +} + +func (star *StarDB) Ping() error { + return star.Db.Ping() +} + +func (star *StarDB) Stats() sql.DBStats { + return star.Db.Stats() +} + +func (star *StarDB) SetMaxOpenConns(n int) { + star.Db.SetMaxOpenConns(n) +} + +func (star *StarDB) SetMaxIdleConns(n int) { + star.Db.SetMaxIdleConns(n) +} + +func (star *StarDB) PingContext(ctx context.Context) error { + return star.Db.PingContext(ctx) +} + +func (star *StarDB) Conn(ctx context.Context) (*sql.Conn, error) { + return star.Db.Conn(ctx) } -// Exec 执行Exec操作 func (star *StarDB) Exec(args ...interface{}) (sql.Result, error) { + return star.exec(nil, args...) +} +func (star *StarDB) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + return star.exec(ctx, args...) +} + +// Exec 执行Exec操作 +func (star *StarDB) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { var err error var effect sql.Result - if err = star.DB.Ping(); err != nil { + if err = star.Db.Ping(); err != nil { return effect, err } if len(args) == 0 { return effect, errors.New("no args") } if len(args) == 1 { - sql := args[0] - if _, err = star.DB.Exec(sql.(string)); err != nil { - return effect, err + sqlStr := args[0] + if ctx == nil { + if effect, err = star.Db.Exec(sqlStr.(string)); err != nil { + return effect, err + } + } else { + if effect, err = star.Db.ExecContext(ctx, sqlStr.(string)); err != nil { + return effect, err + } } return effect, nil } - sql := args[0] - stmt, err := star.DB.Prepare(sql.(string)) + sqlStr := args[0] + var stmt *sql.Stmt + if ctx == nil { + stmt, err = star.Db.Prepare(sqlStr.(string)) + } else { + stmt, err = star.Db.PrepareContext(ctx, sqlStr.(string)) + } if err != nil { return effect, err } @@ -827,12 +1086,84 @@ func (star *StarDB) Exec(args ...interface{}) (sql.Result, error) { } } } - if effect, err = stmt.Exec(para...); err != nil { + if ctx == nil { + if effect, err = stmt.Exec(para...); err != nil { + return effect, err + } + } else { + if effect, err = stmt.ExecContext(ctx, para...); err != nil { + return effect, err + } + } + return effect, nil +} + +func (star *StarTx) Exec(args ...interface{}) (sql.Result, error) { + return star.exec(nil, args...) +} +func (star *StarTx) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + return star.exec(ctx, args...) +} + +func (star *StarTx) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { + var err error + var effect sql.Result + var stmt *sql.Stmt + if len(args) == 0 { + return effect, errors.New("no args") + } + if len(args) == 1 { + sqlStr := args[0] + if ctx == nil { + if _, err = star.Tx.Exec(sqlStr.(string)); err != nil { + return effect, err + } + } else { + if _, err = star.Tx.ExecContext(ctx, sqlStr.(string)); err != nil { + return effect, err + } + } + return effect, nil + } + sqlStr := args[0] + if ctx == nil { + stmt, err = star.Tx.Prepare(sqlStr.(string)) + } else { + stmt, err = star.Tx.PrepareContext(ctx, sqlStr.(string)) + } + if err != nil { return effect, err } + defer stmt.Close() + var para []interface{} + for k, v := range args { + if k != 0 { + switch vtype := v.(type) { + default: + para = append(para, vtype) + } + } + } + if ctx == nil { + if effect, err = stmt.Exec(para...); err != nil { + return effect, err + } + } else { + if effect, err = stmt.ExecContext(ctx, para...); err != nil { + return effect, err + } + } return effect, nil } +func (star *StarTx) Commit() error { + return star.Tx.Commit() +} + +func (star *StarTx) Rollback() error { + return star.Tx.Rollback() +} + // FetchAll 把结果集全部转为key-value型数据 func FetchAll(rows *sql.Rows) (error, map[int]map[string]string) { var ii int = 0