diff --git a/server/conn_stmt.go b/server/conn_stmt.go index acb12ae660b94..19e77ce222d51 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -263,6 +263,16 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm BinaryArgs: args, PrepStmt: prepStmt, } + + // For the combination of `ComPrepare` and `ComExecute`, the statement name is stored in the client side, and the + // TiDB only has the ID, so don't try to construct an `EXECUTE SOMETHING`. Use the original prepared statement here + // instead. + sql := "" + planCacheStmt, ok := prepStmt.(*plannercore.PlanCacheStmt) + if ok { + sql = planCacheStmt.StmtText + } + execStmt.SetText(charset.EncodingUTF8Impl, sql) rs, err := (&cc.ctx).ExecuteStmt(ctx, execStmt) if err != nil { return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) diff --git a/server/conn_test.go b/server/conn_test.go index 6571d7efd5352..94461685c33c4 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -1850,3 +1850,35 @@ func TestAuthSha(t *testing.T) { // fastAuthFail and the rest of the auth process. require.Equal(t, authData, []byte{}) } + +func TestProcessInfoForExecuteCommand(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + cc := &clientConn{ + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + } + ctx := context.Background() + + tk.MustExec("use test") + cc.setCtx(&TiDBContext{Session: tk.Session(), stmts: make(map[int]*TiDBStatement)}) + + tk.MustExec("create table t (col1 int)") + + // simple prepare and execute + require.NoError(t, cc.handleStmtPrepare(ctx, "select sum(col1) from t")) + require.NoError(t, cc.handleStmtExecute(ctx, []byte{0x1, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0})) + require.Equal(t, cc.ctx.Session.ShowProcess().Info, "select sum(col1) from t") + + // prepare and execute with params + require.NoError(t, cc.handleStmtPrepare(ctx, "select sum(col1) from t where col1 < ? and col1 > 100")) + // 1 params, length of nullBitMap is 1, `0x8, 0x0` represents the type, and the following `0x10, 0x0....` is the param + // 10 + require.NoError(t, cc.handleStmtExecute(ctx, []byte{0x2, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, + 0x1, 0x8, 0x0, + 0x0A, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0})) + require.Equal(t, cc.ctx.Session.ShowProcess().Info, "select sum(col1) from t where col1 < ? and col1 > 100") +}