Skip to content

Commit

Permalink
Optimize sub streams when there is a concat op
Browse files Browse the repository at this point in the history
  • Loading branch information
jhchabran committed Jun 4, 2021
1 parent b057e2e commit 5e1a6d2
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 6 deletions.
15 changes: 15 additions & 0 deletions internal/planner/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ var optimizerRules = []func(s *stream.Stream, tx *database.Transaction) (*stream
func Optimize(s *stream.Stream, tx *database.Transaction) (*stream.Stream, error) {
var err error

if firstNode, ok := s.First().(*stream.ConcatOperator); ok {
// If the first operation is a concat, optimize both streams individually.
s1, err := Optimize(firstNode.S1, tx)
if err != nil {
return nil, err
}
s2, err := Optimize(firstNode.S2, tx)
if err != nil {
return nil, err
}

firstNode.S1, firstNode.S2 = s1, s2
return s, nil
}

for _, rule := range optimizerRules {
s, err = rule(s, tx)
if err != nil {
Expand Down
118 changes: 118 additions & 0 deletions internal/planner/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -811,3 +811,121 @@ func TestUseIndexBasedOnSelectionNodeRule_Composite(t *testing.T) {
}
})
}

func TestOptimize(t *testing.T) {
t.Run("concat operator operands are optimized", func(t *testing.T) {
t.Run("PrecalculateExprRule", func(t *testing.T) {
_, tx, cleanup := testutil.NewTestTx(t)
defer cleanup()
testutil.MustExec(t, tx, `
CREATE TABLE foo;
CREATE TABLE bar;
`)

got, err := planner.Optimize(
st.New(st.Concat(
st.New(st.SeqScan("foo")).Pipe(st.Filter(parser.MustParseExpr("a = 1 + 2"))),
st.New(st.SeqScan("bar")).Pipe(st.Filter(parser.MustParseExpr("b = 1 + 2"))),
)),
tx)

want := st.New(st.Concat(
st.New(st.SeqScan("foo")).Pipe(st.Filter(parser.MustParseExpr("a = 3"))),
st.New(st.SeqScan("bar")).Pipe(st.Filter(parser.MustParseExpr("b = 3"))),
))

require.NoError(t, err)
require.Equal(t, want.String(), got.String())
})

t.Run("RemoveUnnecessarySelectionNodesRule", func(t *testing.T) {
_, tx, cleanup := testutil.NewTestTx(t)
defer cleanup()
testutil.MustExec(t, tx, `
CREATE TABLE foo;
CREATE TABLE bar;
`)

got, err := planner.Optimize(
st.New(st.Concat(
st.New(st.SeqScan("foo")).Pipe(st.Filter(parser.MustParseExpr("10"))),
st.New(st.Concat(
st.New(st.SeqScan("bar")).Pipe(st.Filter(parser.MustParseExpr("11"))),
st.New(st.SeqScan("bar")).Pipe(st.Filter(parser.MustParseExpr("12"))),
)),
)),
tx)

want := st.New(st.Concat(
st.New(st.SeqScan("foo")),
st.New(st.Concat(
st.New(st.SeqScan("bar")),
st.New(st.SeqScan("bar")),
)),
))

require.NoError(t, err)
require.Equal(t, want.String(), got.String())
})

t.Run("RemoveUnnecessaryDedupNodeRule", func(t *testing.T) {
_, tx, cleanup := testutil.NewTestTx(t)
defer cleanup()
testutil.MustExec(t, tx, `
CREATE TABLE foo(a integer PRIMARY KEY);
CREATE TABLE bar(a integer PRIMARY KEY);
`)

got, err := planner.Optimize(
st.New(st.Concat(
st.New(st.SeqScan("foo")).
Pipe(st.Project(parser.MustParseExpr("a"))).
Pipe(st.Distinct()),
st.New(st.SeqScan("bar")).
Pipe(st.Project(parser.MustParseExpr("a"))).
Pipe(st.Distinct()),
)),
tx)

want := st.New(st.Concat(
st.New(st.SeqScan("foo")).
Pipe(st.Project(parser.MustParseExpr("a"))),
st.New(st.SeqScan("bar")).
Pipe(st.Project(parser.MustParseExpr("a"))),
))

require.NoError(t, err)
require.Equal(t, want.String(), got.String())
})
})

t.Run("UseIndexBasedOnSelectionNodeRule", func(t *testing.T) {
_, tx, cleanup := testutil.NewTestTx(t)
defer cleanup()
testutil.MustExec(t, tx, `
CREATE TABLE foo;
CREATE TABLE bar;
CREATE INDEX idx_foo_a_d ON foo(a, d);
CREATE INDEX idx_bar_a_d ON bar(a, d);
`)

got, err := planner.Optimize(
st.New(st.Concat(
st.New(st.SeqScan("foo")).
Pipe(st.Filter(parser.MustParseExpr("a = 1"))).
Pipe(st.Filter(parser.MustParseExpr("d = 2"))),
st.New(st.SeqScan("bar")).
Pipe(st.Filter(parser.MustParseExpr("a = 1"))).
Pipe(st.Filter(parser.MustParseExpr("d = 2"))),
)),
tx)

want := st.New(st.Concat(
st.New(st.IndexScan("idx_foo_a_d", st.IndexRange{Min: testutil.ExprList(t, `[1, 2]`), Exact: true})),
st.New(st.IndexScan("idx_bar_a_d", st.IndexRange{Min: testutil.ExprList(t, `[1, 2]`), Exact: true})),
))

require.NoError(t, err)
require.Equal(t, want.String(), got.String())
})
}
12 changes: 6 additions & 6 deletions internal/stream/concat.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,30 @@ import (
// A ConcatOperator concatenates two streams.
type ConcatOperator struct {
baseOperator
s1 *Stream
s2 *Stream
S1 *Stream
S2 *Stream
}

// Concat turns two individual streams into one.
func Concat(s1 *Stream, s2 *Stream) *ConcatOperator {
return &ConcatOperator{s1: s1, s2: s2}
return &ConcatOperator{S1: s1, S2: s2}
}

func (op *ConcatOperator) Iterate(in *expr.Environment, fn func(*expr.Environment) error) error {
err := op.s1.Iterate(in, func(out *expr.Environment) error {
err := op.S1.Iterate(in, func(out *expr.Environment) error {
fn(out)
return nil
})
if err != nil {
return err
}

return op.s2.Iterate(in, func(out *expr.Environment) error {
return op.S2.Iterate(in, func(out *expr.Environment) error {
fn(out)
return nil
})
}

func (op *ConcatOperator) String() string {
return stringutil.Sprintf("concat(%s, %s)", op.s1, op.s2)
return stringutil.Sprintf("concat(%s, %s)", op.S1, op.S2)
}

0 comments on commit 5e1a6d2

Please sign in to comment.