commit babb5963b5f6c0663a1108503ef919de5fc1041d Author: starainrt Date: Sat Jul 1 18:19:58 2023 +0800 init diff --git a/column.go b/column.go new file mode 100644 index 0000000..ed4fee6 --- /dev/null +++ b/column.go @@ -0,0 +1,318 @@ +// Modeling of columns + +package sqlbuilder + +import ( + "bytes" + "regexp" + + "github.com/dropbox/godropbox/errors" +) + +// XXX: Maybe add UIntColumn + +// Representation of a table for query generation +type Column interface { + isProjectionInterface + + Name() string + // Serialization for use in column lists + SerializeSqlForColumnList(out *bytes.Buffer) error + // Serialization for use in an expression (Clause) + SerializeSql(out *bytes.Buffer) error + + // Internal function for tracking table that a column belongs to + // for the purpose of serialization + setTableName(table string) error +} + +type NullableColumn bool + +const ( + Nullable NullableColumn = true + NotNullable NullableColumn = false +) + +// A column that can be refer to outside of the projection list +type NonAliasColumn interface { + Column + isOrderByClauseInterface + isExpressionInterface +} + +type Collation string + +const ( + UTF8CaseInsensitive Collation = "utf8_unicode_ci" + UTF8CaseSensitive Collation = "utf8_unicode" + UTF8Binary Collation = "utf8_bin" +) + +// Representation of MySQL charsets +type Charset string + +const ( + UTF8 Charset = "utf8" +) + +// The base type for real materialized columns. +type baseColumn struct { + isProjection + isExpression + name string + nullable NullableColumn + table string +} + +func (c *baseColumn) Name() string { + return c.name +} + +func (c *baseColumn) setTableName(table string) error { + c.table = table + return nil +} + +/*func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { + if c.table != "" { + _ = out.WriteByte('`') + _, _ = out.WriteString(c.table) + _, _ = out.WriteString("`.") + } + _, _ = out.WriteString("`") + _, _ = out.WriteString(c.name) + _ = out.WriteByte('`') + return nil +} +*/ +func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { + // Momo modified. we don't need prefixing table name + /* + if c.table != "" { + _ = out.WriteByte('`') + _, _ = out.WriteString(c.table) + _, _ = out.WriteString("`.") + } + */ + _, _ = out.WriteString("`") + _, _ = out.WriteString(c.name) + _ = out.WriteByte('`') + return nil +} + +func (c *baseColumn) SerializeSql(out *bytes.Buffer) error { + return c.SerializeSqlForColumnList(out) +} + +type bytesColumn struct { + baseColumn + isExpression +} + +// Representation of VARBINARY/BLOB columns +// This function will panic if name is not valid +func BytesColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in bytes column") + } + bc := &bytesColumn{} + bc.name = name + bc.nullable = nullable + return bc +} + +type stringColumn struct { + baseColumn + isExpression + charset Charset + collation Collation +} + +// Representation of VARCHAR/TEXT columns +// This function will panic if name is not valid +func StrColumn( + name string, + charset Charset, + collation Collation, + nullable NullableColumn) NonAliasColumn { + + if !validIdentifierName(name) { + panic("Invalid column name in str column") + } + sc := &stringColumn{charset: charset, collation: collation} + sc.name = name + sc.nullable = nullable + return sc +} + +type dateTimeColumn struct { + baseColumn + isExpression +} + +// Representation of DateTime columns +// This function will panic if name is not valid +func DateTimeColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in datetime column") + } + dc := &dateTimeColumn{} + dc.name = name + dc.nullable = nullable + return dc +} + +type integerColumn struct { + baseColumn + isExpression +} + +// Representation of any integer column +// This function will panic if name is not valid +func IntColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in int column") + } + ic := &integerColumn{} + ic.name = name + ic.nullable = nullable + return ic +} + +type doubleColumn struct { + baseColumn + isExpression +} + +// Representation of any double column +// This function will panic if name is not valid +func DoubleColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in int column") + } + ic := &doubleColumn{} + ic.name = name + ic.nullable = nullable + return ic +} + +type booleanColumn struct { + baseColumn + isExpression + + // XXX: Maybe allow isBoolExpression (for now, not included because + // the deferred lookup equivalent can never be isBoolExpression) +} + +// Representation of TINYINT used as a bool +// This function will panic if name is not valid +func BoolColumn(name string, nullable NullableColumn) NonAliasColumn { + if !validIdentifierName(name) { + panic("Invalid column name in bool column") + } + bc := &booleanColumn{} + bc.name = name + bc.nullable = nullable + return bc +} + +type aliasColumn struct { + baseColumn + expression Expression +} + +func (c *aliasColumn) SerializeSql(out *bytes.Buffer) error { + _ = out.WriteByte('`') + _, _ = out.WriteString(c.name) + _ = out.WriteByte('`') + return nil +} + +func (c *aliasColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { + if !validIdentifierName(c.name) { + return errors.Newf( + "Invalid alias name `%s`. Generated sql: %s", + c.name, + out.String()) + } + if c.expression == nil { + return errors.Newf( + "Cannot alias a nil expression. Generated sql: %s", + out.String()) + } + + _ = out.WriteByte('(') + if c.expression == nil { + return errors.Newf("nil alias clause. Generate sql: %s", out.String()) + } + if err := c.expression.SerializeSql(out); err != nil { + return err + } + _, _ = out.WriteString(") AS `") + _, _ = out.WriteString(c.name) + _ = out.WriteByte('`') + return nil +} + +func (c *aliasColumn) setTableName(table string) error { + return errors.Newf( + "Alias column '%s' should never have setTableName called on it", + c.name) +} + +// Representation of aliased clauses (expression AS name) +func Alias(name string, c Expression) Column { + ac := &aliasColumn{} + ac.name = name + ac.expression = c + return ac +} + +// This is a strict subset of the actual allowed identifiers +var validIdentifierRegexp = regexp.MustCompile("^[a-zA-Z_]\\w*$") + +// Returns true if the given string is suitable as an identifier. +func validIdentifierName(name string) bool { + //return validIdentifierRegexp.MatchString(name) + return true +} + +// Pseudo Column type returned by table.C(name) +type deferredLookupColumn struct { + isProjection + isExpression + table *Table + colName string + + cachedColumn NonAliasColumn +} + +func (c *deferredLookupColumn) Name() string { + return c.colName +} + +func (c *deferredLookupColumn) SerializeSqlForColumnList( + out *bytes.Buffer) error { + + return c.SerializeSql(out) +} + +func (c *deferredLookupColumn) SerializeSql(out *bytes.Buffer) error { + if c.cachedColumn != nil { + return c.cachedColumn.SerializeSql(out) + } + + col, err := c.table.getColumn(c.colName) + if err != nil { + return err + } + + c.cachedColumn = col + return col.SerializeSql(out) +} + +func (c *deferredLookupColumn) setTableName(table string) error { + return errors.Newf( + "Lookup column '%s' should never have setTableName called on it", + c.colName) +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..3f9170a --- /dev/null +++ b/doc.go @@ -0,0 +1,25 @@ +// A library for generating sql programmatically. +// +// SQL COMPATIBILITY NOTE: sqlbuilder is designed to generate valid MySQL sql +// statements. The generated statements may not work for other sql variants. +// For instances, the generated statements does not currently work for +// PostgreSQL since column identifiers are escaped with backquotes. +// Patches to support other sql flavors are welcome! (see +// https://godropbox/issues/33 for additional details). +// +// Known limitations for SELECT queries: +// - does not support subqueries (since mysql is bad at it) +// - does not currently support join table alias (and hence self join) +// - does not support NATURAL joins and join USING +// +// Known limitation for INSERT statements: +// - does not support "INSERT INTO SELECT" +// +// Known limitation for UPDATE statements: +// - does not support update without a WHERE clause (since it is dangerous) +// - does not support multi-table update +// +// Known limitation for DELETE statements: +// - does not support delete without a WHERE clause (since it is dangerous) +// - does not support multi-table delete +package sqlbuilder diff --git a/expression.go b/expression.go new file mode 100644 index 0000000..2a8d877 --- /dev/null +++ b/expression.go @@ -0,0 +1,733 @@ +// Query building functions for expression components +package sqlbuilder + +import ( + "bytes" + "reflect" + "strconv" + "strings" + "time" + + "b612.me/mysql/sqltypes" + "github.com/dropbox/godropbox/errors" +) + +type orderByClause struct { + isOrderByClause + expression Expression + ascent bool +} + +func (o *orderByClause) SerializeSql(out *bytes.Buffer) error { + if o.expression == nil { + return errors.Newf( + "nil order by clause. Generated sql: %s", + out.String()) + } + + if err := o.expression.SerializeSql(out); err != nil { + return err + } + + if o.ascent { + _, _ = out.WriteString(" ASC") + } else { + _, _ = out.WriteString(" DESC") + } + + return nil +} + +func Asc(expression Expression) OrderByClause { + return &orderByClause{expression: expression, ascent: true} +} + +func Desc(expression Expression) OrderByClause { + return &orderByClause{expression: expression, ascent: false} +} + +// Representation of an escaped literal +type literalExpression struct { + isExpression + value sqltypes.Value +} + +func (c literalExpression) SerializeSql(out *bytes.Buffer) error { + sqltypes.Value(c.value).EncodeSql(out) + return nil +} + +func serializeClauses( + clauses []Clause, + separator []byte, + out *bytes.Buffer) (err error) { + + if clauses == nil || len(clauses) == 0 { + return errors.Newf("Empty clauses. Generated sql: %s", out.String()) + } + + if clauses[0] == nil { + return errors.Newf("nil clause. Generated sql: %s", out.String()) + } + if err = clauses[0].SerializeSql(out); err != nil { + return + } + + for _, c := range clauses[1:] { + _, _ = out.Write(separator) + + if c == nil { + return errors.Newf("nil clause. Generated sql: %s", out.String()) + } + if err = c.SerializeSql(out); err != nil { + return + } + } + + return nil +} + +// Representation of n-ary conjunctions (AND/OR) +type conjunctExpression struct { + isExpression + isBoolExpression + expressions []BoolExpression + conjunction []byte +} + +func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer) (err error) { + if len(conj.expressions) == 0 { + return errors.Newf( + "Empty conjunction. Generated sql: %s", + out.String()) + } + + clauses := make([]Clause, len(conj.expressions), len(conj.expressions)) + for i, expr := range conj.expressions { + clauses[i] = expr + } + + useParentheses := len(clauses) > 1 + if useParentheses { + _ = out.WriteByte('(') + } + + if err = serializeClauses(clauses, conj.conjunction, out); err != nil { + return + } + + if useParentheses { + _ = out.WriteByte(')') + } + + return nil +} + +// Representation of n-ary arithmetic (+ - * /) +type arithmeticExpression struct { + isExpression + expressions []Expression + operator []byte +} + +func (arith *arithmeticExpression) SerializeSql(out *bytes.Buffer) (err error) { + if len(arith.expressions) == 0 { + return errors.Newf( + "Empty arithmetic expression. Generated sql: %s", + out.String()) + } + + clauses := make([]Clause, len(arith.expressions), len(arith.expressions)) + for i, expr := range arith.expressions { + clauses[i] = expr + } + + useParentheses := len(clauses) > 1 + if useParentheses { + _ = out.WriteByte('(') + } + + if err = serializeClauses(clauses, arith.operator, out); err != nil { + return + } + + if useParentheses { + _ = out.WriteByte(')') + } + + return nil +} + +type tupleExpression struct { + isExpression + elements listClause +} + +func (tuple *tupleExpression) SerializeSql(out *bytes.Buffer) error { + if len(tuple.elements.clauses) < 1 { + return errors.Newf("Tuples must include at least one element") + } + return tuple.elements.SerializeSql(out) +} + +func Tuple(exprs ...Expression) Expression { + clauses := make([]Clause, 0, len(exprs)) + for _, expr := range exprs { + clauses = append(clauses, expr) + } + return &tupleExpression{ + elements: listClause{ + clauses: clauses, + includeParentheses: true, + }, + } +} + +// Representation of a tuple enclosed, comma separated list of clauses +type listClause struct { + clauses []Clause + includeParentheses bool +} + +func (list *listClause) SerializeSql(out *bytes.Buffer) error { + if list.includeParentheses { + _ = out.WriteByte('(') + } + + if err := serializeClauses(list.clauses, []byte(","), out); err != nil { + return err + } + + if list.includeParentheses { + _ = out.WriteByte(')') + } + return nil +} + +// A not expression which negates a expression value +type negateExpression struct { + isExpression + isBoolExpression + + nested BoolExpression +} + +func (c *negateExpression) SerializeSql(out *bytes.Buffer) (err error) { + _, _ = out.WriteString("NOT (") + + if c.nested == nil { + return errors.Newf("nil nested. Generated sql: %s", out.String()) + } + if err = c.nested.SerializeSql(out); err != nil { + return + } + + _ = out.WriteByte(')') + return nil +} + +// Returns a representation of "not expr" +func Not(expr BoolExpression) BoolExpression { + return &negateExpression{ + nested: expr, + } +} + +// Representation of binary operations (e.g. comparisons, arithmetic) +type binaryExpression struct { + isExpression + lhs, rhs Expression + operator []byte +} + +func (c *binaryExpression) SerializeSql(out *bytes.Buffer) (err error) { + if c.lhs == nil { + return errors.Newf("nil lhs. Generated sql: %s", out.String()) + } + if err = c.lhs.SerializeSql(out); err != nil { + return + } + + _, _ = out.Write(c.operator) + + if c.rhs == nil { + return errors.Newf("nil rhs. Generated sql: %s", out.String()) + } + if err = c.rhs.SerializeSql(out); err != nil { + return + } + + return nil +} + +// A binary expression that evaluates to a boolean value. +type boolExpression struct { + isBoolExpression + binaryExpression +} + +func newBoolExpression(lhs, rhs Expression, operator []byte) *boolExpression { + // go does not allow {} syntax for initializing promoted fields ... + expr := new(boolExpression) + expr.lhs = lhs + expr.rhs = rhs + expr.operator = operator + return expr +} + +type funcExpression struct { + isExpression + funcName string + args *listClause +} + +func (c *funcExpression) SerializeSql(out *bytes.Buffer) (err error) { + if !validIdentifierName(c.funcName) { + return errors.Newf( + "Invalid function name: %s. Generated sql: %s", + c.funcName, + out.String()) + } + _, _ = out.WriteString(c.funcName) + if c.args == nil { + _, _ = out.WriteString("()") + } else { + return c.args.SerializeSql(out) + } + return nil +} + +// Returns a representation of sql function call "func_call(c[0], ..., c[n-1]) +func SqlFunc(funcName string, expressions ...Expression) Expression { + f := &funcExpression{ + funcName: funcName, + } + if len(expressions) > 0 { + args := make([]Clause, len(expressions), len(expressions)) + for i, expr := range expressions { + args[i] = expr + } + + f.args = &listClause{ + clauses: args, + includeParentheses: true, + } + } + return f +} + +type intervalExpression struct { + isExpression + duration time.Duration + negative bool +} + +var intervalSep = ":" + +func (c *intervalExpression) SerializeSql(out *bytes.Buffer) (err error) { + hours := c.duration / time.Hour + minutes := (c.duration % time.Hour) / time.Minute + sec := (c.duration % time.Minute) / time.Second + msec := (c.duration % time.Second) / time.Microsecond + _, _ = out.WriteString("INTERVAL '") + if c.negative { + _, _ = out.WriteString("-") + } + _, _ = out.WriteString(strconv.FormatInt(int64(hours), 10)) + _, _ = out.WriteString(intervalSep) + _, _ = out.WriteString(strconv.FormatInt(int64(minutes), 10)) + _, _ = out.WriteString(intervalSep) + _, _ = out.WriteString(strconv.FormatInt(int64(sec), 10)) + _, _ = out.WriteString(intervalSep) + _, _ = out.WriteString(strconv.FormatInt(int64(msec), 10)) + _, _ = out.WriteString("' HOUR_MICROSECOND") + return nil +} + +// Interval returns a representation of duration +// in a form "INTERVAL `hour:min:sec:microsec` HOUR_MICROSECOND" +func Interval(duration time.Duration) Expression { + negative := false + if duration < 0 { + negative = true + duration = -duration + } + return &intervalExpression{ + duration: duration, + negative: negative, + } +} + +var likeEscaper = strings.NewReplacer("_", "\\_", "%", "\\%") + +func EscapeForLike(s string) string { + return likeEscaper.Replace(s) +} + +// Returns an escaped literal string +func Literal(v interface{}) Expression { + value, err := sqltypes.BuildValue(v) + if err != nil { + panic(errors.Wrap(err, "Invalid literal value")) + } + return &literalExpression{value: value} +} + +// Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses +func And(expressions ...BoolExpression) BoolExpression { + return &conjunctExpression{ + expressions: expressions, + conjunction: []byte(" AND "), + } +} + +// Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses +func Or(expressions ...BoolExpression) BoolExpression { + return &conjunctExpression{ + expressions: expressions, + conjunction: []byte(" OR "), + } +} + +func Like(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(" LIKE ")) +} + +func LikeL(lhs Expression, val string) BoolExpression { + return Like(lhs, Literal(val)) +} + +func Regexp(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(" REGEXP ")) +} + +func RegexpL(lhs Expression, val string) BoolExpression { + return Regexp(lhs, Literal(val)) +} + +// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses +func Add(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" + "), + } +} + +// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses +func Sub(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" - "), + } +} + +// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses +func Mul(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" * "), + } +} + +// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses +func Div(expressions ...Expression) Expression { + return &arithmeticExpression{ + expressions: expressions, + operator: []byte(" / "), + } +} + +// Returns a representation of "a=b" +func Eq(lhs, rhs Expression) BoolExpression { + lit, ok := rhs.(*literalExpression) + if ok && sqltypes.Value(lit.value).IsNull() { + return newBoolExpression(lhs, rhs, []byte(" IS ")) + } + return newBoolExpression(lhs, rhs, []byte("=")) +} + +// Returns a representation of "a=b", where b is a literal +func EqL(lhs Expression, val interface{}) BoolExpression { + return Eq(lhs, Literal(val)) +} + +// Returns a representation of "a!=b" +func Neq(lhs, rhs Expression) BoolExpression { + lit, ok := rhs.(*literalExpression) + if ok && sqltypes.Value(lit.value).IsNull() { + return newBoolExpression(lhs, rhs, []byte(" IS NOT ")) + } + return newBoolExpression(lhs, rhs, []byte("!=")) +} + +// Returns a representation of "a!=b", where b is a literal +func NeqL(lhs Expression, val interface{}) BoolExpression { + return Neq(lhs, Literal(val)) +} + +// Returns a representation of "ab" +func Gt(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(">")) +} + +// Returns a representation of "a>b", where b is a literal +func GtL(lhs Expression, val interface{}) BoolExpression { + return Gt(lhs, Literal(val)) +} + +// Returns a representation of "a>=b" +func Gte(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(">=")) +} + +// Returns a representation of "a>=b", where b is a literal +func GteL(lhs Expression, val interface{}) BoolExpression { + return Gte(lhs, Literal(val)) +} + +func BitOr(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" | "), + } +} + +func BitAnd(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" & "), + } +} + +func BitXor(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" ^ "), + } +} + +func Plus(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" + "), + } +} + +func Minus(lhs, rhs Expression) Expression { + return &binaryExpression{ + lhs: lhs, + rhs: rhs, + operator: []byte(" - "), + } +} + +// in expression representation +type inExpression struct { + isExpression + isBoolExpression + + lhs Expression + rhs *listClause + + err error +} + +func (c *inExpression) SerializeSql(out *bytes.Buffer) error { + if c.err != nil { + return errors.Wrap(c.err, "Invalid IN expression") + } + + if c.lhs == nil { + return errors.Newf( + "lhs of in expression is nil. Generated sql: %s", + out.String()) + } + + // We'll serialize the lhs even if we don't need it to ensure no error + buf := &bytes.Buffer{} + + err := c.lhs.SerializeSql(buf) + if err != nil { + return err + } + + if c.rhs == nil { + _, _ = out.WriteString("FALSE") + return nil + } + + _, _ = out.WriteString(buf.String()) + _, _ = out.WriteString(" IN ") + + err = c.rhs.SerializeSql(out) + if err != nil { + return err + } + + return nil +} + +// Returns a representation of "a IN (b[0], ..., b[n-1])", where b is a list +// of literals valList must be a slice type +func In(lhs Expression, valList interface{}) BoolExpression { + var clauses []Clause + switch val := valList.(type) { + // This atrocious body of copy-paste code is due to the fact that if you + // try to merge the cases, you can't treat val as a list + case []int: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []int32: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []int64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint32: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []float64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []string: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case [][]byte: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []time.Time: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Numeric: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Fractional: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.String: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Value: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + default: + return &inExpression{ + err: errors.Newf( + "Unknown value list type in IN clause: %s", + reflect.TypeOf(valList)), + } + } + + expr := &inExpression{lhs: lhs} + if len(clauses) > 0 { + expr.rhs = &listClause{clauses: clauses, includeParentheses: true} + } + return expr +} + +type ifExpression struct { + isExpression + conditional BoolExpression + trueExpression Expression + falseExpression Expression +} + +func (exp *ifExpression) SerializeSql(out *bytes.Buffer) error { + _, _ = out.WriteString("IF(") + _ = exp.conditional.SerializeSql(out) + _, _ = out.WriteString(",") + _ = exp.trueExpression.SerializeSql(out) + _, _ = out.WriteString(",") + _ = exp.falseExpression.SerializeSql(out) + _, _ = out.WriteString(")") + return nil +} + +// Returns a representation of an if-expression, of the form: +// +// IF (BOOLEAN TEST, VALUE-IF-TRUE, VALUE-IF-FALSE) +func If(conditional BoolExpression, + trueExpression Expression, + falseExpression Expression) Expression { + return &ifExpression{ + conditional: conditional, + trueExpression: trueExpression, + falseExpression: falseExpression, + } +} + +type columnValueExpression struct { + isExpression + column NonAliasColumn +} + +func ColumnValue(col NonAliasColumn) Expression { + return &columnValueExpression{ + column: col, + } +} + +func (cv *columnValueExpression) SerializeSql(out *bytes.Buffer) error { + _, _ = out.WriteString("VALUES(") + _ = cv.column.SerializeSqlForColumnList(out) + _ = out.WriteByte(')') + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c17344a --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module b612.me/mysql/sqlbuilder + +go 1.20 + +require ( + b612.me/mysql/sqltypes v0.0.0-20230701101652-40406d9a2ff2 + github.com/dropbox/godropbox v0.0.0-20230623171840-436d2007a9fd +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c46d44f --- /dev/null +++ b/go.sum @@ -0,0 +1,33 @@ +b612.me/mysql/sqltypes v0.0.0-20230701101652-40406d9a2ff2 h1:gWGuBHC7hrmyhp9vX1UOOX5C9WRoFHPDjSvRfmP/nS4= +b612.me/mysql/sqltypes v0.0.0-20230701101652-40406d9a2ff2/go.mod h1:Py9XWC9lc2cDhzfSPO7gqk07qZcPpJfk0aQ0iUZC5CQ= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dropbox/godropbox v0.0.0-20230623171840-436d2007a9fd h1:s2vYw+2c+7GR1ccOaDuDcKsmNB/4RIxyu5liBm1VRbs= +github.com/dropbox/godropbox v0.0.0-20230623171840-436d2007a9fd/go.mod h1:Vr/Q4p40Kce7JAHDITjDhiy/zk07W4tqD5YVi5FD0PA= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= +github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/statement.go b/statement.go new file mode 100644 index 0000000..92990b0 --- /dev/null +++ b/statement.go @@ -0,0 +1,1022 @@ +package sqlbuilder + +import ( + "bytes" + "fmt" + "regexp" + + "github.com/dropbox/godropbox/errors" +) + +type Statement interface { + // String returns generated SQL as string. + String(database string) (sql string, err error) +} + +type SelectStatement interface { + Statement + + Where(expression BoolExpression) SelectStatement + AndWhere(expression BoolExpression) SelectStatement + GroupBy(expressions ...Expression) SelectStatement + OrderBy(clauses ...OrderByClause) SelectStatement + Limit(limit int64) SelectStatement + Distinct() SelectStatement + WithSharedLock() SelectStatement + ForUpdate() SelectStatement + Offset(offset int64) SelectStatement + Comment(comment string) SelectStatement + Copy() SelectStatement +} + +type InsertStatement interface { + Statement + + // Add a row of values to the insert statement. + Add(row ...Expression) InsertStatement + AddOnDuplicateKeyUpdate(col NonAliasColumn, expr Expression) InsertStatement + Comment(comment string) InsertStatement + IgnoreDuplicates(ignore bool) InsertStatement +} + +// By default, rows selected by a UNION statement are out-of-order +// If you have an ORDER BY on an inner SELECT statement, the only thing +// it affects is the LIMIT clause on that inner statement (the ordering will +// still be out-of-order). +type UnionStatement interface { + Statement + + // Warning! You cannot include table names for the next 4 clauses, or + // you'll get errors like: + // Table 'server_file_journal' from one of the SELECTs cannot be used in + // global ORDER clause + Where(expression BoolExpression) UnionStatement + AndWhere(expression BoolExpression) UnionStatement + GroupBy(expressions ...Expression) UnionStatement + OrderBy(clauses ...OrderByClause) UnionStatement + + Limit(limit int64) UnionStatement + Offset(offset int64) UnionStatement +} + +type UpdateStatement interface { + Statement + + Set(column NonAliasColumn, expression Expression) UpdateStatement + Where(expression BoolExpression) UpdateStatement + OrderBy(clauses ...OrderByClause) UpdateStatement + Limit(limit int64) UpdateStatement + Comment(comment string) UpdateStatement +} + +type DeleteStatement interface { + Statement + + Where(expression BoolExpression) DeleteStatement + OrderBy(clauses ...OrderByClause) DeleteStatement + Limit(limit int64) DeleteStatement + Comment(comment string) DeleteStatement +} + +// LockStatement is used to take Read/Write lock on tables. +// See http://dev.mysql.com/doc/refman/5.0/en/lock-tables.html +type LockStatement interface { + Statement + + AddReadLock(table *Table) LockStatement + AddWriteLock(table *Table) LockStatement +} + +// UnlockStatement can be used to release table locks taken using LockStatement. +// NOTE: You can not selectively release a lock and continue to hold lock on +// another table. UnlockStatement releases all the lock held in the current +// session. +type UnlockStatement interface { + Statement +} + +// SetGtidNextStatement returns a SQL statement that can be used to explicitly set the next GTID. +type GtidNextStatement interface { + Statement +} + +// +// UNION SELECT Statement ====================================================== +// + +func Union(selects ...SelectStatement) UnionStatement { + return &unionStatementImpl{ + selects: selects, + limit: -1, + offset: -1, + unique: true, + } +} + +func UnionAll(selects ...SelectStatement) UnionStatement { + return &unionStatementImpl{ + selects: selects, + limit: -1, + offset: -1, + unique: false, + } +} + +// Similar to selectStatementImpl, but less complete +type unionStatementImpl struct { + selects []SelectStatement + where BoolExpression + group *listClause + order *listClause + limit, offset int64 + // True if results of the union should be deduped. + unique bool +} + +func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement { + us.where = expression + return us +} + +// Further filter the query, instead of replacing the filter +func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement { + if us.where == nil { + return us.Where(expression) + } + us.where = And(us.where, expression) + return us +} + +func (us *unionStatementImpl) GroupBy( + expressions ...Expression) UnionStatement { + + us.group = &listClause{ + clauses: make([]Clause, len(expressions), len(expressions)), + includeParentheses: false, + } + + for i, e := range expressions { + us.group.clauses[i] = e + } + return us +} + +func (us *unionStatementImpl) OrderBy( + clauses ...OrderByClause) UnionStatement { + + us.order = newOrderByListClause(clauses...) + return us +} + +func (us *unionStatementImpl) Limit(limit int64) UnionStatement { + us.limit = limit + return us +} + +func (us *unionStatementImpl) Offset(offset int64) UnionStatement { + us.offset = offset + return us +} + +func (us *unionStatementImpl) String(database string) (sql string, err error) { + if len(us.selects) == 0 { + return "", errors.Newf("Union statement must have at least one SELECT") + } + + if len(us.selects) == 1 { + return us.selects[0].String(database) + } + + // Union statements in MySQL require that the same number of columns in each subquery + var projections []Projection + + for _, statement := range us.selects { + // do a type assertion to get at the underlying struct + statementImpl, ok := statement.(*selectStatementImpl) + if !ok { + return "", errors.Newf( + "Expected inner select statement to be of type " + + "selectStatementImpl") + } + + // check that for limit for statements with order by clauses + if statementImpl.order != nil && statementImpl.limit < 0 { + return "", errors.Newf( + "All inner selects in Union statement must have LIMIT if " + + "they have ORDER BY") + } + + // check number of projections + if projections == nil { + projections = statementImpl.projections + } else { + if len(projections) != len(statementImpl.projections) { + return "", errors.Newf( + "All inner selects in Union statement must select the " + + "same number of columns. For sanity, you probably " + + "want to select the same table columns in the same " + + "order. If you are selecting on multiple tables, " + + "use Null to pad to the right number of fields.") + } + } + } + + buf := new(bytes.Buffer) + for i, statement := range us.selects { + if i != 0 { + if us.unique { + _, _ = buf.WriteString(" UNION ") + } else { + _, _ = buf.WriteString(" UNION ALL ") + } + } + _, _ = buf.WriteString("(") + selectSql, err := statement.String(database) + if err != nil { + return "", err + } + _, _ = buf.WriteString(selectSql) + _, _ = buf.WriteString(")") + } + + if us.where != nil { + _, _ = buf.WriteString(" WHERE ") + if err = us.where.SerializeSql(buf); err != nil { + return + } + } + + if us.group != nil { + _, _ = buf.WriteString(" GROUP BY ") + if err = us.group.SerializeSql(buf); err != nil { + return + } + } + + if us.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = us.order.SerializeSql(buf); err != nil { + return + } + } + + if us.limit >= 0 { + if us.offset >= 0 { + _, _ = buf.WriteString( + fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit)) + } else { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit)) + } + } + return buf.String(), nil +} + +// +// SELECT Statement ============================================================ +// + +func newSelectStatement( + table ReadableTable, + projections []Projection) SelectStatement { + + return &selectStatementImpl{ + table: table, + projections: projections, + limit: -1, + offset: -1, + withSharedLock: false, + forUpdate: false, + distinct: false, + } +} + +// NOTE: SelectStatement purposely does not implement the Table interface since +// mysql's subquery performance is horrible. +type selectStatementImpl struct { + table ReadableTable + projections []Projection + where BoolExpression + group *listClause + order *listClause + comment string + limit, offset int64 + withSharedLock bool + forUpdate bool + distinct bool +} + +func (s *selectStatementImpl) Copy() SelectStatement { + ret := *s + return &ret +} + +// Further filter the query, instead of replacing the filter +func (q *selectStatementImpl) AndWhere( + expression BoolExpression) SelectStatement { + + if q.where == nil { + return q.Where(expression) + } + q.where = And(q.where, expression) + return q +} + +func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement { + q.where = expression + return q +} + +func (q *selectStatementImpl) GroupBy( + expressions ...Expression) SelectStatement { + + q.group = &listClause{ + clauses: make([]Clause, len(expressions), len(expressions)), + includeParentheses: false, + } + + for i, e := range expressions { + q.group.clauses[i] = e + } + return q +} + +func (q *selectStatementImpl) OrderBy( + clauses ...OrderByClause) SelectStatement { + + q.order = newOrderByListClause(clauses...) + return q +} + +func (q *selectStatementImpl) Limit(limit int64) SelectStatement { + q.limit = limit + return q +} + +func (q *selectStatementImpl) Distinct() SelectStatement { + q.distinct = true + return q +} + +func (q *selectStatementImpl) WithSharedLock() SelectStatement { + // We don't need to grab a read lock if we're going to grab a write one + if !q.forUpdate { + q.withSharedLock = true + } + return q +} + +func (q *selectStatementImpl) ForUpdate() SelectStatement { + // Clear a request for a shared lock if we're asking for a write one + q.withSharedLock = false + q.forUpdate = true + return q +} + +func (q *selectStatementImpl) Offset(offset int64) SelectStatement { + q.offset = offset + return q +} + +func (q *selectStatementImpl) Comment(comment string) SelectStatement { + q.comment = comment + return q +} + +// Return the properly escaped SQL statement, against the specified database +func (q *selectStatementImpl) String(database string) (sql string, err error) { + if !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("SELECT ") + + if err = writeComment(q.comment, buf); err != nil { + return + } + + if q.distinct { + _, _ = buf.WriteString("DISTINCT ") + } + + if q.projections == nil || len(q.projections) == 0 { + return "", errors.Newf( + "No column selected. Generated sql: %s", + buf.String()) + } + + for i, col := range q.projections { + if i > 0 { + _ = buf.WriteByte(',') + } + if col == nil { + return "", errors.Newf( + "nil column selected. Generated sql: %s", + buf.String()) + } + if err = col.SerializeSqlForColumnList(buf); err != nil { + return + } + } + + _, _ = buf.WriteString(" FROM ") + if q.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + if err = q.table.SerializeSql(database, buf); err != nil { + return + } + + if q.where != nil { + _, _ = buf.WriteString(" WHERE ") + if err = q.where.SerializeSql(buf); err != nil { + return + } + } + + if q.group != nil { + _, _ = buf.WriteString(" GROUP BY ") + if err = q.group.SerializeSql(buf); err != nil { + return + } + } + + if q.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = q.order.SerializeSql(buf); err != nil { + return + } + } + + if q.limit >= 0 { + if q.offset >= 0 { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d, %d", q.offset, q.limit)) + } else { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", q.limit)) + } + } + + if q.forUpdate { + _, _ = buf.WriteString(" FOR UPDATE") + } else if q.withSharedLock { + _, _ = buf.WriteString(" LOCK IN SHARE MODE") + } + + return buf.String(), nil +} + +// +// INSERT Statement ============================================================ +// + +func newInsertStatement( + t WritableTable, + columns ...NonAliasColumn) InsertStatement { + + return &insertStatementImpl{ + table: t, + columns: columns, + rows: make([][]Expression, 0, 1), + onDuplicateKeyUpdates: make([]columnAssignment, 0, 0), + } +} + +type columnAssignment struct { + col NonAliasColumn + expr Expression +} + +type insertStatementImpl struct { + table WritableTable + columns []NonAliasColumn + rows [][]Expression + onDuplicateKeyUpdates []columnAssignment + comment string + ignore bool +} + +func (s *insertStatementImpl) Add( + row ...Expression) InsertStatement { + + s.rows = append(s.rows, row) + return s +} + +func (s *insertStatementImpl) AddOnDuplicateKeyUpdate( + col NonAliasColumn, + expr Expression) InsertStatement { + + s.onDuplicateKeyUpdates = append( + s.onDuplicateKeyUpdates, + columnAssignment{col, expr}) + + return s +} + +func (s *insertStatementImpl) IgnoreDuplicates(ignore bool) InsertStatement { + s.ignore = ignore + return s +} + +func (s *insertStatementImpl) Comment(comment string) InsertStatement { + s.comment = comment + return s +} + +func (s *insertStatementImpl) String(database string) (sql string, err error) { + // Momo modified. if database empty, not validate it + if database != "" && !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("INSERT ") + if s.ignore { + _, _ = buf.WriteString("IGNORE ") + } + _, _ = buf.WriteString("INTO ") + + if err = writeComment(s.comment, buf); err != nil { + return + } + + if s.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = s.table.SerializeSql(database, buf); err != nil { + return + } + + if len(s.columns) == 0 { + return "", errors.Newf( + "No column specified. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" (") + for i, col := range s.columns { + if i > 0 { + _ = buf.WriteByte(',') + } + + if col == nil { + return "", errors.Newf( + "nil column in columns list. Generated sql: %s", + buf.String()) + } + + if err = col.SerializeSqlForColumnList(buf); err != nil { + return + } + } + + if len(s.rows) == 0 { + return "", errors.Newf( + "No row specified. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(") VALUES (") + for row_i, row := range s.rows { + if row_i > 0 { + _, _ = buf.WriteString(", (") + } + + if len(row) != len(s.columns) { + return "", errors.Newf( + "# of values does not match # of columns. Generated sql: %s", + buf.String()) + } + + for col_i, value := range row { + if col_i > 0 { + _ = buf.WriteByte(',') + } + + if value == nil { + return "", errors.Newf( + "nil value in row %d col %d. Generated sql: %s", + row_i, + col_i, + buf.String()) + } + + if err = value.SerializeSql(buf); err != nil { + return + } + } + _ = buf.WriteByte(')') + } + + if len(s.onDuplicateKeyUpdates) > 0 { + _, _ = buf.WriteString(" ON DUPLICATE KEY UPDATE ") + for i, colExpr := range s.onDuplicateKeyUpdates { + if i > 0 { + _, _ = buf.WriteString(", ") + } + + if colExpr.col == nil { + return "", errors.Newf( + ("nil column in on duplicate key update list. " + + "Generated sql: %s"), + buf.String()) + } + + if err = colExpr.col.SerializeSqlForColumnList(buf); err != nil { + return + } + + _ = buf.WriteByte('=') + + if colExpr.expr == nil { + return "", errors.Newf( + ("nil expression in on duplicate key update list. " + + "Generated sql: %s"), + buf.String()) + } + + if err = colExpr.expr.SerializeSql(buf); err != nil { + return + } + } + } + + return buf.String(), nil +} + +// +// UPDATE statement =========================================================== +// + +func newUpdateStatement(table WritableTable) UpdateStatement { + return &updateStatementImpl{ + table: table, + updateValues: make(map[NonAliasColumn]Expression), + limit: -1, + } +} + +type updateStatementImpl struct { + table WritableTable + updateValues map[NonAliasColumn]Expression + where BoolExpression + order *listClause + limit int64 + comment string +} + +func (u *updateStatementImpl) Set( + column NonAliasColumn, + expression Expression) UpdateStatement { + + u.updateValues[column] = expression + return u +} + +func (u *updateStatementImpl) Where(expression BoolExpression) UpdateStatement { + u.where = expression + return u +} + +func (u *updateStatementImpl) OrderBy( + clauses ...OrderByClause) UpdateStatement { + + u.order = newOrderByListClause(clauses...) + return u +} + +func (u *updateStatementImpl) Limit(limit int64) UpdateStatement { + u.limit = limit + return u +} + +func (u *updateStatementImpl) Comment(comment string) UpdateStatement { + u.comment = comment + return u +} + +func (u *updateStatementImpl) String(database string) (sql string, err error) { + // Momo modified. if database empty, not validate it + if database != "" && !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("UPDATE ") + + if err = writeComment(u.comment, buf); err != nil { + return + } + + if u.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = u.table.SerializeSql(database, buf); err != nil { + return + } + + if len(u.updateValues) == 0 { + return "", errors.Newf( + "No column updated. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" SET ") + addComma := false + + // Sorting is too hard in go, just create a second map ... + updateValues := make(map[string]Expression) + for col, expr := range u.updateValues { + if col == nil { + return "", errors.Newf( + "nil column. Generated sql: %s", + buf.String()) + } + + updateValues[col.Name()] = expr + } + + for _, col := range u.table.Columns() { + val, inMap := updateValues[col.Name()] + if !inMap { + continue + } + + if addComma { + _, _ = buf.WriteString(", ") + } + + if val == nil { + return "", errors.Newf( + "nil value. Generated sql: %s", + buf.String()) + } + + if err = col.SerializeSql(buf); err != nil { + return + } + + _ = buf.WriteByte('=') + if err = val.SerializeSql(buf); err != nil { + return + } + + addComma = true + } + + if u.where == nil { + return "", errors.Newf( + "Updating without a WHERE clause. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" WHERE ") + if err = u.where.SerializeSql(buf); err != nil { + return + } + + if u.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = u.order.SerializeSql(buf); err != nil { + return + } + } + + if u.limit >= 0 { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", u.limit)) + } + + return buf.String(), nil +} + +// +// DELETE statement =========================================================== +// + +func newDeleteStatement(table WritableTable) DeleteStatement { + return &deleteStatementImpl{ + table: table, + limit: -1, + } +} + +type deleteStatementImpl struct { + table WritableTable + where BoolExpression + order *listClause + limit int64 + comment string +} + +func (d *deleteStatementImpl) Where(expression BoolExpression) DeleteStatement { + d.where = expression + return d +} + +func (d *deleteStatementImpl) OrderBy( + clauses ...OrderByClause) DeleteStatement { + + d.order = newOrderByListClause(clauses...) + return d +} + +func (d *deleteStatementImpl) Limit(limit int64) DeleteStatement { + d.limit = limit + return d +} + +func (d *deleteStatementImpl) Comment(comment string) DeleteStatement { + d.comment = comment + return d +} + +func (d *deleteStatementImpl) String(database string) (sql string, err error) { + // Momo modified. if database empty, not validate it + if database != "" && !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("DELETE FROM ") + + if err = writeComment(d.comment, buf); err != nil { + return + } + + if d.table == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = d.table.SerializeSql(database, buf); err != nil { + return + } + + if d.where == nil { + return "", errors.Newf( + "Deleting without a WHERE clause. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" WHERE ") + if err = d.where.SerializeSql(buf); err != nil { + return + } + + if d.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = d.order.SerializeSql(buf); err != nil { + return + } + } + + if d.limit >= 0 { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", d.limit)) + } + + return buf.String(), nil +} + +// +// LOCK statement =========================================================== +// + +// NewLockStatement returns a SQL representing empty set of locks. You need to use +// AddReadLock/AddWriteLock to add tables that need to be locked. +// NOTE: You need at least one lock in the set for it to be a valid statement. +func NewLockStatement() LockStatement { + return &lockStatementImpl{} +} + +type lockStatementImpl struct { + locks []tableLock +} + +type tableLock struct { + t *Table + w bool +} + +// AddReadLock takes read lock on the table. +func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement { + s.locks = append(s.locks, tableLock{t: t, w: false}) + return s +} + +// AddWriteLock takes write lock on the table. +func (s *lockStatementImpl) AddWriteLock(t *Table) LockStatement { + s.locks = append(s.locks, tableLock{t: t, w: true}) + return s +} + +func (s *lockStatementImpl) String(database string) (sql string, err error) { + if !validIdentifierName(database) { + return "", errors.New("Invalid database name specified") + } + + if len(s.locks) == 0 { + return "", errors.New("No locks added") + } + + buf := new(bytes.Buffer) + _, _ = buf.WriteString("LOCK TABLES ") + + for idx, lock := range s.locks { + if lock.t == nil { + return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + } + + if err = lock.t.SerializeSql(database, buf); err != nil { + return + } + + if lock.w { + _, _ = buf.WriteString(" WRITE") + } else { + _, _ = buf.WriteString(" READ") + } + + if idx != len(s.locks)-1 { + _, _ = buf.WriteString(", ") + } + } + + return buf.String(), nil +} + +// NewUnlockStatement returns SQL statement that can be used to release table locks +// grabbed by the current session. +func NewUnlockStatement() UnlockStatement { + return &unlockStatementImpl{} +} + +type unlockStatementImpl struct { +} + +func (s *unlockStatementImpl) String(database string) (sql string, err error) { + return "UNLOCK TABLES", nil +} + +// Set GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID. +func NewGtidNextStatement(sid []byte, gno uint64) GtidNextStatement { + return >idNextStatementImpl{ + sid: sid, + gno: gno, + } +} + +type gtidNextStatementImpl struct { + sid []byte + gno uint64 +} + +func (s *gtidNextStatementImpl) String(database string) (sql string, err error) { + // This statement sets a session local variable defining what the next transaction ID is. It + // does not interact with other MySQL sessions. It is neither a DDL nor DML statement, so we + // don't have to worry about data corruption. + // Because of the string formatting (hex plus an integer), can't morph into another statement. + // See: https://dev.mysql.com/doc/refman/5.7/en/replication-options-gtids.html + const gtidFormatString = "SET GTID_NEXT=\"%x-%x-%x-%x-%x:%d\"" + + buf := new(bytes.Buffer) + _, _ = buf.WriteString(fmt.Sprintf(gtidFormatString, + s.sid[:4], s.sid[4:6], s.sid[6:8], s.sid[8:10], s.sid[10:], s.gno)) + return buf.String(), nil +} + +// +// Util functions ============================================================= +// + +// Once again, teisenberger is lazy. Here's a quick filter on comments +var validCommentRegexp *regexp.Regexp = regexp.MustCompile("^[\\w .?]*$") + +func isValidComment(comment string) bool { + return validCommentRegexp.MatchString(comment) +} + +func writeComment(comment string, buf *bytes.Buffer) error { + if comment != "" { + _, _ = buf.WriteString("/* ") + if !isValidComment(comment) { + return errors.Newf("Invalid comment: %s", comment) + } + _, _ = buf.WriteString(comment) + _, _ = buf.WriteString(" */") + } + return nil +} + +func newOrderByListClause(clauses ...OrderByClause) *listClause { + ret := &listClause{ + clauses: make([]Clause, len(clauses), len(clauses)), + includeParentheses: false, + } + + for i, c := range clauses { + ret.clauses[i] = c + } + + return ret +} diff --git a/table.go b/table.go new file mode 100644 index 0000000..25c7a39 --- /dev/null +++ b/table.go @@ -0,0 +1,321 @@ +// Modeling of tables. This is where query preparation starts + +package sqlbuilder + +import ( + "bytes" + "fmt" + + "github.com/dropbox/godropbox/errors" +) + +// The sql table read interface. NOTE: NATURAL JOINs, and join "USING" clause +// are not supported. +type ReadableTable interface { + // Returns the list of columns that are in the current table expression. + Columns() []NonAliasColumn + + // Generates the sql string for the current table expression. Note: the + // generated string may not be a valid/executable sql statement. + // The database is the name of the database the table is on + SerializeSql(database string, out *bytes.Buffer) error + + // Generates a select query on the current table. + Select(projections ...Projection) SelectStatement + + // Creates a inner join table expression using onCondition. + InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable + + // Creates a left join table expression using onCondition. + LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable + + // Creates a right join table expression using onCondition. + RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable +} + +// The sql table write interface. +type WritableTable interface { + // Returns the list of columns that are in the table. + Columns() []NonAliasColumn + + // Generates the sql string for the current table expression. Note: the + // generated string may not be a valid/executable sql statement. + // The database is the name of the database the table is on + SerializeSql(database string, out *bytes.Buffer) error + + Insert(columns ...NonAliasColumn) InsertStatement + Update() UpdateStatement + Delete() DeleteStatement +} + +// Defines a physical table in the database that is both readable and writable. +// This function will panic if name is not valid +func NewTable(name string, columns ...NonAliasColumn) *Table { + if !validIdentifierName(name) { + panic("Invalid table name") + } + + t := &Table{ + name: name, + columns: columns, + columnLookup: make(map[string]NonAliasColumn), + } + for _, c := range columns { + err := c.setTableName(name) + if err != nil { + panic(err) + } + t.columnLookup[c.Name()] = c + } + + if len(columns) == 0 { + panic(fmt.Sprintf("Table %s has no columns", name)) + } + + return t +} + +type Table struct { + name string + columns []NonAliasColumn + columnLookup map[string]NonAliasColumn + // If not empty, the name of the index to force + forcedIndex string +} + +// Returns the specified column, or errors if it doesn't exist in the table +func (t *Table) getColumn(name string) (NonAliasColumn, error) { + if c, ok := t.columnLookup[name]; ok { + return c, nil + } + return nil, errors.Newf("No such column '%s' in table '%s'", name, t.name) +} + +// Returns a pseudo column representation of the column name. Error checking +// is deferred to SerializeSql. +func (t *Table) C(name string) NonAliasColumn { + return &deferredLookupColumn{ + table: t, + colName: name, + } +} + +// Returns all columns for a table as a slice of projections +func (t *Table) Projections() []Projection { + result := make([]Projection, 0) + + for _, col := range t.columns { + result = append(result, col) + } + + return result +} + +// Returns the table's name in the database +func (t *Table) Name() string { + return t.name +} + +// Returns a list of the table's columns +func (t *Table) Columns() []NonAliasColumn { + return t.columns +} + +// Returns a copy of this table, but with the specified index forced. +func (t *Table) ForceIndex(index string) *Table { + newTable := *t + newTable.forcedIndex = index + return &newTable +} + +// Generates the sql string for the current table expression. Note: the +// generated string may not be a valid/executable sql statement. +func (t *Table) SerializeSql(database string, out *bytes.Buffer) error { + //Momo modified. if database empty, not write + if database != "" { + _, _ = out.WriteString("`") + _, _ = out.WriteString(database) + _, _ = out.WriteString("`.") + } + _, _ = out.WriteString("`") + _, _ = out.WriteString(t.Name()) + _, _ = out.WriteString("`") + + if t.forcedIndex != "" { + if !validIdentifierName(t.forcedIndex) { + return errors.Newf("'%s' is not a valid identifier for an index", t.forcedIndex) + } + _, _ = out.WriteString(" FORCE INDEX (`") + _, _ = out.WriteString(t.forcedIndex) + _, _ = out.WriteString("`)") + } + + return nil +} + +// Generates a select query on the current table. +func (t *Table) Select(projections ...Projection) SelectStatement { + return newSelectStatement(t, projections) +} + +// Creates a inner join table expression using onCondition. +func (t *Table) InnerJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return InnerJoinOn(t, table, onCondition) +} + +// Creates a left join table expression using onCondition. +func (t *Table) LeftJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return LeftJoinOn(t, table, onCondition) +} + +// Creates a right join table expression using onCondition. +func (t *Table) RightJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return RightJoinOn(t, table, onCondition) +} + +func (t *Table) Insert(columns ...NonAliasColumn) InsertStatement { + return newInsertStatement(t, columns...) +} + +func (t *Table) Update() UpdateStatement { + return newUpdateStatement(t) +} + +func (t *Table) Delete() DeleteStatement { + return newDeleteStatement(t) +} + +type joinType int + +const ( + INNER_JOIN joinType = iota + LEFT_JOIN + RIGHT_JOIN +) + +// Join expressions are pseudo readable tables. +type joinTable struct { + lhs ReadableTable + rhs ReadableTable + join_type joinType + onCondition BoolExpression +} + +func newJoinTable( + lhs ReadableTable, + rhs ReadableTable, + join_type joinType, + onCondition BoolExpression) ReadableTable { + + return &joinTable{ + lhs: lhs, + rhs: rhs, + join_type: join_type, + onCondition: onCondition, + } +} + +func InnerJoinOn( + lhs ReadableTable, + rhs ReadableTable, + onCondition BoolExpression) ReadableTable { + + return newJoinTable(lhs, rhs, INNER_JOIN, onCondition) +} + +func LeftJoinOn( + lhs ReadableTable, + rhs ReadableTable, + onCondition BoolExpression) ReadableTable { + + return newJoinTable(lhs, rhs, LEFT_JOIN, onCondition) +} + +func RightJoinOn( + lhs ReadableTable, + rhs ReadableTable, + onCondition BoolExpression) ReadableTable { + + return newJoinTable(lhs, rhs, RIGHT_JOIN, onCondition) +} + +func (t *joinTable) Columns() []NonAliasColumn { + columns := make([]NonAliasColumn, 0) + columns = append(columns, t.lhs.Columns()...) + columns = append(columns, t.rhs.Columns()...) + + return columns +} + +func (t *joinTable) SerializeSql( + database string, + out *bytes.Buffer) (err error) { + + if t.lhs == nil { + return errors.Newf("nil lhs. Generated sql: %s", out.String()) + } + if t.rhs == nil { + return errors.Newf("nil rhs. Generated sql: %s", out.String()) + } + if t.onCondition == nil { + return errors.Newf("nil onCondition. Generated sql: %s", out.String()) + } + + if err = t.lhs.SerializeSql(database, out); err != nil { + return + } + + switch t.join_type { + case INNER_JOIN: + _, _ = out.WriteString(" JOIN ") + case LEFT_JOIN: + _, _ = out.WriteString(" LEFT JOIN ") + case RIGHT_JOIN: + _, _ = out.WriteString(" RIGHT JOIN ") + } + + if err = t.rhs.SerializeSql(database, out); err != nil { + return + } + + _, _ = out.WriteString(" ON ") + if err = t.onCondition.SerializeSql(out); err != nil { + return + } + + return nil +} + +func (t *joinTable) Select(projections ...Projection) SelectStatement { + return newSelectStatement(t, projections) +} + +func (t *joinTable) InnerJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return InnerJoinOn(t, table, onCondition) +} + +func (t *joinTable) LeftJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return LeftJoinOn(t, table, onCondition) +} + +func (t *joinTable) RightJoinOn( + table ReadableTable, + onCondition BoolExpression) ReadableTable { + + return RightJoinOn(t, table, onCondition) +} diff --git a/test_utils.go b/test_utils.go new file mode 100644 index 0000000..a7a0250 --- /dev/null +++ b/test_utils.go @@ -0,0 +1,26 @@ +package sqlbuilder + +var table1Col1 = IntColumn("col1", Nullable) +var table1Col2 = IntColumn("col2", Nullable) +var table1Col3 = IntColumn("col3", Nullable) +var table1Col4 = DateTimeColumn("col4", Nullable) +var table1 = NewTable( + "table1", + table1Col1, + table1Col2, + table1Col3, + table1Col4) + +var table2Col3 = IntColumn("col3", Nullable) +var table2Col4 = IntColumn("col4", Nullable) +var table2 = NewTable( + "table2", + table2Col3, + table2Col4) + +var table3Col1 = IntColumn("col1", Nullable) +var table3Col2 = IntColumn("col2", Nullable) +var table3 = NewTable( + "table3", + table3Col1, + table3Col2) diff --git a/types.go b/types.go new file mode 100644 index 0000000..c9d05ea --- /dev/null +++ b/types.go @@ -0,0 +1,79 @@ +package sqlbuilder + +import ( + "bytes" +) + +type Clause interface { + SerializeSql(out *bytes.Buffer) error +} + +// A clause that can be used in order by +type OrderByClause interface { + Clause + isOrderByClauseInterface +} + +// An expression +type Expression interface { + Clause + isExpressionInterface +} + +type BoolExpression interface { + Clause + isBoolExpressionInterface +} + +// A clause that is selectable. +type Projection interface { + Clause + isProjectionInterface + SerializeSqlForColumnList(out *bytes.Buffer) error +} + +// +// Boiler plates ... +// + +type isOrderByClauseInterface interface { + isOrderByClauseType() +} + +type isOrderByClause struct { +} + +func (o *isOrderByClause) isOrderByClauseType() { +} + +type isExpressionInterface interface { + isExpressionType() +} + +type isExpression struct { + isOrderByClause // can always use expression in order by. +} + +func (e *isExpression) isExpressionType() { +} + +type isBoolExpressionInterface interface { + isExpressionInterface + isBoolExpressionType() +} + +type isBoolExpression struct { +} + +func (e *isBoolExpression) isBoolExpressionType() { +} + +type isProjectionInterface interface { + isProjectionType() +} + +type isProjection struct { +} + +func (p *isProjection) isProjectionType() { +}