Skip to content

Commit

Permalink
recursivly trim expression tree
Browse files Browse the repository at this point in the history
  • Loading branch information
zhexuany committed Oct 17, 2018
1 parent 4623f36 commit 7000f42
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 14 deletions.
11 changes: 11 additions & 0 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -410,4 +410,15 @@ id count task operator info
Projection_3 10000.00 root test.t.nb, test.t.nb
└─TableReader_5 10000.00 root data:TableScan_4
└─TableScan_4 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo
explain select 1+ifnull(nb, 0) from t;
id count task operator info
Projection_3 10000.00 root plus(1, test.t.nb)
└─TableReader_5 10000.00 root data:TableScan_4
└─TableScan_4 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo
explain select 1+ifnull(a, 0) from t;
id count task operator info
Projection_3 10000.00 root plus(1, ifnull(test.t.a, 0))
└─TableReader_5 10000.00 root data:TableScan_4
└─TableScan_4 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo
drop table if exists t;
drop table if exists t;
2 changes: 2 additions & 0 deletions cmd/explaintest/t/explain_easy.test
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,6 @@ explain select ifnull(nb, 0) from t;
explain select ifnull(nb, 0), ifnull(nc, 0) from t;
explain select ifnull(a, 0), ifnull(nb, 0) from t;
explain select ifnull(nb, 0), ifnull(nb, 0) from t;
explain select 1+ifnull(nb, 0) from t;
explain select 1+ifnull(a, 0) from t;
drop table if exists t;
6 changes: 6 additions & 0 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ func (b *baseBuiltinFunc) getArgs() []Expression {
return b.args
}

func (b *baseBuiltinFunc) setArgs(args []Expression) {
b.args = args
}

func (b *baseBuiltinFunc) evalInt(row chunk.Row) (int64, bool, error) {
panic("baseBuiltinFunc.evalInt() should never be called.")
}
Expand Down Expand Up @@ -274,6 +278,8 @@ type builtinFunc interface {
evalJSON(row chunk.Row) (val json.BinaryJSON, isNull bool, err error)
// getArgs returns the arguments expressions.
getArgs() []Expression
// setArgs set the arguments expressions to builtFunc.
setArgs(args []Expression)
// equal check if this function equals to another function.
equal(builtinFunc) bool
// getCtx returns this function's context.
Expand Down
5 changes: 5 additions & 0 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ func (sf *ScalarFunction) GetArgs() []Expression {
return sf.Function.getArgs()
}

// SetArgs sets arguments of function.
func (sf *ScalarFunction) SetArgs(args []Expression) {
sf.Function.setArgs(args)
}

// GetCtx gets the context of function.
func (sf *ScalarFunction) GetCtx() sessionctx.Context {
return sf.Function.getCtx()
Expand Down
43 changes: 29 additions & 14 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,21 +576,36 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi
}
}

func eliminateIfNullOnNonNullColumn(p LogicalPlan, expr expression.Expression) expression.Expression {
if scalarExpr, ok := expr.(*expression.ScalarFunction); ok {
if scalarExpr.FuncName.L == ast.Ifnull {
if cExpr, ok := scalarExpr.GetArgs()[0].(*expression.Column); ok {
var colFound *model.ColumnInfo
if ds, ok := p.(*DataSource); ok {
colFound = model.FindColumnInfo(ds.Columns, cExpr.ColName.L)
if mysql.HasNotNullFlag(colFound.Flag) {
return scalarExpr.GetArgs()[0]
}
}
}
func eliminateIfNullOnNotNullCol(p LogicalPlan, expr expression.Expression) expression.Expression {
ds, isDs := p.(*DataSource)
if !isDs {
return expr
}

scalarExpr, isScalarFunc := expr.(*expression.ScalarFunction)
if !isScalarFunc {
return expr
}
exprChildren := scalarExpr.GetArgs()
for i := 0; i < len(exprChildren); i++ {
exprChildren[i] = eliminateIfNullOnNotNullCol(p, exprChildren[i])
}

if scalarExpr.FuncName.L == ast.Ifnull {
colRef, isColRef := exprChildren[0].(*expression.Column)
if !isColRef {
return expr
}

colInfo := model.FindColumnInfo(ds.Columns, colRef.ColName.L)
if !mysql.HasNotNullFlag(colInfo.Flag) {
return expr
}

return colRef
}
return expr
scalarExpr.SetArgs(exprChildren)
return scalarExpr
}

// buildProjection returns a Projection plan and non-aux columns length.
Expand All @@ -602,7 +617,7 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField,
oldLen := 0
for _, field := range fields {
newExpr, np, err := b.rewrite(field.Expr, p, mapper, true)
newExpr = eliminateIfNullOnNonNullColumn(p, newExpr)
newExpr = eliminateIfNullOnNotNullCol(p, newExpr)
if err != nil {
return nil, 0, errors.Trace(err)
}
Expand Down

0 comments on commit 7000f42

Please sign in to comment.