From 8cb2a7212fcc610b5a600550e97cdc211f6161a6 Mon Sep 17 00:00:00 2001 From: Qishuai Liu Date: Wed, 29 Mar 2023 15:12:03 +0900 Subject: [PATCH] add isBool flag to optimize logical expression --- delete_test.go | 2 +- expression.go | 42 +++++++++++++++++++++++++++--------------- expression_test.go | 13 +++++++++---- select.go | 4 ++-- 4 files changed, 39 insertions(+), 22 deletions(-) diff --git a/delete_test.go b/delete_test.go index cca643c..165f68b 100644 --- a/delete_test.go +++ b/delete_test.go @@ -12,7 +12,7 @@ func TestDelete(t *testing.T) { }, } db := newMockDatabase() - if _, err := db.DeleteFrom(Table1).Where(staticExpression("##", 1)).Execute(); err != nil { + if _, err := db.DeleteFrom(Table1).Where(staticExpression("##", 1, false)).Execute(); err != nil { t.Error(err) } assertLastSql(t, "DELETE FROM `table1` WHERE ##") diff --git a/expression.go b/expression.go index 139b2e4..4144075 100644 --- a/expression.go +++ b/expression.go @@ -116,6 +116,7 @@ type expression struct { priority priority isTrue bool isFalse bool + isBool bool } func (e expression) GetTable() Table { @@ -128,10 +129,11 @@ type scope struct { lastJoin *join } -func staticExpression(sql string, priority priority) expression { +func staticExpression(sql string, priority priority, isBool bool) expression { return expression{ sql: sql, priority: priority, + isBool: isBool, } } @@ -139,6 +141,7 @@ func True() BooleanExpression { return expression{ sql: "1", isTrue: true, + isBool: true, } } @@ -146,6 +149,7 @@ func False() BooleanExpression { return expression{ sql: "0", isFalse: true, + isBool: true, } } @@ -401,6 +405,8 @@ func toBooleanExpression(value interface{}) BooleanExpression { return True() case e.isFalse: return False() + case e.isBool: + return e default: return nil } @@ -522,38 +528,44 @@ func (e expression) binaryOperation(operator string, value interface{}, priority }, priority: priority} } -func (e expression) prefixSuffixExpression(prefix string, suffix string, priority priority) expression { +func (e expression) prefixSuffixExpression(prefix string, suffix string, priority priority, isBool bool) expression { if e.sql != "" { return expression{ sql: prefix + e.sql + suffix, priority: priority, + isBool: isBool, } } - return expression{builder: func(scope scope) (string, error) { - exprSql, err := e.GetSQL(scope) - if err != nil { - return "", err - } - return prefix + exprSql + suffix, nil - }, priority: priority} + return expression{ + builder: func(scope scope) (string, error) { + exprSql, err := e.GetSQL(scope) + if err != nil { + return "", err + } + return prefix + exprSql + suffix, nil + }, + priority: priority, + isBool: isBool, + } } func (e expression) IsNull() BooleanExpression { - return e.prefixSuffixExpression("", " IS NULL", 11) + return e.prefixSuffixExpression("", " IS NULL", 11, true) } func (e expression) Not() BooleanExpression { - if e.isTrue { + switch { + case e.isTrue: return False() - } - if e.isFalse { + case e.isFalse: return True() + default: + return e.prefixSuffixExpression("NOT ", "", 13, true) } - return e.prefixSuffixExpression("NOT ", "", 13) } func (e expression) IsNotNull() BooleanExpression { - return e.prefixSuffixExpression("", " IS NOT NULL", 11) + return e.prefixSuffixExpression("", " IS NOT NULL", 11, true) } func expandSliceValue(value reflect.Value) (result []interface{}) { diff --git a/expression_test.go b/expression_test.go index 8e25812..9fb24e5 100644 --- a/expression_test.go +++ b/expression_test.go @@ -162,10 +162,10 @@ func TestMisc(t *testing.T) { assertValue(t, True(), "1") assertValue(t, False(), "0") - assertValue(t, command("COMMAND", staticExpression("", 0)), "COMMAND ") + assertValue(t, command("COMMAND", staticExpression("", 0, false)), "COMMAND ") - assertValue(t, staticExpression("", 1). - prefixSuffixExpression("", "", 1), "") + assertValue(t, staticExpression("", 1, false). + prefixSuffixExpression("", "", 1, false), "") } func TestLogicalExpression(t *testing.T) { @@ -186,7 +186,8 @@ func TestLogicalExpression(t *testing.T) { func TestLogicalOptimizer(t *testing.T) { trueValue := True() falseValue := False() - otherValue := staticExpression("<>", 0) + otherValue := staticExpression("<>", 0, false) + otherBoolValue := staticExpression("<>", 0, true) assertValue(t, trueValue.Or(trueValue), "1") assertValue(t, trueValue.Or(falseValue), "1") @@ -203,4 +204,8 @@ func TestLogicalOptimizer(t *testing.T) { assertValue(t, trueValue.And(otherValue), "1 AND <>") assertValue(t, trueValue.And(123), "1 AND 123") + assertValue(t, falseValue.Or(otherValue), "0 OR <>") + + assertValue(t, trueValue.And(otherBoolValue), "<>") + assertValue(t, falseValue.Or(otherBoolValue), "<>") } diff --git a/select.go b/select.go index 3224d7b..088d445 100644 --- a/select.go +++ b/select.go @@ -345,12 +345,12 @@ func (s selectStatus) Count() (count int, err error) { }}} _, err = s.FetchFirst(&count) } else { - s.base.fields = []Field{staticExpression("COUNT(1)", 0)} + s.base.fields = []Field{staticExpression("COUNT(1)", 0, false)} _, err = s.FetchFirst(&count) } } else { if !s.base.distinct { - s.base.fields = []Field{staticExpression("1", 0)} + s.base.fields = []Field{staticExpression("1", 0, false)} } _, err = s.base.scope.Database.Select(Function("COUNT", 1)). From(s.asDerivedTable("t")).