diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index e3aa29aca4668..023800b83df5f 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -320,6 +320,50 @@ Projection_11 10000.00 root 9_aux_0 │ └─TableScan_27 10.00 cop table:t, keep order:false, stats:pseudo └─TableReader_33 1.00 root data:TableScan_32 └─TableScan_32 1.00 cop table:t1, range: decided by [s.c], keep order:false, stats:pseudo +insert into t values(1, 1, 1), (2, 2 ,2), (3, 3, 3), (4, 3, 4),(5,3,5); +analyze table t; +explain select t.c in (select count(*) from t s, t t1 where s.b = t.a and s.b = 3 and s.a = t1.a) from t; +id count task operator info +Projection_11 5.00 root 9_aux_0 +└─Apply_13 5.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0) + ├─TableReader_15 5.00 root data:TableScan_14 + │ └─TableScan_14 5.00 cop table:t, range:[-inf,+inf], keep order:false + └─StreamAgg_20 1.00 root funcs:count(1) + └─IndexJoin_49 2.40 root inner join, inner:TableReader_48, outer key:s.a, inner key:t1.a + ├─IndexReader_41 2.40 root index:Selection_40 + │ └─Selection_40 2.40 cop eq(3, test.t.a) + │ └─IndexScan_39 3.00 cop table:s, index:b, range:[3,3], keep order:false + └─TableReader_48 0.80 root data:Selection_47 + └─Selection_47 0.80 cop eq(3, test.t.a) + └─TableScan_46 1.00 cop table:t1, range: decided by [s.a], keep order:false +explain select t.c in (select count(*) from t s left join t t1 on s.a = t1.a where 3 = t.a and s.b = 3) from t; +id count task operator info +Projection_10 5.00 root 9_aux_0 +└─Apply_12 5.00 root left outer semi join, inner:StreamAgg_19, other cond:eq(test.t.c, 7_col_0) + ├─TableReader_14 5.00 root data:TableScan_13 + │ └─TableScan_13 5.00 cop table:t, range:[-inf,+inf], keep order:false + └─StreamAgg_19 1.00 root funcs:count(1) + └─IndexJoin_43 2.40 root left outer join, inner:TableReader_42, outer key:s.a, inner key:t1.a + ├─IndexReader_35 2.40 root index:Selection_34 + │ └─Selection_34 2.40 cop eq(3, test.t.a) + │ └─IndexScan_33 3.00 cop table:s, index:b, range:[3,3], keep order:false + └─TableReader_42 0.80 root data:Selection_41 + └─Selection_41 0.80 cop eq(3, test.t.a) + └─TableScan_40 1.00 cop table:t1, range: decided by [s.a], keep order:false +explain select t.c in (select count(*) from t s right join t t1 on s.a = t1.a where 3 = t.a and t1.b = 3) from t; +id count task operator info +Projection_10 5.00 root 9_aux_0 +└─Apply_12 5.00 root left outer semi join, inner:StreamAgg_19, other cond:eq(test.t.c, 7_col_0) + ├─TableReader_14 5.00 root data:TableScan_13 + │ └─TableScan_13 5.00 cop table:t, range:[-inf,+inf], keep order:false + └─StreamAgg_19 1.00 root funcs:count(1) + └─IndexJoin_43 2.40 root right outer join, inner:TableReader_42, outer key:t1.a, inner key:s.a + ├─TableReader_42 0.80 root data:Selection_41 + │ └─Selection_41 0.80 cop eq(3, test.t.a) + │ └─TableScan_40 1.00 cop table:s, range: decided by [t1.a], keep order:false + └─IndexReader_35 2.40 root index:Selection_34 + └─Selection_34 2.40 cop eq(3, test.t.a) + └─IndexScan_33 3.00 cop table:t1, index:b, range:[3,3], keep order:false drop table if exists t; create table t(a int unsigned); explain select t.a = '123455' from t; diff --git a/cmd/explaintest/t/explain_easy.test b/cmd/explaintest/t/explain_easy.test index 36ae28fb79e58..8c2b399be3115 100644 --- a/cmd/explaintest/t/explain_easy.test +++ b/cmd/explaintest/t/explain_easy.test @@ -52,6 +52,12 @@ explain select t.c in (select count(*) from t s ignore index(idx), t t1 where s. explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.a = t1.a) from t; explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.c = t1.a) from t; +insert into t values(1, 1, 1), (2, 2 ,2), (3, 3, 3), (4, 3, 4),(5,3,5); +analyze table t; +explain select t.c in (select count(*) from t s, t t1 where s.b = t.a and s.b = 3 and s.a = t1.a) from t; +explain select t.c in (select count(*) from t s left join t t1 on s.a = t1.a where 3 = t.a and s.b = 3) from t; +explain select t.c in (select count(*) from t s right join t t1 on s.a = t1.a where 3 = t.a and t1.b = 3) from t; + drop table if exists t; create table t(a int unsigned); explain select t.a = '123455' from t; diff --git a/expression/constant_propagation_test.go b/expression/constant_propagation_test.go index 2598016c54e2f..d457e85d9eb17 100644 --- a/expression/constant_propagation_test.go +++ b/expression/constant_propagation_test.go @@ -208,7 +208,13 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "└─TableDual_9 8000.00 root rows:0", )) - tk.MustQuery("explain select * from t1 left join t2 on t1.a =1 and t1.a = 2;").Check(testkit.Rows( + tk.MustQuery("explain select * from t1 right join t2 on false;").Check(testkit.Rows( + "HashRightJoin_6 80000000.00 root right outer join, inner:TableDual_7", + "├─TableDual_7 8000.00 root rows:0", + "└─TableReader_9 10000.00 root data:TableScan_8", + " └─TableScan_8 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain select * from t1 left join t2 on t1.a = 1 and t1.a = 2;").Check(testkit.Rows( "HashLeftJoin_6 80000000.00 root left outer join, inner:TableDual_9", "├─TableReader_8 10000.00 root data:TableScan_7", "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 719d0dbed892f..3d156b4b830c1 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -180,6 +180,35 @@ func (b *PlanBuilder) buildResultSetNode(node ast.ResultSetNode) (p LogicalPlan, } } +// pushDownConstExpr checks if the condition is from filter condition, if true, push it down to both +// children of join, whatever the join type is; if false, push it down to inner child of outer join, +// and both children of non-outer-join. +func (p *LogicalJoin) pushDownConstExpr(expr expression.Expression, leftCond []expression.Expression, + rightCond []expression.Expression, filterCond bool) ([]expression.Expression, []expression.Expression) { + switch p.JoinType { + case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: + if filterCond { + leftCond = append(leftCond, expr) + // Append the expr to right join condition instead of `rightCond`, to make it able to be + // pushed down to children of join. + p.RightConditions = append(p.RightConditions, expr) + } else { + rightCond = append(rightCond, expr) + } + case RightOuterJoin: + if filterCond { + rightCond = append(rightCond, expr) + p.LeftConditions = append(p.LeftConditions, expr) + } else { + leftCond = append(leftCond, expr) + } + case SemiJoin, AntiSemiJoin, InnerJoin: + leftCond = append(leftCond, expr) + rightCond = append(rightCond, expr) + } + return leftCond, rightCond +} + // extractOnCondition divide conditions in CNF of join node into 4 groups. // These conditions can be where conditions, join conditions, or collection of both. // If deriveLeft/deriveRight is set, we would try to derive more conditions for left/right plan. @@ -233,6 +262,12 @@ func (p *LogicalJoin) extractOnCondition(conditions []expression.Expression, der } } columns := expression.ExtractColumns(expr) + // `columns` may be empty, if the condition is like `correlated_column op constant`, or `constant`, + // push this kind of constant condition down according to join type. + if len(columns) == 0 { + leftCond, rightCond = p.pushDownConstExpr(expr, leftCond, rightCond, deriveLeft || deriveRight) + continue + } allFromLeft, allFromRight := true, true for _, col := range columns { if !left.Schema().Contains(col) { diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 5b1c29a18e534..b4d1702228faf 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -906,7 +906,7 @@ func (s *testPlanSuite) TestJoinReOrder(c *C) { }, { sql: "select * from t o where o.b in (select t3.c from t t1, t t2, t t3 where t1.a = t3.a and t2.a = t3.a and t2.a = o.a and t1.a = 1)", - best: "Apply{DataScan(o)->Join{Join{DataScan(t3)->DataScan(t1)}->DataScan(t2)}->Projection}->Projection", + best: "Apply{DataScan(o)->Join{Join{DataScan(t1)->DataScan(t2)}->DataScan(t3)}->Projection}->Projection", }, } for _, tt := range tests {