package stardb import ( "context" "database/sql" "errors" "reflect" "strconv" "time" ) // StarDB 一个简单封装的DB库 type StarDB struct { Db *sql.DB ManualScan bool } type StarTx struct { Db *StarDB Tx *sql.Tx } // StarRows 为查询结果集(按行) type StarRows struct { Rows *sql.Rows Length int StringResult []map[string]string Columns []string ColumnsType []reflect.Type columnref map[string]int result [][]interface{} parsed bool } type StarDBStmt struct { Stmt *sql.Stmt Db *StarDB } // StarResult 为查询结果集(总) type StarResult struct { Result []interface{} Columns []string columnref map[string]int ColumnsType []reflect.Type } // StarResultCol 为查询结果集(按列) type StarResultCol struct { Result []interface{} } // MustBytes 列查询结果转Bytes func (star *StarResultCol) MustBytes() [][]byte { var res [][]byte for _, v := range star.Result { res = append(res, v.([]byte)) } return res } // MustBool 列查询结果转Bool func (star *StarResultCol) MustBool() []bool { var res []bool var tmp bool for _, v := range star.Result { switch vtype := v.(type) { case nil: tmp = false case bool: tmp = vtype case float64: if vtype > 0 { tmp = true } else { tmp = false } case float32: if vtype > 0 { tmp = true } else { tmp = false } case int: if vtype > 0 { tmp = true } else { tmp = false } case int32: if vtype > 0 { tmp = true } else { tmp = false } case int64: if vtype > 0 { tmp = true } else { tmp = false } case uint64: if vtype > 0 { tmp = true } else { tmp = false } case string: tmp, _ = strconv.ParseBool(vtype) default: tmp, _ = strconv.ParseBool(string(vtype.([]byte))) } res = append(res, tmp) } return res } // MustFloat32 列查询结果转Float32 func (star *StarResultCol) MustFloat32() []float32 { var res []float32 var tmp float32 for _, v := range star.Result { switch vtype := v.(type) { case nil: tmp = 0 case float32: tmp = vtype case float64: tmp = float32(vtype) case string: tmps, _ := strconv.ParseFloat(vtype, 32) tmp = float32(tmps) case int: tmp = float32(vtype) case int32: tmp = float32(vtype) case int64: tmp = float32(vtype) case uint64: tmp = float32(vtype) case time.Time: tmp = float32(vtype.Unix()) default: tmpt := string(vtype.([]byte)) tmps, _ := strconv.ParseFloat(tmpt, 32) tmp = float32(tmps) } res = append(res, tmp) } return res } // MustFloat64 列查询结果转Float64 func (star *StarResultCol) MustFloat64() []float64 { var res []float64 var tmp float64 for _, v := range star.Result { switch vtype := v.(type) { case nil: tmp = 0 case float64: tmp = vtype case float32: tmp = float64(vtype) case string: tmp, _ = strconv.ParseFloat(vtype, 64) case int: tmp = float64(vtype) case int32: tmp = float64(vtype) case int64: tmp = float64(vtype) case uint64: tmp = float64(vtype) case time.Time: tmp = float64(vtype.Unix()) default: tmpt := string(vtype.([]byte)) tmps, _ := strconv.ParseFloat(tmpt, 64) tmp = float64(tmps) } res = append(res, tmp) } return res } // MustString 列查询结果转String func (star *StarResultCol) MustString() []string { var res []string var tmp string for _, v := range star.Result { switch vtype := v.(type) { case nil: tmp = "" case string: tmp = vtype case int64: tmp = strconv.FormatInt(vtype, 10) case int32: tmp = strconv.Itoa(int(vtype)) case bool: tmp = strconv.FormatBool(vtype) case float64: tmp = strconv.FormatFloat(vtype, 'f', 10, 64) case float32: tmp = strconv.FormatFloat(float64(vtype), 'f', 10, 32) case int: tmp = strconv.Itoa(vtype) case uint64: tmp = strconv.FormatUint(vtype, 10) case time.Time: tmp = vtype.String() default: tmp = string(vtype.([]byte)) } res = append(res, tmp) } return res } // MustInt32 列查询结果转Int32 func (star *StarResultCol) MustInt32() []int32 { var res []int32 var tmp int32 for _, v := range star.Result { switch vtype := v.(type) { case nil: tmp = 0 case float64: tmp = int32(vtype) case float32: tmp = int32(vtype) case string: tmps, _ := strconv.ParseInt(vtype, 10, 32) tmp = int32(tmps) case int: tmp = int32(vtype) case int64: tmp = int32(vtype) case uint64: tmp = int32(vtype) case int32: tmp = vtype case time.Time: tmp = int32(vtype.Unix()) default: tmpt := string(vtype.([]byte)) tmps, _ := strconv.ParseInt(tmpt, 10, 32) tmp = int32(tmps) } res = append(res, tmp) } return res } // MustInt64 列查询结果转Int64 func (star *StarResultCol) MustInt64() []int64 { var res []int64 var tmp int64 for _, v := range star.Result { switch vtype := v.(type) { case nil: tmp = 0 case float64: tmp = int64(vtype) case float32: tmp = int64(vtype) case string: tmps, _ := strconv.ParseInt(vtype, 10, 64) tmp = int64(tmps) case int: tmp = int64(vtype) case int32: tmp = int64(vtype) case uint64: tmp = int64(vtype) case int64: tmp = vtype case time.Time: tmp = vtype.Unix() default: tmpt := string(vtype.([]byte)) tmp, _ = strconv.ParseInt(tmpt, 10, 64) } res = append(res, tmp) } return res } // MustUint64 列查询结果转Int64 func (star *StarResultCol) MustUint64() []uint64 { var res []uint64 var tmp uint64 for _, v := range star.Result { switch vtype := v.(type) { case nil: tmp = 0 case float64: tmp = uint64(vtype) case float32: tmp = uint64(vtype) case string: tmp, _ = strconv.ParseUint(vtype, 10, 64) case int: tmp = uint64(vtype) case int32: tmp = uint64(vtype) case int64: tmp = uint64(vtype) case uint64: tmp = vtype case time.Time: tmp = uint64(vtype.Unix()) default: tmpt := string(vtype.([]byte)) tmp, _ = strconv.ParseUint(tmpt, 10, 64) } res = append(res, tmp) } return res } // MustInt 列查询结果转Int func (star *StarResultCol) MustInt() []int { var res []int var tmp int for _, v := range star.Result { switch vtype := v.(type) { case nil: tmp = 0 case float64: tmp = int(vtype) case float32: tmp = int(vtype) case string: tmps, _ := strconv.ParseInt(vtype, 10, 64) tmp = int(tmps) case int: tmp = vtype case int32: tmp = int(vtype) case int64: tmp = int(vtype) case uint64: tmp = int(vtype) case time.Time: tmp = int(vtype.Unix()) default: tmpt := string(vtype.([]byte)) tmps, _ := strconv.ParseInt(tmpt, 10, 64) tmp = int(tmps) } res = append(res, tmp) } return res } // MustDate 列查询结果转Date(time.Time) func (star *StarResultCol) MustDate(layout string) []time.Time { var res []time.Time var tmp time.Time for _, v := range star.Result { switch vtype := v.(type) { case nil: tmp = time.Time{} case float64: tmp = time.Unix(int64(vtype), int64(vtype-float64(int64(vtype)))*1000000000) case float32: tmp = time.Unix(int64(vtype), int64(vtype-float32(int64(vtype)))*1000000000) case string: tmp, _ = time.Parse(layout, vtype) case int: tmp = time.Unix(int64(vtype), 0) case int32: tmp = time.Unix(int64(vtype), 0) case int64: tmp = time.Unix(vtype, 0) case uint64: tmp = time.Unix(int64(vtype), 0) case time.Time: tmp = vtype default: tmpt := string(vtype.([]byte)) tmp, _ = time.Parse(layout, tmpt) } res = append(res, tmp) } return res } // IsNil 检测是不是nil 列查询结果是不是nil func (star *StarResultCol) IsNil(name string) []bool { var res []bool var tmp bool for _, v := range star.Result { switch v.(type) { case nil: tmp = true default: tmp = false } res = append(res, tmp) } return res } // IsNil 检测是不是nil func (star *StarResult) IsNil(name string) bool { num, ok := star.columnref[name] if !ok { return false } tmp := star.Result[num] switch tmp.(type) { case nil: return true default: return false } } // MustDate 列查询结果转Date func (star *StarResult) MustDate(name, layout string) time.Time { var res time.Time num, ok := star.columnref[name] if !ok { return time.Time{} } tmp := star.Result[num] switch vtype := tmp.(type) { case nil: res = time.Time{} case float64: res = time.Unix(int64(vtype), int64(vtype-float64(int64(vtype)))*1000000000) case float32: res = time.Unix(int64(vtype), int64(vtype-float32(int64(vtype)))*1000000000) case string: res, _ = time.Parse(layout, vtype) case int: res = time.Unix(int64(vtype), 0) case int32: res = time.Unix(int64(vtype), 0) case int64: res = time.Unix(vtype, 0) case uint64: res = time.Unix(int64(vtype), 0) case time.Time: res = vtype default: res, _ = time.Parse(layout, string(tmp.([]byte))) } return res } // MustInt64 列查询结果转int64 func (star *StarResult) MustInt64(name string) int64 { var res int64 num, ok := star.columnref[name] if !ok { return 0 } tmp := star.Result[num] switch vtype := tmp.(type) { case nil: res = 0 case float64: res = int64(vtype) case float32: res = int64(vtype) case string: res, _ = strconv.ParseInt(vtype, 10, 64) case int: res = int64(vtype) case int32: res = int64(vtype) case uint64: res = int64(vtype) case int64: res = vtype case time.Time: res = int64(vtype.Unix()) default: res, _ = strconv.ParseInt(string(tmp.([]byte)), 10, 64) } return res } // MustInt32 列查询结果转Int32 func (star *StarResult) MustInt32(name string) int32 { var res int32 num, ok := star.columnref[name] if !ok { return 0 } tmp := star.Result[num] switch vtype := tmp.(type) { case nil: res = 0 case float64: res = int32(vtype) case float32: res = int32(vtype) case string: ress, _ := strconv.ParseInt(vtype, 10, 32) res = int32(ress) case int: res = int32(vtype) case int32: res = vtype case int64: res = int32(vtype) case uint64: res = int32(vtype) case time.Time: res = int32(vtype.Unix()) default: ress, _ := strconv.ParseInt(string(tmp.([]byte)), 10, 32) res = int32(ress) } return res } // MustUint 列查询结果转uint func (star *StarResult) MustUint64(name string) uint64 { var res uint64 num, ok := star.columnref[name] if !ok { return 0 } tmp := star.Result[num] switch vtype := tmp.(type) { case nil: res = 0 case float64: res = uint64(vtype) case float32: res = uint64(vtype) case string: res, _ = strconv.ParseUint(vtype, 10, 64) case uint64: res = vtype case int32: res = uint64(vtype) case int64: res = uint64(vtype) case time.Time: res = uint64(vtype.Unix()) default: res, _ = strconv.ParseUint(string(tmp.([]byte)), 10, 64) } return res } // MustString 列查询结果转string func (star *StarResult) MustString(name string) string { var res string num, ok := star.columnref[name] if !ok { return "" } switch vtype := star.Result[num].(type) { case nil: res = "" case string: res = vtype case int64: res = strconv.FormatInt(vtype, 10) case int32: res = strconv.Itoa(int(vtype)) case bool: res = strconv.FormatBool(vtype) case float64: res = strconv.FormatFloat(vtype, 'f', 10, 64) case float32: res = strconv.FormatFloat(float64(vtype), 'f', 10, 32) case int: res = strconv.Itoa(vtype) case uint64: res = strconv.FormatUint(vtype, 10) case time.Time: res = vtype.String() default: res = string(vtype.([]byte)) } return res } // MustFloat64 列查询结果转float64 func (star *StarResult) MustFloat64(name string) float64 { var res float64 num, ok := star.columnref[name] if !ok { return 0 } switch vtype := star.Result[num].(type) { case nil: res = 0 case string: res, _ = strconv.ParseFloat(vtype, 64) case float64: res = vtype case int: res = float64(vtype) case int64: res = float64(vtype) case int32: res = float64(vtype) case float32: res = float64(vtype) case uint64: res = float64(vtype) case time.Time: res = float64(vtype.Unix()) default: res, _ = strconv.ParseFloat(string(vtype.([]byte)), 64) } return res } // MustFloat32 列查询结果转float32 func (star *StarResult) MustFloat32(name string) float32 { var res float32 num, ok := star.columnref[name] if !ok { return 0 } switch vtype := star.Result[num].(type) { case nil: res = 0 case string: tmp, _ := strconv.ParseFloat(vtype, 32) res = float32(tmp) case float64: res = float32(vtype) case float32: res = vtype case int: res = float32(vtype) case int64: res = float32(vtype) case int32: res = float32(vtype) case uint64: res = float32(vtype) case time.Time: res = float32(vtype.Unix()) default: tmp, _ := strconv.ParseFloat(string(vtype.([]byte)), 32) res = float32(tmp) } return res } // MustInt 列查询结果转int func (star *StarResult) MustInt(name string) int { var res int num, ok := star.columnref[name] if !ok { return 0 } tmp := star.Result[num] switch vtype := tmp.(type) { case nil: res = 0 case float64: res = int(vtype) case float32: res = int(vtype) case string: ress, _ := strconv.ParseInt(vtype, 10, 64) res = int(ress) case int: res = vtype case int32: res = int(vtype) case int64: res = int(vtype) case uint64: res = int(vtype) case time.Time: res = int(vtype.Unix()) default: ress, _ := strconv.ParseInt(string(tmp.([]byte)), 10, 64) res = int(ress) } return res } // MustBool 列查询结果转bool func (star *StarResult) MustBool(name string) bool { var res bool num, ok := star.columnref[name] if !ok { return false } tmp := star.Result[num] switch vtype := tmp.(type) { case nil: res = false case bool: res = vtype case float64: if vtype > 0 { res = true } else { res = false } case float32: if vtype > 0 { res = true } else { res = false } case int: if vtype > 0 { res = true } else { res = false } case int32: if vtype > 0 { res = true } else { res = false } case int64: if vtype > 0 { res = true } else { res = false } case uint64: if vtype > 0 { res = true } else { res = false } case string: res, _ = strconv.ParseBool(vtype) default: res, _ = strconv.ParseBool(string(vtype.([]byte))) } return res } // MustBytes 列查询结果转byte func (star *StarResult) MustBytes(name string) []byte { num, ok := star.columnref[name] if !ok { return []byte{} } res := star.Result[num].([]byte) return res } // Rescan 重新分析结果集 func (star *StarRows) Rescan() { star.parserows() } // Col 选择需要进行操作的数据结果列 func (star *StarRows) Col(name string) *StarResultCol { result := new(StarResultCol) if _, ok := star.columnref[name]; !ok { return result } var rescol []interface{} for _, v := range star.result { rescol = append(rescol, v[star.columnref[name]]) } result.Result = rescol return result } // Row 选择需要进行操作的数据结果行 func (star *StarRows) Row(id int) *StarResult { result := new(StarResult) if id+1 > len(star.result) { return result } result.Result = star.result[id] result.Columns = star.Columns result.ColumnsType = star.ColumnsType result.columnref = star.columnref return result } // Close 关闭打开的结果集 func (star *StarRows) Close() error { return star.Rows.Close() } func (star *StarRows) parserows() error { defer func() { star.parsed = true }() star.result = [][]interface{}{} star.columnref = make(map[string]int) star.StringResult = []map[string]string{} star.Columns, _ = star.Rows.Columns() types, _ := star.Rows.ColumnTypes() for _, v := range types { star.ColumnsType = append(star.ColumnsType, v.ScanType()) } scanArgs := make([]interface{}, len(star.Columns)) values := make([]interface{}, len(star.Columns)) for i := range values { star.columnref[star.Columns[i]] = i scanArgs[i] = &values[i] } for star.Rows.Next() { if err := star.Rows.Scan(scanArgs...); err != nil { return err } record := make(map[string]string) var rescopy []interface{} for i, col := range values { rescopy = append(rescopy, col) switch vtype := col.(type) { case float32: record[star.Columns[i]] = strconv.FormatFloat(float64(vtype), 'f', -1, 64) case float64: record[star.Columns[i]] = strconv.FormatFloat(vtype, 'f', -1, 64) case int64: record[star.Columns[i]] = strconv.FormatInt(vtype, 10) case int32: record[star.Columns[i]] = strconv.FormatInt(int64(vtype), 10) case int: record[star.Columns[i]] = strconv.Itoa(vtype) case string: record[star.Columns[i]] = vtype case bool: record[star.Columns[i]] = strconv.FormatBool(vtype) case time.Time: record[star.Columns[i]] = vtype.String() case nil: record[star.Columns[i]] = "" default: record[star.Columns[i]] = string(vtype.([]byte)) } } star.result = append(star.result, rescopy) star.StringResult = append(star.StringResult, record) } star.Length = len(star.StringResult) return nil } func (star *StarDB) Begin() (*StarTx, error) { tx, err := star.Db.Begin() if err != nil { return nil, err } stx := new(StarTx) stx.Db = star stx.Tx = tx return stx, err } func (star *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, error) { tx, err := star.Db.BeginTx(ctx, opts) if err != nil { return nil, err } stx := new(StarTx) stx.Db = star stx.Tx = tx return stx, err } func (star *StarTx) Query(args ...interface{}) (*StarRows, error) { return star.query(nil, args...) } func (star *StarTx) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) { return star.query(ctx, args...) } func (star *StarTx) ExecStmt(args ...interface{}) (sql.Result, error) { if len(args) <= 1 { return nil, errors.New("parameter not enough") } stmt, err := star.Prepare(args[0].(string)) if err != nil { return nil, err } defer stmt.Close() return stmt.Exec(args[1:]...) } func (star *StarTx) ExecStmtContext(ctx context.Context, args ...interface{}) (sql.Result, error) { if len(args) <= 1 { return nil, errors.New("parameter not enough") } stmt, err := star.PrepareContext(ctx, args[0].(string)) if err != nil { return nil, err } defer stmt.Close() return stmt.ExecContext(ctx, args[1:]...) } func (star *StarTx) QueryStmt(args ...interface{}) (*StarRows, error) { if len(args) <= 1 { return nil, errors.New("parameter not enough") } stmt, err := star.Prepare(args[0].(string)) if err != nil { return nil, err } defer stmt.Close() return stmt.Query(args[1:]...) } func (star *StarTx) QueryStmtContext(ctx context.Context, args ...interface{}) (*StarRows, error) { if len(args) <= 1 { return nil, errors.New("parameter not enough") } stmt, err := star.PrepareContext(ctx, args[0].(string)) if err != nil { return nil, err } defer stmt.Close() return stmt.QueryContext(ctx, args[1:]...) } func (star *StarTx) query(ctx context.Context, args ...interface{}) (*StarRows, error) { var err error var rows *sql.Rows effect := new(StarRows) if err = star.Db.Ping(); err != nil { return effect, err } if len(args) == 0 { return effect, errors.New("no args") } var para []interface{} for k, v := range args { if k != 0 { switch vtype := v.(type) { default: para = append(para, vtype) } } } if ctx == nil { if rows, err = star.Tx.Query(args[0].(string), para...); err != nil { return effect, err } } else { if rows, err = star.Tx.QueryContext(ctx, args[0].(string), para...); err != nil { return effect, err } } effect.Rows = rows if !star.Db.ManualScan { err = effect.parserows() } return effect, err } func (star *StarDB) Query(args ...interface{}) (*StarRows, error) { return star.query(nil, args...) } func (star *StarDB) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) { return star.query(ctx, args...) } func (star *StarDB) QueryStmt(args ...interface{}) (*StarRows, error) { if len(args) <= 1 { return nil, errors.New("parameter not enough") } stmt, err := star.Prepare(args[0].(string)) if err != nil { return nil, err } defer stmt.Close() return stmt.Query(args[1:]...) } func (star *StarDB) QueryStmtContext(ctx context.Context, args ...interface{}) (*StarRows, error) { if len(args) <= 1 { return nil, errors.New("parameter not enough") } stmt, err := star.PrepareContext(ctx, args[0].(string)) if err != nil { return nil, err } defer stmt.Close() return stmt.QueryContext(ctx, args[1:]...) } func (star *StarDB) ExecStmt(args ...interface{}) (sql.Result, error) { if len(args) <= 1 { return nil, errors.New("parameter not enough") } stmt, err := star.Prepare(args[0].(string)) if err != nil { return nil, err } defer stmt.Close() return stmt.Exec(args[1:]...) } func (star *StarDB) ExecStmtContext(ctx context.Context, args ...interface{}) (sql.Result, error) { if len(args) <= 1 { return nil, errors.New("parameter not enough") } stmt, err := star.PrepareContext(ctx, args[0].(string)) if err != nil { return nil, err } defer stmt.Close() return stmt.ExecContext(ctx, args[1:]...) } func (star *StarDBStmt) Query(args ...interface{}) (*StarRows, error) { return star.query(nil, args...) } func (star *StarDBStmt) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) { return star.query(ctx, args...) } func (star *StarDBStmt) Exec(args ...interface{}) (sql.Result, error) { return star.exec(nil, args...) } func (star *StarDBStmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { return star.exec(ctx, args...) } func (star *StarDBStmt) Close() error { return star.Stmt.Close() } func (star *StarDBStmt) query(ctx context.Context, args ...interface{}) (*StarRows, error) { var err error var rows *sql.Rows effect := new(StarRows) if len(args) == 0 { return effect, errors.New("no args") } if ctx == nil { if rows, err = star.Stmt.Query(args...); err != nil { return effect, err } } else { if rows, err = star.Stmt.QueryContext(ctx, args...); err != nil { return effect, err } } effect.Rows = rows if !star.Db.ManualScan { err = effect.parserows() } return effect, err } func (star *StarDBStmt) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { if len(args) == 0 { return nil, errors.New("no args") } if ctx == nil { return star.Stmt.Exec(args...) } return star.Stmt.ExecContext(ctx, args...) } func (star *StarDB) Prepare(sqlStr string) (*StarDBStmt, error) { stmt := new(StarDBStmt) stmtS, err := star.Db.Prepare(sqlStr) if err != nil { return nil, err } stmt.Stmt = stmtS stmt.Db = star return stmt, err } func (star *StarDB) PrepareContext(ctx context.Context, sqlStr string) (*StarDBStmt, error) { stmt := new(StarDBStmt) stmtS, err := star.Db.PrepareContext(ctx, sqlStr) if err != nil { return nil, err } stmt.Stmt = stmtS stmt.Db = star return stmt, err } func (star *StarTx) Prepare(sqlStr string) (*StarDBStmt, error) { stmt := new(StarDBStmt) stmtS, err := star.Tx.Prepare(sqlStr) if err != nil { return nil, err } stmt.Stmt = stmtS stmt.Db = star.Db return stmt, err } func (star *StarTx) PrepareContext(ctx context.Context, sqlStr string) (*StarDBStmt, error) { stmt := new(StarDBStmt) stmtS, err := star.Tx.PrepareContext(ctx, sqlStr) if err != nil { return nil, err } stmt.Db = star.Db stmt.Stmt = stmtS return stmt, err } // Query 进行Query操作 func (star *StarDB) query(ctx context.Context, args ...interface{}) (*StarRows, error) { var err error var rows *sql.Rows effect := new(StarRows) if err = star.Db.Ping(); err != nil { return effect, err } if len(args) == 0 { return effect, errors.New("no args") } var para []interface{} for k, v := range args { if k != 0 { switch vtype := v.(type) { default: para = append(para, vtype) } } } if ctx == nil { if rows, err = star.Db.Query(args[0].(string), para...); err != nil { return effect, err } } else { if rows, err = star.Db.QueryContext(ctx, args[0].(string), para...); err != nil { return effect, err } } effect.Rows = rows if !star.ManualScan { err = effect.parserows() } return effect, err } // Open 打开一个新的数据库 func (star *StarDB) Open(Method, ConnStr string) error { var err error star.Db, err = sql.Open(Method, ConnStr) return err } // Close 关闭打开的数据库 func (star *StarDB) Close() error { return star.Db.Close() } func (star *StarDB) Ping() error { return star.Db.Ping() } func (star *StarDB) Stats() sql.DBStats { return star.Db.Stats() } func (star *StarDB) SetMaxOpenConns(n int) { star.Db.SetMaxOpenConns(n) } func (star *StarDB) SetMaxIdleConns(n int) { star.Db.SetMaxIdleConns(n) } func (star *StarDB) PingContext(ctx context.Context) error { return star.Db.PingContext(ctx) } func (star *StarDB) Conn(ctx context.Context) (*sql.Conn, error) { return star.Db.Conn(ctx) } func (star *StarDB) Exec(args ...interface{}) (sql.Result, error) { return star.exec(nil, args...) } func (star *StarDB) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { return star.exec(ctx, args...) } // Exec 执行Exec操作 func (star *StarDB) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { var err error if err = star.Db.Ping(); err != nil { return nil, err } if len(args) == 0 { return nil, errors.New("no args") } var para []interface{} for k, v := range args { if k != 0 { switch vtype := v.(type) { default: para = append(para, vtype) } } } if ctx == nil { return star.Db.Exec(args[0].(string), para...) } return star.Db.ExecContext(ctx, args[0].(string), para...) } func (star *StarTx) Exec(args ...interface{}) (sql.Result, error) { return star.exec(nil, args...) } func (star *StarTx) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { return star.exec(ctx, args...) } func (star *StarTx) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { var err error if err = star.Db.Ping(); err != nil { return nil, err } if len(args) == 0 { return nil, errors.New("no args") } var para []interface{} for k, v := range args { if k != 0 { switch vtype := v.(type) { default: para = append(para, vtype) } } } if ctx == nil { return star.Tx.Exec(args[0].(string), para...) } return star.Tx.ExecContext(ctx, args[0].(string), para...) } func (star *StarTx) Commit() error { return star.Tx.Commit() } func (star *StarTx) Rollback() error { return star.Tx.Rollback() } // FetchAll 把结果集全部转为key-value型数据 func FetchAll(rows *sql.Rows) (error, map[int]map[string]string) { var ii int = 0 records := make(map[int]map[string]string) columns, err := rows.Columns() if err != nil { return err, records } scanArgs := make([]interface{}, len(columns)) values := make([]interface{}, len(columns)) for i := range values { scanArgs[i] = &values[i] } for rows.Next() { if err := rows.Scan(scanArgs...); err != nil { return err, records } record := make(map[string]string) for i, col := range values { switch vtype := col.(type) { case float64: record[columns[i]] = strconv.FormatFloat(vtype, 'f', -1, 64) case int64: record[columns[i]] = strconv.FormatInt(vtype, 10) case string: record[columns[i]] = vtype case nil: record[columns[i]] = "" default: record[columns[i]] = string(vtype.([]byte)) } } records[ii] = record ii++ } return nil, records }