Skip to content

Commit

Permalink
bugfix: backport of #14974
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Oct 15, 2024
1 parent a0b1861 commit f3ffc30
Show file tree
Hide file tree
Showing 21 changed files with 401 additions and 315 deletions.
13 changes: 13 additions & 0 deletions go/test/endtoend/vtgate/queries/subquery/subquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,16 @@ func TestSubqueryInAggregation(t *testing.T) {
// This fails as the planner adds `weight_string` method which make the query fail on MySQL.
// mcmp.Exec(`SELECT max((select min(id2) from t1 where t1.id1 = t.id1)) FROM t1 t`)
}

// TestSubqueryInDerivedTable tests that subqueries and derived tables
// are handled correctly when there are joins inside the derived table
func TestSubqueryInDerivedTable(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 18, "vtgate")
mcmp, closer := start(t)
defer closer()

mcmp.Exec("INSERT INTO t1 (id1, id2) VALUES (1, 100), (2, 200), (3, 300), (4, 400), (5, 500);")
mcmp.Exec("INSERT INTO t2 (id3, id4) VALUES (10, 1), (20, 2), (30, 3), (40, 4), (50, 99)")
mcmp.Exec(`select t.a from (select t1.id2, t2.id3, (select id2 from t1 order by id2 limit 1) as a from t1 join t2 on t1.id1 = t2.id4) t`)
mcmp.Exec(`SELECT COUNT(*) FROM (SELECT DISTINCT t1.id1 FROM t1 JOIN t2 ON t1.id1 = t2.id4) dt`)
}
17 changes: 13 additions & 4 deletions go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"slices"
"sort"

"vitess.io/vitess/go/slice"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops"
Expand Down Expand Up @@ -562,25 +564,32 @@ func buildProjection(op *Projection, qb *queryBuilder) error {
}

func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) error {
predicates := slice.Map(op.JoinPredicates, func(jc JoinColumn) sqlparser.Expr {
// since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done
qb.ctx.SkipPredicates[jc.RHSExpr] = nil

return jc.Original.Expr
})
pred := sqlparser.AndExpressions(predicates...)
err := buildQuery(op.LHS, qb)
if err != nil {
return err
}
// If we are going to add the predicate used in join here
// We should not add the predicate's copy of when it was split into
// two parts. To avoid this, we use the SkipPredicates map.
for _, expr := range qb.ctx.JoinPredicates[op.Predicate] {
qb.ctx.SkipPredicates[expr] = nil
for _, pred := range op.JoinPredicates {
qb.ctx.SkipPredicates[pred.RHSExpr] = nil
}
qbR := &queryBuilder{ctx: qb.ctx}
err = buildQuery(op.RHS, qbR)
if err != nil {
return err
}
if op.LeftJoin {
qb.joinOuterWith(qbR, op.Predicate)
qb.joinOuterWith(qbR, pred)
} else {
qb.joinInnerWith(qbR, op.Predicate)
qb.joinInnerWith(qbR, pred)
}
return nil
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ func (a *Aggregator) isDerived() bool {
return a.DT != nil
}

func (a *Aggregator) derivedName() string {
if a.DT == nil {
return ""
}

return a.DT.Alias
}

func (a *Aggregator) FindCol(ctx *plancontext.PlanningContext, in sqlparser.Expr, underRoute bool) (int, error) {
if underRoute && a.isDerived() {
// We don't want to use columns on this operator if it's a derived table under a route.
Expand Down
50 changes: 32 additions & 18 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ type (
// LeftJoin will be true in the case of an outer join
LeftJoin bool

// Before offset planning
Predicate sqlparser.Expr

// JoinColumns keeps track of what AST expression is represented in the Columns array
JoinColumns []JoinColumn

Expand Down Expand Up @@ -86,14 +83,19 @@ type (
}
)

func NewApplyJoin(lhs, rhs ops.Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin {
return &ApplyJoin{
LHS: lhs,
RHS: rhs,
Vars: map[string]int{},
Predicate: predicate,
LeftJoin: leftOuterJoin,
func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, predicate sqlparser.Expr, leftOuterJoin bool) (*ApplyJoin, error) {
aj := &ApplyJoin{
LHS: lhs,
RHS: rhs,
Vars: map[string]int{},
LeftJoin: leftOuterJoin,
}
err := aj.AddJoinPredicate(ctx, predicate)
if err != nil {
return nil, err
}

return aj, nil
}

// Clone implements the Operator interface
Expand All @@ -105,7 +107,6 @@ func (aj *ApplyJoin) Clone(inputs []ops.Operator) ops.Operator {
kopy.JoinColumns = slices.Clone(aj.JoinColumns)
kopy.JoinPredicates = slices.Clone(aj.JoinPredicates)
kopy.Vars = maps.Clone(aj.Vars)
kopy.Predicate = sqlparser.CloneExpr(aj.Predicate)
kopy.ExtraLHSVars = slices.Clone(aj.ExtraLHSVars)
return &kopy
}
Expand Down Expand Up @@ -149,8 +150,9 @@ func (aj *ApplyJoin) IsInner() bool {
}

func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) error {
aj.Predicate = ctx.SemTable.AndExpressions(expr, aj.Predicate)

if expr == nil {
return nil
}
col, err := BreakExpressionInLHSandRHS(ctx, expr, TableID(aj.LHS))
if err != nil {
return err
Expand Down Expand Up @@ -312,11 +314,15 @@ func (aj *ApplyJoin) addOffset(offset int) {
}

func (aj *ApplyJoin) ShortDescription() string {
pred := sqlparser.String(aj.Predicate)
columns := slice.Map(aj.JoinColumns, func(from JoinColumn) string {
return sqlparser.String(from.Original)
})
firstPart := fmt.Sprintf("on %s columns: %s", pred, strings.Join(columns, ", "))
fn := func(cols []JoinColumn) string {
out := slice.Map(cols, func(jc JoinColumn) string {
return jc.String()
})
return strings.Join(out, ", ")
}

firstPart := fmt.Sprintf("on %s columns: %s", fn(aj.JoinPredicates), fn(aj.JoinColumns))

if len(aj.ExtraLHSVars) == 0 {
return firstPart
}
Expand Down Expand Up @@ -419,6 +425,14 @@ func (jc JoinColumn) IsMixedLeftAndRight() bool {
return len(jc.LHSExprs) > 0 && jc.RHSExpr != nil
}

func (jc JoinColumn) String() string {
rhs := sqlparser.String(jc.RHSExpr)
lhs := slice.Map(jc.LHSExprs, func(e BindVarExpr) string {
return sqlparser.String(e.Expr)
})
return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original))
}

func (bve BindVarExpr) String() string {
if bve.Name == "" {
return sqlparser.String(bve.Expr)
Expand Down
1 change: 1 addition & 0 deletions go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func BreakExpressionInLHSandRHS(
expr sqlparser.Expr,
lhs semantics.TableSet,
) (col JoinColumn, err error) {
col.Original = aeWrap(expr)
rewrittenExpr := sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) {
nodeExpr, ok := cursor.Node().(sqlparser.Expr)
if !ok || !fetchByOffset(nodeExpr) {
Expand Down
10 changes: 7 additions & 3 deletions go/vt/vtgate/planbuilder/operators/join_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,17 @@ func mergeShardedRouting(r1 *ShardedRouting, r2 *ShardedRouting) *ShardedRouting
return tr
}

func (jm *joinMerger) getApplyJoin(ctx *plancontext.PlanningContext, op1, op2 *Route) *ApplyJoin {
return NewApplyJoin(op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin)
func (jm *joinMerger) getApplyJoin(ctx *plancontext.PlanningContext, op1, op2 *Route) (*ApplyJoin, error) {
return NewApplyJoin(ctx, op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin)
}

func (jm *joinMerger) merge(ctx *plancontext.PlanningContext, op1, op2 *Route, r Routing) (*Route, error) {
join, err := jm.getApplyJoin(ctx, op1, op2)
if err != nil {
return nil, err
}
return &Route{
Source: jm.getApplyJoin(ctx, op1, op2),
Source: join,
MergedWith: []*Route{op2},
Routing: r,
}, nil
Expand Down
6 changes: 6 additions & 0 deletions go/vt/vtgate/planbuilder/operators/offset_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
// planOffsets will walk the tree top down, adding offset information to columns in the tree for use in further optimization,
func planOffsets(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) {
type offsettable interface {
ops.Operator
planOffsets(ctx *plancontext.PlanningContext) error
}

Expand All @@ -40,6 +41,11 @@ func planOffsets(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Opera
return nil, nil, vterrors.VT13001(fmt.Sprintf("should not see %T here", in))
case offsettable:
err = op.planOffsets(ctx)
if rewrite.DebugOperatorTree {
fmt.Println("Planned offsets for:")
fmt.Println(ops.ToTree(op))
}

}
if err != nil {
return nil, nil, err
Expand Down
103 changes: 0 additions & 103 deletions go/vt/vtgate/planbuilder/operators/operator_funcs.go

This file was deleted.

12 changes: 11 additions & 1 deletion go/vt/vtgate/planbuilder/operators/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ func (sp StarProjections) GetSelectExprs() sqlparser.SelectExprs {

func (ap AliasedProjections) GetColumns() ([]*sqlparser.AliasedExpr, error) {
return slice.Map(ap, func(from *ProjExpr) *sqlparser.AliasedExpr {
return aeWrap(from.ColExpr)
return &sqlparser.AliasedExpr{
As: from.Original.As,
Expr: from.ColExpr,
}
}), nil
}

Expand Down Expand Up @@ -247,6 +250,13 @@ func (p *Projection) isDerived() bool {
return p.DT != nil
}

func (p *Projection) derivedName() string {
if p.DT == nil {
return ""
}
return p.DT.Alias
}

func (p *Projection) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) (int, error) {
ap, err := p.GetAliasedProjections()
if err != nil {
Expand Down
Loading

0 comments on commit f3ffc30

Please sign in to comment.