diff --git a/executor/hash_table.go b/executor/hash_table.go index 2ba840d04fdc9..50acc4447f4df 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -114,7 +114,8 @@ type hashRowContainer struct { memTracker *memory.Tracker // chkBuf buffer the data reads from the disk if rowContainer is spilled. - chkBuf *chunk.Chunk + chkBuf *chunk.Chunk + chkBufSizeForOneProbe int64 } func newHashRowContainer(sCtx sessionctx.Context, hCtx *hashContext, allTypes []*types.FieldType) *hashRowContainer { @@ -213,6 +214,15 @@ func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRo return matched, nil } +// signalCheckpointForJoinMask indicates the times of row probe that a signal detection will be triggered. +const signalCheckpointForJoinMask int = 1<<14 - 1 + +// rowSize is the size of Row. +const rowSize = int64(unsafe.Sizeof(chunk.Row{})) + +// rowPtrSize is the size of RowPtr. +const rowPtrSize = int64(unsafe.Sizeof(chunk.RowPtr{})) + // GetMatchedRowsAndPtrs get matched rows and Ptrs from probeRow. It can be called // in multiple goroutines while each goroutine should keep its own // h and buf. @@ -225,7 +235,23 @@ func (c *hashRowContainer) GetMatchedRowsAndPtrs(probeKey uint64, probeRow chunk matched = matched[:0] var matchedRow chunk.Row matchedPtrs = matchedPtrs[:0] - for _, ptr := range innerPtrs { + + // Some variables used for memTracker. + var ( + matchedDataSize = int64(cap(matched))*rowSize + int64(cap(matchedPtrs))*rowPtrSize + lastChunkBufPointer *chunk.Chunk = nil + memDelta int64 = 0 + needTrackMemUsage = cap(innerPtrs) > signalCheckpointForJoinMask + ) + c.chkBuf = nil + c.memTracker.Consume(-c.chkBufSizeForOneProbe) + if needTrackMemUsage { + c.memTracker.Consume(int64(cap(innerPtrs)) * rowPtrSize) + defer c.memTracker.Consume(-int64(cap(innerPtrs))*rowPtrSize + memDelta) + } + c.chkBufSizeForOneProbe = 0 + + for i, ptr := range innerPtrs { matchedRow, c.chkBuf, err = c.rowContainer.GetRowAndAppendToChunk(ptr, c.chkBuf) if err != nil { return nil, nil, err @@ -235,6 +261,19 @@ func (c *hashRowContainer) GetMatchedRowsAndPtrs(probeKey uint64, probeRow chunk if err != nil { return nil, nil, err } + if needTrackMemUsage && c.chkBuf != lastChunkBufPointer && lastChunkBufPointer != nil { + lastChunkSize := lastChunkBufPointer.MemoryUsage() + c.chkBufSizeForOneProbe += lastChunkSize + memDelta += lastChunkSize + } + lastChunkBufPointer = c.chkBuf + if needTrackMemUsage && (i&signalCheckpointForJoinMask == signalCheckpointForJoinMask) { + // Trigger Consume for checking the OOM Action signal + memDelta += int64(cap(matched))*rowSize + int64(cap(matchedPtrs))*rowPtrSize - matchedDataSize + matchedDataSize = int64(cap(matched))*rowSize + int64(cap(matchedPtrs))*rowPtrSize + c.memTracker.Consume(memDelta + 1) + memDelta = 0 + } if !ok { atomic.AddInt64(&c.stat.probeCollision, 1) continue diff --git a/executor/join_test.go b/executor/join_test.go index a5d5f6efc9fb5..6f56d0a18dc8e 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -2892,3 +2892,20 @@ func TestOuterJoin(t *testing.T) { ), ) } + +func TestCartesianJoinPanic(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t values(1)") + tk.MustExec("set tidb_mem_quota_query = 1 << 30") + tk.MustExec("set global tidb_mem_oom_action = 'CANCEL'") + tk.MustExec("set global tidb_enable_tmp_storage_on_oom = off;") + for i := 0; i < 14; i++ { + tk.MustExec("insert into t select * from t") + } + err := tk.QueryToErr("desc analyze select * from t t1, t t2, t t3, t t4, t t5, t t6;") + require.NotNil(t, err) + require.True(t, strings.Contains(err.Error(), "Out Of Memory Quota!")) +}