diff --git a/executor/window.go b/executor/window.go index b1abd402dc799..12cfb1797ee81 100644 --- a/executor/window.go +++ b/executor/window.go @@ -30,137 +30,134 @@ import ( type WindowExec struct { baseExecutor - groupChecker *groupChecker - inputIter *chunk.Iterator4Chunk - inputRow chunk.Row - groupRows []chunk.Row - childResults []*chunk.Chunk - executed bool - meetNewGroup bool - remainingRowsInGroup int - remainingRowsInChunk int - numWindowFuncs int - processor windowProcessor + groupChecker *groupChecker + // inputIter is the iterator of child chunks + inputIter *chunk.Iterator4Chunk + // executed indicates the child executor is drained or something unexpected happened. + executed bool + // resultChunks stores the chunks to return + resultChunks []*chunk.Chunk + // remainingRowsInChunk indicates how many rows the resultChunks[i] is not prepared. + remainingRowsInChunk []int + + numWindowFuncs int + processor windowProcessor } // Close implements the Executor Close interface. func (e *WindowExec) Close() error { - e.childResults = nil return errors.Trace(e.baseExecutor.Close()) } // Next implements the Executor Next interface. func (e *WindowExec) Next(ctx context.Context, chk *chunk.Chunk) error { chk.Reset() - if (e.executed || e.meetNewGroup) && e.remainingRowsInGroup > 0 { - err := e.appendResult2Chunk(chk) + for !e.executed && !e.preparedChunkAvailable() { + err := e.consumeOneGroup(ctx) if err != nil { + e.executed = true return err } } - for !e.executed && (chk.NumRows() == 0 || e.remainingRowsInChunk > 0) { - err := e.consumeOneGroup(ctx, chk) - if err != nil { - e.executed = true - return errors.Trace(err) - } + if len(e.resultChunks) > 0 { + chk.SwapColumns(e.resultChunks[0]) + e.resultChunks[0] = nil // GC it. TODO: Reuse it. + e.resultChunks = e.resultChunks[1:] + e.remainingRowsInChunk = e.remainingRowsInChunk[1:] } return nil } -func (e *WindowExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) error { - var err error - if err = e.fetchChildIfNecessary(ctx, chk); err != nil { - return errors.Trace(err) - } - for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() { - e.meetNewGroup, err = e.groupChecker.meetNewGroup(e.inputRow) +func (e *WindowExec) preparedChunkAvailable() bool { + return len(e.resultChunks) > 0 && e.remainingRowsInChunk[0] == 0 +} + +func (e *WindowExec) consumeOneGroup(ctx context.Context) error { + var groupRows []chunk.Row + for { + eof, err := e.fetchChildIfNecessary(ctx) if err != nil { return errors.Trace(err) } - if e.meetNewGroup && e.remainingRowsInGroup > 0 { - err := e.consumeGroupRows() + if eof { + e.executed = true + return e.consumeGroupRows(groupRows) + } + for inputRow := e.inputIter.Current(); inputRow != e.inputIter.End(); inputRow = e.inputIter.Next() { + meetNewGroup, err := e.groupChecker.meetNewGroup(inputRow) if err != nil { return errors.Trace(err) } - err = e.appendResult2Chunk(chk) - return err + if meetNewGroup { + return e.consumeGroupRows(groupRows) + } + groupRows = append(groupRows, inputRow) } - e.remainingRowsInGroup++ - e.groupRows = append(e.groupRows, e.inputRow) } - return nil } -func (e *WindowExec) consumeGroupRows() (err error) { - if len(e.groupRows) == 0 { +func (e *WindowExec) consumeGroupRows(groupRows []chunk.Row) (err error) { + remainingRowsInGroup := len(groupRows) + if remainingRowsInGroup == 0 { return nil } - e.groupRows, err = e.processor.consumeGroupRows(e.ctx, e.groupRows) - if err != nil { - return errors.Trace(err) + for i := 0; i < len(e.resultChunks); i++ { + remained := mathutil.Min(e.remainingRowsInChunk[i], remainingRowsInGroup) + e.remainingRowsInChunk[i] -= remained + remainingRowsInGroup -= remained + + // TODO: Combine these three methods. + // The old implementation needs the processor has these three methods + // but now it does not have to. + groupRows, err = e.processor.consumeGroupRows(e.ctx, groupRows) + if err != nil { + return errors.Trace(err) + } + _, err = e.processor.appendResult2Chunk(e.ctx, groupRows, e.resultChunks[i], remained) + if err != nil { + return errors.Trace(err) + } + if remainingRowsInGroup == 0 { + e.processor.resetPartialResult() + break + } } return nil } -func (e *WindowExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Chunk) (err error) { - if e.inputIter != nil && e.inputRow != e.inputIter.End() { - return nil - } - - // Before fetching a new batch of input, we should consume the last group rows. - err = e.consumeGroupRows() - if err != nil { - return errors.Trace(err) +func (e *WindowExec) fetchChildIfNecessary(ctx context.Context) (EOF bool, err error) { + if e.inputIter != nil && e.inputIter.Current() != e.inputIter.End() { + return false, nil } childResult := newFirstChunk(e.children[0]) err = Next(ctx, e.children[0], childResult) if err != nil { - return errors.Trace(err) + return false, errors.Trace(err) } - e.childResults = append(e.childResults, childResult) // No more data. - if childResult.NumRows() == 0 { - e.executed = true - err = e.appendResult2Chunk(chk) - return errors.Trace(err) + numRows := childResult.NumRows() + if numRows == 0 { + return true, nil } - e.inputIter = chunk.NewIterator4Chunk(childResult) - e.inputRow = e.inputIter.Begin() - return nil -} - -// appendResult2Chunk appends result of the window function to the result chunk. -func (e *WindowExec) appendResult2Chunk(chk *chunk.Chunk) (err error) { - if err := e.copyChk(chk); err != nil { - return err - } - remained := mathutil.Min(e.remainingRowsInChunk, e.remainingRowsInGroup) - e.groupRows, err = e.processor.appendResult2Chunk(e.ctx, e.groupRows, chk, remained) + resultChk := chunk.New(e.retFieldTypes, 0, numRows) + err = e.copyChk(childResult, resultChk) if err != nil { - return err + return false, err } - e.remainingRowsInGroup -= remained - e.remainingRowsInChunk -= remained - if e.remainingRowsInGroup == 0 { - e.processor.resetPartialResult() - e.groupRows = e.groupRows[:0] - } - return nil + e.resultChunks = append(e.resultChunks, resultChk) + e.remainingRowsInChunk = append(e.remainingRowsInChunk, numRows) + + e.inputIter = chunk.NewIterator4Chunk(childResult) + e.inputIter.Begin() + return false, nil } -func (e *WindowExec) copyChk(chk *chunk.Chunk) error { - if len(e.childResults) == 0 || chk.NumRows() > 0 { - return nil - } - childResult := e.childResults[0] - e.childResults = e.childResults[1:] - e.remainingRowsInChunk = childResult.NumRows() +func (e *WindowExec) copyChk(src, dst *chunk.Chunk) error { columns := e.Schema().Columns[:len(e.Schema().Columns)-e.numWindowFuncs] for i, col := range columns { - if err := chk.MakeRefTo(i, childResult, col.Index); err != nil { + if err := dst.MakeRefTo(i, src, col.Index); err != nil { return err } } diff --git a/executor/window_test.go b/executor/window_test.go index 736be720cbbad..6f9a460d9e36d 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -20,6 +20,7 @@ import ( func (s *testSuite4) TestWindowFunctions(c *C) { tk := testkit.NewTestKit(c, s.store) + var result *testkit.Result tk.MustExec("use test") tk.MustExec("drop table if exists t") tk.MustExec("create table t (a int, b int, c int)") @@ -28,7 +29,7 @@ func (s *testSuite4) TestWindowFunctions(c *C) { tk.MustExec("set @@tidb_enable_window_function = 0") }() tk.MustExec("insert into t values (1,2,3),(4,3,2),(2,3,4)") - result := tk.MustQuery("select count(a) over () from t") + result = tk.MustQuery("select count(a) over () from t") result.Check(testkit.Rows("3", "3", "3")) result = tk.MustQuery("select sum(a) over () + count(a) over () from t") result.Check(testkit.Rows("10", "10", "10")) @@ -178,7 +179,8 @@ func (s *testSuite4) TestWindowFunctions(c *C) { result.Check(testkit.Rows("1 1", "1 2", "2 1", "2 2")) } -func (s *testSuite4) TestWindowFunctionsIssue11614(c *C) { +func (s *testSuite4) TestWindowFunctionsDataReference(c *C) { + // see https://github.com/pingcap/tidb/issues/11614 tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -192,4 +194,10 @@ func (s *testSuite4) TestWindowFunctionsIssue11614(c *C) { result.Check(testkit.Rows("2 1 0", "2 2 0.5", "2 3 1")) result = tk.MustQuery("select a, b, CUME_DIST() over (partition by a order by b) from t") result.Check(testkit.Rows("2 1 0.3333333333333333", "2 2 0.6666666666666666", "2 3 1")) + + // see https://github.com/pingcap/tidb/issues/12415 + result = tk.MustQuery("select b, first_value(b) over (order by b RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) from t") + result.Check(testkit.Rows("1 1", "2 1", "3 1")) + result = tk.MustQuery("select b, first_value(b) over (order by b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) from t") + result.Check(testkit.Rows("1 1", "2 1", "3 1")) }