From b797290a0999ea98e538d2681d3a75ad9b8a76e1 Mon Sep 17 00:00:00 2001 From: YangKeao Date: Tue, 30 May 2023 18:52:41 +0800 Subject: [PATCH] server: fix memtracker leak with cursor (#44257) (#44280) close pingcap/tidb#44254 --- server/conn_stmt.go | 18 ++++++++------ server/conn_stmt_test.go | 53 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 8 deletions(-) diff --git a/server/conn_stmt.go b/server/conn_stmt.go index 177fdecfd3d72..145e83aee82b1 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -275,9 +275,19 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm } execStmt.SetText(charset.EncodingUTF8Impl, sql) rs, err := (&cc.ctx).ExecuteStmt(ctx, execStmt) + if rs != nil { + defer terror.Call(rs.Close) + } if err != nil { + // If error is returned during the planner phase or the executor.Open + // phase, the rs will be nil, and StmtCtx.MemTracker StmtCtx.DiskTracker + // will not be detached. We need to detach them manually. + if sv := cc.ctx.GetSessionVars(); sv != nil && sv.StmtCtx != nil { + sv.StmtCtx.DetachMemDiskTracker() + } return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) } + if rs == nil { if useCursor { vars.SetStatusFlag(mysql.ServerStatusCursorExists, false) @@ -331,13 +341,6 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm cl.OnFetchReturned() } - // as the `Next` of `ResultSet` will never be called, all rows have been cached inside it. We could close this - // `ResultSet`. - err = rs.Close() - if err != nil { - return false, err - } - stmt.SetCursorActive(true) // explicitly flush columnInfo to client. @@ -348,7 +351,6 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm return false, cc.flush(ctx) } - defer terror.Call(rs.Close) retryable, err := cc.writeResultset(ctx, rs, true, cc.ctx.Status(), 0) if err != nil { return retryable, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) diff --git a/server/conn_stmt_test.go b/server/conn_stmt_test.go index dff61b203bf5e..4e054ff889b97 100644 --- a/server/conn_stmt_test.go +++ b/server/conn_stmt_test.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "encoding/binary" + "fmt" "testing" "github.com/pingcap/tidb/expression" @@ -372,3 +373,55 @@ func TestCursorWithParams(t *testing.T) { 0x0, 0x1, 0x0, 0x0, ))) } + +func TestCursorDetachMemTracker(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + srv := CreateMockServer(t, store) + srv.SetDomain(dom) + defer srv.Close() + + appendUint32 := binary.LittleEndian.AppendUint32 + ctx := context.Background() + c := CreateMockConn(t, srv).(*mockConn) + + tk := testkit.NewTestKitWithSession(t, store, c.Context().Session) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id_1 int, id_2 int)") + tk.MustExec("insert into t values (1, 1), (1, 2)") + tk.MustExec("set global tidb_mem_oom_action = 'CANCEL'") + defer tk.MustExec("set global tidb_mem_oom_action= DEFAULT") + // TODO: find whether it's expected to have one child at the beginning + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 1) + + // execute a normal statement, it'll success + stmt, _, _, err := c.Context().Prepare("select count(id_2) from t") + require.NoError(t, err) + + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), + mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + ))) + maxConsumed := tk.Session().GetSessionVars().MemTracker.MaxConsumed() + + // testkit also uses `PREPARE` related calls to run statement with arguments. + // format the SQL to avoid the interference from testkit. + tk.MustExec(fmt.Sprintf("set tidb_mem_quota_query=%d", maxConsumed/2)) + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 1) + // This query should exceed the memory limitation during `openExecutor` + require.Error(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), + mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + ))) + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 1) + + // The next query should succeed + tk.MustExec(fmt.Sprintf("set tidb_mem_quota_query=%d", maxConsumed+1)) + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 1) + // This query should succeed + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), + mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + ))) + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 1) +}