From 0fb474b8427a79cb91cbe1f3eea17a2f1ffab41e Mon Sep 17 00:00:00 2001 From: Shenghui Wu <793703860@qq.com> Date: Thu, 9 Feb 2023 14:40:02 +0800 Subject: [PATCH] This is an automated cherry-pick of #41081 Signed-off-by: ti-chi-bot --- executor/hash_table.go | 121 +++++++++++++++++++++++++++++++++++++++++ executor/join_test.go | 17 ++++++ 2 files changed, 138 insertions(+) diff --git a/executor/hash_table.go b/executor/hash_table.go index b22f98bbef501..acda85e522d6e 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -83,6 +83,14 @@ type hashRowContainer struct { hashTable baseHashTable rowContainer *chunk.RowContainer +<<<<<<< HEAD +======= + memTracker *memory.Tracker + + // chkBuf buffer the data reads from the disk if rowContainer is spilled. + chkBuf *chunk.Chunk + chkBufSizeForOneProbe int64 +>>>>>>> 5cb84186dc (executor: track the memroy usage in HashJoin probe phase (#41081)) } func newHashRowContainer(sCtx sessionctx.Context, estCount int, hCtx *hashContext, allTypes []*types.FieldType) *hashRowContainer { @@ -104,6 +112,88 @@ func (c *hashRowContainer) ShallowCopy() *hashRowContainer { return &newHRC } +<<<<<<< HEAD +======= +// GetMatchedRows get matched rows from probeRow. It can be called +// in multiple goroutines while each goroutine should keep its own +// h and buf. +func (c *hashRowContainer) GetMatchedRows(probeKey uint64, probeRow chunk.Row, hCtx *hashContext, matched []chunk.Row) ([]chunk.Row, error) { + matchedRows, _, err := c.GetMatchedRowsAndPtrs(probeKey, probeRow, hCtx, matched, nil, false) + return matchedRows, err +} + +func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRow chunk.Row, + probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) { + // for NAAJ probe row with null, we should match them with all build rows. + var ( + ok bool + err error + innerPtrs []chunk.RowPtr + ) + c.hashTable.Iter( + func(_ uint64, e *entry) { + entryAddr := e + for entryAddr != nil { + innerPtrs = append(innerPtrs, entryAddr.ptr) + entryAddr = entryAddr.next + } + }) + matched = matched[:0] + if len(innerPtrs) == 0 { + return matched, nil + } + // all built bucket rows come from hash table, their bitmap are all nil (doesn't contain any null). so + // we could only use the probe null bits to filter valid rows. + if probeKeyNullBits != nil && len(probeHCtx.naKeyColIdx) > 1 { + // if len(probeHCtx.naKeyColIdx)=1 + // that means the NA-Join probe key is directly a (null) <-> (fetch all buckets), nothing to do. + // else like + // (null, 1, 2), we should use the not-null probe bit to filter rows. Only fetch rows like + // ( ? , 1, 2), that exactly with value as 1 and 2 in the second and third join key column. + needCheckProbeRowPos = needCheckProbeRowPos[:0] + needCheckBuildRowPos = needCheckBuildRowPos[:0] + keyColLen := len(c.hCtx.naKeyColIdx) + for i := 0; i < keyColLen; i++ { + // since all bucket is from hash table (Not Null), so the buildSideNullBits check is eliminated. + if probeKeyNullBits.UnsafeIsSet(i) { + continue + } + needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i]) + needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i]) + } + } + var mayMatchedRow chunk.Row + for _, ptr := range innerPtrs { + mayMatchedRow, c.chkBuf, err = c.rowContainer.GetRowAndAppendToChunk(ptr, c.chkBuf) + if err != nil { + return nil, err + } + if probeKeyNullBits != nil && len(probeHCtx.naKeyColIdx) > 1 { + // check the idxs-th value of the join columns. + ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos) + if err != nil { + return nil, err + } + if !ok { + continue + } + // once ok. just append the (maybe) valid build row for latter other conditions check if any. + } + matched = append(matched, mayMatchedRow) + } + return matched, nil +} + +// signalCheckpointForJoin indicates the times of row probe that a signal detection will be triggered. +const signalCheckpointForJoin int = 1 << 14 + +// rowSize is the size of Row. +const rowSize = int64(unsafe.Sizeof(chunk.Row{})) + +// rowPtrSize is the size of RowPtr. +const rowPtrSize = int64(unsafe.Sizeof(chunk.RowPtr{})) + +>>>>>>> 5cb84186dc (executor: track the memroy usage in HashJoin probe phase (#41081)) // GetMatchedRowsAndPtrs get matched rows and Ptrs from probeRow. It can be called // in multiple goroutines while each goroutine should keep its own // h and buf. @@ -114,9 +204,27 @@ func (c *hashRowContainer) GetMatchedRowsAndPtrs(probeKey uint64, probeRow chunk } matched = make([]chunk.Row, 0, len(innerPtrs)) var matchedRow chunk.Row +<<<<<<< HEAD matchedPtrs = make([]chunk.RowPtr, 0, len(innerPtrs)) for _, ptr := range innerPtrs { matchedRow, err = c.rowContainer.GetRow(ptr) +======= + matchedPtrs = matchedPtrs[:0] + + // Some variables used for memTracker. + var ( + matchedDataSize = int64(cap(matched))*rowSize + int64(cap(matchedPtrs))*rowPtrSize + lastChunkBufPointer *chunk.Chunk = nil + memDelta int64 = 0 + ) + c.chkBuf = nil + c.memTracker.Consume(-c.chkBufSizeForOneProbe + int64(cap(innerPtrs))*rowPtrSize) + defer c.memTracker.Consume(-int64(cap(innerPtrs))*rowPtrSize + memDelta) + c.chkBufSizeForOneProbe = 0 + + for i, ptr := range innerPtrs { + matchedRow, c.chkBuf, err = c.rowContainer.GetRowAndAppendToChunk(ptr, c.chkBuf) +>>>>>>> 5cb84186dc (executor: track the memroy usage in HashJoin probe phase (#41081)) if err != nil { return } @@ -125,6 +233,19 @@ func (c *hashRowContainer) GetMatchedRowsAndPtrs(probeKey uint64, probeRow chunk if err != nil { return } + if c.chkBuf != lastChunkBufPointer && lastChunkBufPointer != nil { + lastChunkSize := lastChunkBufPointer.MemoryUsage() + c.chkBufSizeForOneProbe += lastChunkSize + memDelta += lastChunkSize + } + lastChunkBufPointer = c.chkBuf + if i&signalCheckpointForJoin == 0 { + // Trigger Consume for checking the OOM Action signal + memDelta += int64(cap(matched))*rowSize + int64(cap(matchedPtrs))*rowPtrSize - matchedDataSize + matchedDataSize = int64(cap(matched))*rowSize + int64(cap(matchedPtrs))*rowPtrSize + c.memTracker.Consume(memDelta + 1) + memDelta = 0 + } if !ok { atomic.AddInt64(&c.stat.probeCollision, 1) continue diff --git a/executor/join_test.go b/executor/join_test.go index de8b524684ebe..3107c7180ba4b 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -2808,3 +2808,20 @@ func (s *testSuiteJoinSerial) TestIssue37932(c *C) { } c.Assert(err, IsNil) } + +func TestCartesianJoinPanic(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t values(1)") + tk.MustExec("set tidb_mem_quota_query = 1 << 30") + tk.MustExec("set global tidb_mem_oom_action = 'CANCEL'") + tk.MustExec("set global tidb_enable_tmp_storage_on_oom = off;") + for i := 0; i < 14; i++ { + tk.MustExec("insert into t select * from t") + } + err := tk.QueryToErr("desc analyze select * from t t1, t t2, t t3, t t4, t t5, t t6;") + require.NotNil(t, err) + require.True(t, strings.Contains(err.Error(), "Out Of Memory Quota!")) +}