Skip to content

Commit

Permalink
add isBool flag to optimize logical expression
Browse files Browse the repository at this point in the history
  • Loading branch information
lqs committed Mar 29, 2023
1 parent 85f09c5 commit 8cb2a72
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 22 deletions.
2 changes: 1 addition & 1 deletion delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func TestDelete(t *testing.T) {
},
}
db := newMockDatabase()
if _, err := db.DeleteFrom(Table1).Where(staticExpression("##", 1)).Execute(); err != nil {
if _, err := db.DeleteFrom(Table1).Where(staticExpression("##", 1, false)).Execute(); err != nil {
t.Error(err)
}
assertLastSql(t, "DELETE FROM `table1` WHERE ##")
Expand Down
42 changes: 27 additions & 15 deletions expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ type expression struct {
priority priority
isTrue bool
isFalse bool
isBool bool
}

func (e expression) GetTable() Table {
Expand All @@ -128,24 +129,27 @@ type scope struct {
lastJoin *join
}

func staticExpression(sql string, priority priority) expression {
func staticExpression(sql string, priority priority, isBool bool) expression {
return expression{
sql: sql,
priority: priority,
isBool: isBool,
}
}

func True() BooleanExpression {
return expression{
sql: "1",
isTrue: true,
isBool: true,
}
}

func False() BooleanExpression {
return expression{
sql: "0",
isFalse: true,
isBool: true,
}
}

Expand Down Expand Up @@ -401,6 +405,8 @@ func toBooleanExpression(value interface{}) BooleanExpression {
return True()
case e.isFalse:
return False()
case e.isBool:
return e
default:
return nil
}
Expand Down Expand Up @@ -522,38 +528,44 @@ func (e expression) binaryOperation(operator string, value interface{}, priority
}, priority: priority}
}

func (e expression) prefixSuffixExpression(prefix string, suffix string, priority priority) expression {
func (e expression) prefixSuffixExpression(prefix string, suffix string, priority priority, isBool bool) expression {
if e.sql != "" {
return expression{
sql: prefix + e.sql + suffix,
priority: priority,
isBool: isBool,
}
}
return expression{builder: func(scope scope) (string, error) {
exprSql, err := e.GetSQL(scope)
if err != nil {
return "", err
}
return prefix + exprSql + suffix, nil
}, priority: priority}
return expression{
builder: func(scope scope) (string, error) {
exprSql, err := e.GetSQL(scope)
if err != nil {
return "", err
}
return prefix + exprSql + suffix, nil
},
priority: priority,
isBool: isBool,
}
}

func (e expression) IsNull() BooleanExpression {
return e.prefixSuffixExpression("", " IS NULL", 11)
return e.prefixSuffixExpression("", " IS NULL", 11, true)
}

func (e expression) Not() BooleanExpression {
if e.isTrue {
switch {
case e.isTrue:
return False()
}
if e.isFalse {
case e.isFalse:
return True()
default:
return e.prefixSuffixExpression("NOT ", "", 13, true)
}
return e.prefixSuffixExpression("NOT ", "", 13)
}

func (e expression) IsNotNull() BooleanExpression {
return e.prefixSuffixExpression("", " IS NOT NULL", 11)
return e.prefixSuffixExpression("", " IS NOT NULL", 11, true)
}

func expandSliceValue(value reflect.Value) (result []interface{}) {
Expand Down
13 changes: 9 additions & 4 deletions expression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ func TestMisc(t *testing.T) {
assertValue(t, True(), "1")
assertValue(t, False(), "0")

assertValue(t, command("COMMAND", staticExpression("<arg>", 0)), "COMMAND <arg>")
assertValue(t, command("COMMAND", staticExpression("<arg>", 0, false)), "COMMAND <arg>")

assertValue(t, staticExpression("<expression>", 1).
prefixSuffixExpression("<prefix>", "<suffix>", 1), "<prefix><expression><suffix>")
assertValue(t, staticExpression("<expression>", 1, false).
prefixSuffixExpression("<prefix>", "<suffix>", 1, false), "<prefix><expression><suffix>")
}

func TestLogicalExpression(t *testing.T) {
Expand All @@ -186,7 +186,8 @@ func TestLogicalExpression(t *testing.T) {
func TestLogicalOptimizer(t *testing.T) {
trueValue := True()
falseValue := False()
otherValue := staticExpression("<>", 0)
otherValue := staticExpression("<>", 0, false)
otherBoolValue := staticExpression("<>", 0, true)

assertValue(t, trueValue.Or(trueValue), "1")
assertValue(t, trueValue.Or(falseValue), "1")
Expand All @@ -203,4 +204,8 @@ func TestLogicalOptimizer(t *testing.T) {

assertValue(t, trueValue.And(otherValue), "1 AND <>")
assertValue(t, trueValue.And(123), "1 AND 123")
assertValue(t, falseValue.Or(otherValue), "0 OR <>")

assertValue(t, trueValue.And(otherBoolValue), "<>")
assertValue(t, falseValue.Or(otherBoolValue), "<>")
}
4 changes: 2 additions & 2 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,12 @@ func (s selectStatus) Count() (count int, err error) {
}}}
_, err = s.FetchFirst(&count)
} else {
s.base.fields = []Field{staticExpression("COUNT(1)", 0)}
s.base.fields = []Field{staticExpression("COUNT(1)", 0, false)}
_, err = s.FetchFirst(&count)
}
} else {
if !s.base.distinct {
s.base.fields = []Field{staticExpression("1", 0)}
s.base.fields = []Field{staticExpression("1", 0, false)}
}
_, err = s.base.scope.Database.Select(Function("COUNT", 1)).
From(s.asDerivedTable("t")).
Expand Down

0 comments on commit 8cb2a72

Please sign in to comment.