From 9cb794b2b3885db5b5642d16c3a03f1807e22cf6 Mon Sep 17 00:00:00 2001 From: Yu Shuaipeng Date: Thu, 13 Sep 2018 13:03:56 +0800 Subject: [PATCH 01/10] add arguments in execute statement --- executor/prepared.go | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/executor/prepared.go b/executor/prepared.go index 95400dd2d2bde..fb7a8cd70e56c 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -14,10 +14,11 @@ package executor import ( + "fmt" "math" "sort" + "strings" - "fmt" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/expression" @@ -245,8 +246,15 @@ func (e *DeallocateExec) Next(ctx context.Context, chk *chunk.Chunk) error { func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...interface{}) (ast.Statement, error) { execStmt := &ast.ExecuteStmt{ExecID: ID} execStmt.UsingVars = make([]ast.ExprNode, len(args)) + argStrs := make([]string, 0, len(args)) for i, val := range args { - execStmt.UsingVars[i] = ast.NewValueExpr(val) + expr := ast.NewValueExpr(val) + execStmt.UsingVars[i] = expr + str, err := expr.ToString() + if err != nil { + return nil, err + } + argStrs = append(argStrs, str) } is := GetInfoSchema(ctx) execPlan, err := plan.Optimize(ctx, execStmt, is) @@ -261,6 +269,10 @@ func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...inter Ctx: ctx, } if prepared, ok := ctx.GetSessionVars().PreparedStmts[ID].(*plan.Prepared); ok { + if len(argStrs) > 0 { + prepared.Stmt.SetText(fmt.Sprintf("%s [arguments: %s]", prepared.Stmt.Text(), + strings.Join(argStrs, ","))) + } stmt.Text = prepared.Stmt.Text() } return stmt, nil From 2f0cceba51f1915fe855b735a70298754e59914e Mon Sep 17 00:00:00 2001 From: Yu Shuaipeng Date: Thu, 13 Sep 2018 14:01:35 +0800 Subject: [PATCH 02/10] show in debug and slow query --- executor/adapter.go | 1 - executor/prepared.go | 6 +++--- executor/prepared_test.go | 3 ++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/executor/adapter.go b/executor/adapter.go index 379039dc36a39..31c4b9a39ae07 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -321,7 +321,6 @@ func (a *ExecStmt) buildExecutor(ctx sessionctx.Context) (Executor, error) { if err != nil { return nil, errors.Trace(err) } - a.Text = executorExec.stmt.Text() a.isPreparedStmt = true a.Plan = executorExec.plan e = executorExec.stmtExec diff --git a/executor/prepared.go b/executor/prepared.go index fb7a8cd70e56c..a5168c03151ae 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -269,11 +269,11 @@ func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...inter Ctx: ctx, } if prepared, ok := ctx.GetSessionVars().PreparedStmts[ID].(*plan.Prepared); ok { + argInfo := "" if len(argStrs) > 0 { - prepared.Stmt.SetText(fmt.Sprintf("%s [arguments: %s]", prepared.Stmt.Text(), - strings.Join(argStrs, ","))) + argInfo = fmt.Sprintf(" [arguments: %s]", strings.Join(argStrs, ",")) } - stmt.Text = prepared.Stmt.Text() + stmt.Text = prepared.Stmt.Text() + argInfo } return stmt, nil } diff --git a/executor/prepared_test.go b/executor/prepared_test.go index feb0e10829bab..fc86c03cbc1bb 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -14,6 +14,7 @@ package executor_test import ( + "fmt" "math" "strings" @@ -101,7 +102,7 @@ func (s *testSuite) TestPrepared(c *C) { // Check that ast.Statement created by executor.CompileExecutePreparedStmt has query text. stmt, err := executor.CompileExecutePreparedStmt(tk.Se, stmtId, 1) c.Assert(err, IsNil) - c.Assert(stmt.OriginText(), Equals, query) + c.Assert(stmt.OriginText(), Equals, fmt.Sprintf("%s [arguments: %d]", query, 1)) // Check that rebuild plan works. tk.Se.PrepareTxnCtx(ctx) From 3631775df86bcce4fc6cb87d75890843dde06bd6 Mon Sep 17 00:00:00 2001 From: Yu Shuaipeng Date: Thu, 13 Sep 2018 15:19:22 +0800 Subject: [PATCH 03/10] fix CI --- expression/builtin_time_test.go | 15 ++++++++------- types/datum.go | 10 ++++++---- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 437094bce1d44..0c4d694dd1549 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -2201,13 +2201,13 @@ func (s *testEvaluatorSuite) TestConvertTz(c *C) { {"2004-01-01 12:00:00", "GMT", "MET", true, "2004-01-01 13:00:00"}, {"2004-01-01 12:00:00", "-01:00", "-12:00", true, "2004-01-01 01:00:00"}, {"2004-01-01 12:00:00", "-00:00", "+13:00", true, "2004-01-02 01:00:00"}, - {"2004-01-01 12:00:00", "-00:00", "-13:00", true, ""}, - {"2004-01-01 12:00:00", "-00:00", "-12:88", true, ""}, - {"2004-01-01 12:00:00", "+10:82", "GMT", false, ""}, - {"2004-01-01 12:00:00", "+00:00", "GMT", true, ""}, - {"2004-01-01 12:00:00", "GMT", "+00:00", true, ""}, + {"2004-01-01 12:00:00", "-00:00", "-13:00", true, ""}, + {"2004-01-01 12:00:00", "-00:00", "-12:88", true, ""}, + {"2004-01-01 12:00:00", "+10:82", "GMT", false, ""}, + {"2004-01-01 12:00:00", "+00:00", "GMT", true, ""}, + {"2004-01-01 12:00:00", "GMT", "+00:00", true, ""}, {20040101, "+00:00", "+10:32", true, "2004-01-01 10:32:00"}, - {3.14159, "+00:00", "+10:32", false, ""}, + {3.14159, "+00:00", "+10:32", false, ""}, } fc := funcs[ast.ConvertTz] for _, test := range tests { @@ -2224,7 +2224,8 @@ func (s *testEvaluatorSuite) TestConvertTz(c *C) { } else { c.Assert(err, NotNil) } - result, _ := d.ToString() + result, err := d.ToString() + c.Assert(err, IsNil) c.Assert(result, Equals, test.expect, Commentf("convert_tz(\"%v\", \"%s\", \"%s\")", test.t, test.fromTz, test.toTz)) } } diff --git a/types/datum.go b/types/datum.go index 641cd889d32d9..d51a877716c24 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1501,6 +1501,8 @@ func (d *Datum) ToFloat64(sc *stmtctx.StatementContext) (float64, error) { // ToString gets the string representation of the datum. func (d *Datum) ToString() (string, error) { switch d.Kind() { + case KindNull: + return "", nil case KindInt64: return strconv.FormatInt(d.GetInt64(), 10), nil case KindUint64: @@ -1513,16 +1515,16 @@ func (d *Datum) ToString() (string, error) { return d.GetString(), nil case KindBytes: return d.GetString(), nil - case KindMysqlTime: - return d.GetMysqlTime().String(), nil - case KindMysqlDuration: - return d.GetMysqlDuration().String(), nil case KindMysqlDecimal: return d.GetMysqlDecimal().String(), nil + case KindMysqlDuration: + return d.GetMysqlDuration().String(), nil case KindMysqlEnum: return d.GetMysqlEnum().String(), nil case KindMysqlSet: return d.GetMysqlSet().String(), nil + case KindMysqlTime: + return d.GetMysqlTime().String(), nil case KindMysqlJSON: return d.GetMysqlJSON().String(), nil case KindBinaryLiteral, KindMysqlBit: From 96f9a295a1f16d480dab926952f43bab2499a3aa Mon Sep 17 00:00:00 2001 From: Yu Shuaipeng Date: Thu, 13 Sep 2018 19:13:56 +0800 Subject: [PATCH 04/10] address comments --- executor/prepared.go | 12 ++++++++---- expression/builtin_time_test.go | 15 +++++++-------- types/datum.go | 10 ++++------ 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/executor/prepared.go b/executor/prepared.go index a5168c03151ae..de18b57d2e50a 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -250,11 +250,15 @@ func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...inter for i, val := range args { expr := ast.NewValueExpr(val) execStmt.UsingVars[i] = expr - str, err := expr.ToString() - if err != nil { - return nil, err + if expr.GetDatum().IsNull() { + argStrs = append(argStrs, "") + } else { + str, err := expr.ToString() + if err != nil { + return nil, err + } + argStrs = append(argStrs, str) } - argStrs = append(argStrs, str) } is := GetInfoSchema(ctx) execPlan, err := plan.Optimize(ctx, execStmt, is) diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 0c4d694dd1549..437094bce1d44 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -2201,13 +2201,13 @@ func (s *testEvaluatorSuite) TestConvertTz(c *C) { {"2004-01-01 12:00:00", "GMT", "MET", true, "2004-01-01 13:00:00"}, {"2004-01-01 12:00:00", "-01:00", "-12:00", true, "2004-01-01 01:00:00"}, {"2004-01-01 12:00:00", "-00:00", "+13:00", true, "2004-01-02 01:00:00"}, - {"2004-01-01 12:00:00", "-00:00", "-13:00", true, ""}, - {"2004-01-01 12:00:00", "-00:00", "-12:88", true, ""}, - {"2004-01-01 12:00:00", "+10:82", "GMT", false, ""}, - {"2004-01-01 12:00:00", "+00:00", "GMT", true, ""}, - {"2004-01-01 12:00:00", "GMT", "+00:00", true, ""}, + {"2004-01-01 12:00:00", "-00:00", "-13:00", true, ""}, + {"2004-01-01 12:00:00", "-00:00", "-12:88", true, ""}, + {"2004-01-01 12:00:00", "+10:82", "GMT", false, ""}, + {"2004-01-01 12:00:00", "+00:00", "GMT", true, ""}, + {"2004-01-01 12:00:00", "GMT", "+00:00", true, ""}, {20040101, "+00:00", "+10:32", true, "2004-01-01 10:32:00"}, - {3.14159, "+00:00", "+10:32", false, ""}, + {3.14159, "+00:00", "+10:32", false, ""}, } fc := funcs[ast.ConvertTz] for _, test := range tests { @@ -2224,8 +2224,7 @@ func (s *testEvaluatorSuite) TestConvertTz(c *C) { } else { c.Assert(err, NotNil) } - result, err := d.ToString() - c.Assert(err, IsNil) + result, _ := d.ToString() c.Assert(result, Equals, test.expect, Commentf("convert_tz(\"%v\", \"%s\", \"%s\")", test.t, test.fromTz, test.toTz)) } } diff --git a/types/datum.go b/types/datum.go index d51a877716c24..641cd889d32d9 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1501,8 +1501,6 @@ func (d *Datum) ToFloat64(sc *stmtctx.StatementContext) (float64, error) { // ToString gets the string representation of the datum. func (d *Datum) ToString() (string, error) { switch d.Kind() { - case KindNull: - return "", nil case KindInt64: return strconv.FormatInt(d.GetInt64(), 10), nil case KindUint64: @@ -1515,16 +1513,16 @@ func (d *Datum) ToString() (string, error) { return d.GetString(), nil case KindBytes: return d.GetString(), nil - case KindMysqlDecimal: - return d.GetMysqlDecimal().String(), nil + case KindMysqlTime: + return d.GetMysqlTime().String(), nil case KindMysqlDuration: return d.GetMysqlDuration().String(), nil + case KindMysqlDecimal: + return d.GetMysqlDecimal().String(), nil case KindMysqlEnum: return d.GetMysqlEnum().String(), nil case KindMysqlSet: return d.GetMysqlSet().String(), nil - case KindMysqlTime: - return d.GetMysqlTime().String(), nil case KindMysqlJSON: return d.GetMysqlJSON().String(), nil case KindBinaryLiteral, KindMysqlBit: From f046a8daf106fa341d051852c90effa710fdf73f Mon Sep 17 00:00:00 2001 From: Yu Shuaipeng Date: Fri, 14 Sep 2018 21:36:11 +0800 Subject: [PATCH 05/10] convert prepared params to a datum slice --- expression/builtin_other.go | 3 +-- plan/common_plans.go | 3 ++- sessionctx/variable/session.go | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/expression/builtin_other.go b/expression/builtin_other.go index e7224f4031bd0..e9832c88c2165 100644 --- a/expression/builtin_other.go +++ b/expression/builtin_other.go @@ -771,8 +771,7 @@ func (b *builtinGetParamStringSig) evalString(row chunk.Row) (string, bool, erro } v := sessionVars.PreparedParams[idx] - dt := v.(types.Datum) - str, err := (&dt).ToString() + str, err := v.ToString() if err != nil { return "", true, nil } diff --git a/plan/common_plans.go b/plan/common_plans.go index 8b41cae415bd4..1fb1ead69e265 100644 --- a/plan/common_plans.go +++ b/plan/common_plans.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/auth" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/kvcache" @@ -152,7 +153,7 @@ func (e *Execute) optimizePreparedPlan(ctx sessionctx.Context, is infoschema.Inf } if cap(vars.PreparedParams) < len(e.UsingVars) { - vars.PreparedParams = make([]interface{}, len(e.UsingVars)) + vars.PreparedParams = make([]types.Datum, len(e.UsingVars)) } for i, usingVar := range e.UsingVars { val, err := usingVar.Eval(chunk.Row{}) diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index b0454332fe625..25f84d8c823d4 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -168,7 +168,7 @@ type SessionVars struct { // preparedStmtID is id of prepared statement. preparedStmtID uint32 // params for prepared statements - PreparedParams []interface{} + PreparedParams []types.Datum // retry information RetryInfo *RetryInfo @@ -302,7 +302,7 @@ func NewSessionVars() *SessionVars { systems: make(map[string]string), PreparedStmts: make(map[uint32]interface{}), PreparedStmtNameToID: make(map[string]uint32), - PreparedParams: make([]interface{}, 10), + PreparedParams: make([]types.Datum, 10), TxnCtx: &TransactionContext{}, KVVars: kv.NewVariables(), RetryInfo: &RetryInfo{}, From 7f4dbf7c7be41e9e41037fa63bf6ee3f03f6e1cb Mon Sep 17 00:00:00 2001 From: Yu Shuaipeng Date: Fri, 14 Sep 2018 22:18:09 +0800 Subject: [PATCH 06/10] get args when needed --- executor/adapter.go | 4 ++-- executor/prepared.go | 20 ++------------------ executor/prepared_test.go | 3 +-- session/session.go | 2 +- sessionctx/variable/session.go | 20 ++++++++++++++++++++ 5 files changed, 26 insertions(+), 23 deletions(-) diff --git a/executor/adapter.go b/executor/adapter.go index 31c4b9a39ae07..963b353abc4cb 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -346,9 +346,9 @@ func (a *ExecStmt) logSlowQuery(txnTS uint64, succ bool) { if len(sql) > int(cfg.Log.QueryLogMaxLen) { sql = fmt.Sprintf("%.*q(len:%d)", cfg.Log.QueryLogMaxLen, sql, len(a.Text)) } - sql = QueryReplacer.Replace(sql) - sessVars := a.Ctx.GetSessionVars() + sql = QueryReplacer.Replace(sql) + sessVars.GetExecuteArgumentsInfo() + connID := sessVars.ConnectionID currentDB := sessVars.CurrentDB var tableIDs, indexIDs string diff --git a/executor/prepared.go b/executor/prepared.go index de18b57d2e50a..9b541fe7c46e7 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -17,7 +17,6 @@ import ( "fmt" "math" "sort" - "strings" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/config" @@ -246,19 +245,8 @@ func (e *DeallocateExec) Next(ctx context.Context, chk *chunk.Chunk) error { func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...interface{}) (ast.Statement, error) { execStmt := &ast.ExecuteStmt{ExecID: ID} execStmt.UsingVars = make([]ast.ExprNode, len(args)) - argStrs := make([]string, 0, len(args)) for i, val := range args { - expr := ast.NewValueExpr(val) - execStmt.UsingVars[i] = expr - if expr.GetDatum().IsNull() { - argStrs = append(argStrs, "") - } else { - str, err := expr.ToString() - if err != nil { - return nil, err - } - argStrs = append(argStrs, str) - } + execStmt.UsingVars[i] = ast.NewValueExpr(val) } is := GetInfoSchema(ctx) execPlan, err := plan.Optimize(ctx, execStmt, is) @@ -273,11 +261,7 @@ func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...inter Ctx: ctx, } if prepared, ok := ctx.GetSessionVars().PreparedStmts[ID].(*plan.Prepared); ok { - argInfo := "" - if len(argStrs) > 0 { - argInfo = fmt.Sprintf(" [arguments: %s]", strings.Join(argStrs, ",")) - } - stmt.Text = prepared.Stmt.Text() + argInfo + stmt.Text = prepared.Stmt.Text() } return stmt, nil } diff --git a/executor/prepared_test.go b/executor/prepared_test.go index fc86c03cbc1bb..feb0e10829bab 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -14,7 +14,6 @@ package executor_test import ( - "fmt" "math" "strings" @@ -102,7 +101,7 @@ func (s *testSuite) TestPrepared(c *C) { // Check that ast.Statement created by executor.CompileExecutePreparedStmt has query text. stmt, err := executor.CompileExecutePreparedStmt(tk.Se, stmtId, 1) c.Assert(err, IsNil) - c.Assert(stmt.OriginText(), Equals, fmt.Sprintf("%s [arguments: %d]", query, 1)) + c.Assert(stmt.OriginText(), Equals, query) // Check that rebuild plan works. tk.Se.PrepareTxnCtx(ctx) diff --git a/session/session.go b/session/session.go index 4c47455187439..83e871a3580a3 100644 --- a/session/session.go +++ b/session/session.go @@ -1394,7 +1394,7 @@ func logStmt(node ast.StmtNode, vars *variable.SessionVars) { func logQuery(query string, vars *variable.SessionVars) { if atomic.LoadUint32(&variable.ProcessGeneralLog) != 0 && !vars.InRestrictedSQL { - query = executor.QueryReplacer.Replace(query) + query = executor.QueryReplacer.Replace(query) + vars.GetExecuteArgumentsInfo() log.Infof("[GENERAL_LOG] con:%d user:%s schema_ver:%d start_ts:%d sql:%s", vars.ConnectionID, vars.User, vars.TxnCtx.SchemaVersion, vars.TxnCtx.StartTS, query) } } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 25f84d8c823d4..19758730b371f 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -15,6 +15,7 @@ package variable import ( "crypto/tls" + "fmt" "strings" "sync" "sync/atomic" @@ -443,6 +444,25 @@ func (s *SessionVars) ResetPrevAffectedRows() { } } +func (s *SessionVars) GetExecuteArgumentsInfo() string { + if len(s.PreparedParams) == 0 { + return "" + } + args := make([]string, 0, len(s.PreparedParams)) + for _, v := range s.PreparedParams { + if v.IsNull() { + args = append(args, "") + } else { + str, err := v.ToString() + if err != nil { + terror.Log(err) + } + args = append(args, str) + } + } + return fmt.Sprintf("[arguments: %s]", strings.Join(args, ", ")) +} + // GetSystemVar gets the string value of a system variable. func (s *SessionVars) GetSystemVar(name string) (string, bool) { val, ok := s.systems[name] From 93fd6705b7ebea74eaf5367b96c99a62a8c90811 Mon Sep 17 00:00:00 2001 From: Yu Shuaipeng Date: Fri, 14 Sep 2018 22:46:07 +0800 Subject: [PATCH 07/10] print log when needed --- sessionctx/variable/session.go | 1 + 1 file changed, 1 insertion(+) diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 19758730b371f..3b902e7a98602 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -444,6 +444,7 @@ func (s *SessionVars) ResetPrevAffectedRows() { } } +// GetExecuteArgumentsInfo gets the argument list as a string of execute statement. func (s *SessionVars) GetExecuteArgumentsInfo() string { if len(s.PreparedParams) == 0 { return "" From 463500537f19fa5fbcafac5aef0db33d52f376ab Mon Sep 17 00:00:00 2001 From: Yu Shuaipeng Date: Sat, 15 Sep 2018 00:39:44 +0800 Subject: [PATCH 08/10] fix some more --- ddl/db_change_test.go | 2 +- executor/executor.go | 101 +++++++++++++++++++++++++++++++ executor/prepared.go | 106 +-------------------------------- plan/common_plans.go | 6 +- session/session.go | 4 +- session/session_test.go | 4 +- sessionctx/variable/session.go | 4 +- 7 files changed, 112 insertions(+), 115 deletions(-) diff --git a/ddl/db_change_test.go b/ddl/db_change_test.go index 33eb71eee0bf0..5412b5e676f98 100644 --- a/ddl/db_change_test.go +++ b/ddl/db_change_test.go @@ -296,7 +296,7 @@ func (t *testExecInfo) compileSQL(idx int) (err error) { ctx := context.TODO() se.PrepareTxnCtx(ctx) sctx := se.(sessionctx.Context) - if err = executor.ResetStmtCtx(sctx, c.rawStmt); err != nil { + if err = executor.ResetContextOfStmt(sctx, c.rawStmt); err != nil { return errors.Trace(err) } c.stmt, err = compiler.Compile(ctx, c.rawStmt) diff --git a/executor/executor.go b/executor/executor.go index 812eda2bcd02e..7e8bcfc1f2183 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -21,18 +21,21 @@ import ( "github.com/cznic/mathutil" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/admin" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/memory" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "golang.org/x/net/context" @@ -1060,3 +1063,101 @@ func (e *UnionExec) Close() error { e.resourcePools = nil return errors.Trace(e.baseExecutor.Close()) } + +// ResetContextOfStmt resets the StmtContext and session variables. +// Before every execution, we must clear statement context. +func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { + sessVars := ctx.GetSessionVars() + sc := new(stmtctx.StatementContext) + sc.TimeZone = sessVars.Location() + sc.MemTracker = memory.NewTracker(s.Text(), sessVars.MemQuotaQuery) + switch config.GetGlobalConfig().OOMAction { + case config.OOMActionCancel: + sc.MemTracker.SetActionOnExceed(&memory.PanicOnExceed{}) + case config.OOMActionLog: + sc.MemTracker.SetActionOnExceed(&memory.LogOnExceed{}) + default: + sc.MemTracker.SetActionOnExceed(&memory.LogOnExceed{}) + } + + // TODO: Many same bool variables here. + // We should set only two variables ( + // IgnoreErr and StrictSQLMode) to avoid setting the same bool variables and + // pushing them down to TiKV as flags. + switch stmt := s.(type) { + case *ast.UpdateStmt: + sc.InUpdateOrDeleteStmt = true + sc.DupKeyAsWarning = stmt.IgnoreErr + sc.BadNullAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.DividedByZeroAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.Priority = stmt.Priority + case *ast.DeleteStmt: + sc.InUpdateOrDeleteStmt = true + sc.DupKeyAsWarning = stmt.IgnoreErr + sc.BadNullAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.DividedByZeroAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.Priority = stmt.Priority + case *ast.InsertStmt: + sc.InInsertStmt = true + sc.DupKeyAsWarning = stmt.IgnoreErr + sc.BadNullAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.DividedByZeroAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.Priority = stmt.Priority + case *ast.CreateTableStmt, *ast.AlterTableStmt: + // Make sure the sql_mode is strict when checking column default value. + case *ast.LoadDataStmt: + sc.DupKeyAsWarning = true + sc.BadNullAsWarning = true + sc.TruncateAsWarning = !sessVars.StrictSQLMode + case *ast.SelectStmt: + sc.InSelectStmt = true + + // see https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html#sql-mode-strict + // said "For statements such as SELECT that do not change data, invalid values + // generate a warning in strict mode, not an error." + // and https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html + sc.OverflowAsWarning = true + + // Return warning for truncate error in selection. + sc.TruncateAsWarning = true + sc.IgnoreZeroInDate = true + if opts := stmt.SelectStmtOpts; opts != nil { + sc.Priority = opts.Priority + sc.NotFillCache = !opts.SQLCache + } + sc.PadCharToFullLength = ctx.GetSessionVars().SQLMode.HasPadCharToFullLengthMode() + case *ast.ShowStmt: + sc.IgnoreTruncate = true + sc.IgnoreZeroInDate = true + if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors { + sc.InShowWarning = true + sc.SetWarnings(sessVars.StmtCtx.GetWarnings()) + } + default: + sc.IgnoreTruncate = true + sc.IgnoreZeroInDate = true + } + sessVars.PreparedParams = sessVars.PreparedParams[:0] + if sessVars.LastInsertID > 0 { + sessVars.PrevLastInsertID = sessVars.LastInsertID + sessVars.LastInsertID = 0 + } + sessVars.ResetPrevAffectedRows() + err = sessVars.SetSystemVar("warning_count", fmt.Sprintf("%d", sessVars.StmtCtx.NumWarnings(false))) + if err != nil { + return errors.Trace(err) + } + err = sessVars.SetSystemVar("error_count", fmt.Sprintf("%d", sessVars.StmtCtx.NumWarnings(true))) + if err != nil { + return errors.Trace(err) + } + sessVars.InsertID = 0 + sessVars.StmtCtx = sc + return +} diff --git a/executor/prepared.go b/executor/prepared.go index 9b541fe7c46e7..f94e77e6ac27c 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -14,21 +14,17 @@ package executor import ( - "fmt" "math" "sort" "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" - "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/sqlexec" "github.com/pkg/errors" "golang.org/x/net/context" @@ -214,9 +210,6 @@ func (e *ExecuteExec) Build() error { return errors.Trace(b.err) } e.stmtExec = stmtExec - if err = ResetStmtCtx(e.ctx, e.stmt); err != nil { - return err - } CountStmtNode(e.stmt, e.ctx.GetSessionVars().InRestrictedSQL) logExpensiveQuery(e.stmt, e.plan) return nil @@ -244,6 +237,9 @@ func (e *DeallocateExec) Next(ctx context.Context, chk *chunk.Chunk) error { // CompileExecutePreparedStmt compiles a session Execute command to a stmt.Statement. func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...interface{}) (ast.Statement, error) { execStmt := &ast.ExecuteStmt{ExecID: ID} + if err := ResetContextOfStmt(ctx, execStmt); err != nil { + return nil, err + } execStmt.UsingVars = make([]ast.ExprNode, len(args)) for i, val := range args { execStmt.UsingVars[i] = ast.NewValueExpr(val) @@ -266,99 +262,3 @@ func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...inter return stmt, nil } -// ResetStmtCtx resets the StmtContext. -// Before every execution, we must clear statement context. -func ResetStmtCtx(ctx sessionctx.Context, s ast.StmtNode) (err error) { - sessVars := ctx.GetSessionVars() - sc := new(stmtctx.StatementContext) - sc.TimeZone = sessVars.Location() - sc.MemTracker = memory.NewTracker(s.Text(), sessVars.MemQuotaQuery) - switch config.GetGlobalConfig().OOMAction { - case config.OOMActionCancel: - sc.MemTracker.SetActionOnExceed(&memory.PanicOnExceed{}) - case config.OOMActionLog: - sc.MemTracker.SetActionOnExceed(&memory.LogOnExceed{}) - default: - sc.MemTracker.SetActionOnExceed(&memory.LogOnExceed{}) - } - - // TODO: Many same bool variables here. - // We should set only two variables ( - // IgnoreErr and StrictSQLMode) to avoid setting the same bool variables and - // pushing them down to TiKV as flags. - switch stmt := s.(type) { - case *ast.UpdateStmt: - sc.InUpdateOrDeleteStmt = true - sc.DupKeyAsWarning = stmt.IgnoreErr - sc.BadNullAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.DividedByZeroAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.Priority = stmt.Priority - case *ast.DeleteStmt: - sc.InUpdateOrDeleteStmt = true - sc.DupKeyAsWarning = stmt.IgnoreErr - sc.BadNullAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.DividedByZeroAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.Priority = stmt.Priority - case *ast.InsertStmt: - sc.InInsertStmt = true - sc.DupKeyAsWarning = stmt.IgnoreErr - sc.BadNullAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.DividedByZeroAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.Priority = stmt.Priority - case *ast.CreateTableStmt, *ast.AlterTableStmt: - // Make sure the sql_mode is strict when checking column default value. - case *ast.LoadDataStmt: - sc.DupKeyAsWarning = true - sc.BadNullAsWarning = true - sc.TruncateAsWarning = !sessVars.StrictSQLMode - case *ast.SelectStmt: - sc.InSelectStmt = true - - // see https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html#sql-mode-strict - // said "For statements such as SELECT that do not change data, invalid values - // generate a warning in strict mode, not an error." - // and https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html - sc.OverflowAsWarning = true - - // Return warning for truncate error in selection. - sc.TruncateAsWarning = true - sc.IgnoreZeroInDate = true - if opts := stmt.SelectStmtOpts; opts != nil { - sc.Priority = opts.Priority - sc.NotFillCache = !opts.SQLCache - } - sc.PadCharToFullLength = ctx.GetSessionVars().SQLMode.HasPadCharToFullLengthMode() - case *ast.ShowStmt: - sc.IgnoreTruncate = true - sc.IgnoreZeroInDate = true - if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors { - sc.InShowWarning = true - sc.SetWarnings(sessVars.StmtCtx.GetWarnings()) - } - default: - sc.IgnoreTruncate = true - sc.IgnoreZeroInDate = true - } - if sessVars.LastInsertID > 0 { - sessVars.PrevLastInsertID = sessVars.LastInsertID - sessVars.LastInsertID = 0 - } - sessVars.ResetPrevAffectedRows() - err = sessVars.SetSystemVar("warning_count", fmt.Sprintf("%d", sessVars.StmtCtx.NumWarnings(false))) - if err != nil { - return errors.Trace(err) - } - err = sessVars.SetSystemVar("error_count", fmt.Sprintf("%d", sessVars.StmtCtx.NumWarnings(true))) - if err != nil { - return errors.Trace(err) - } - sessVars.InsertID = 0 - sessVars.StmtCtx = sc - return -} diff --git a/plan/common_plans.go b/plan/common_plans.go index 1fb1ead69e265..edbe618940268 100644 --- a/plan/common_plans.go +++ b/plan/common_plans.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/auth" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/kvcache" @@ -152,16 +151,13 @@ func (e *Execute) optimizePreparedPlan(ctx sessionctx.Context, is infoschema.Inf return errors.Trace(ErrWrongParamCount) } - if cap(vars.PreparedParams) < len(e.UsingVars) { - vars.PreparedParams = make([]types.Datum, len(e.UsingVars)) - } for i, usingVar := range e.UsingVars { val, err := usingVar.Eval(chunk.Row{}) if err != nil { return errors.Trace(err) } prepared.Params[i].SetDatum(val) - vars.PreparedParams[i] = val + vars.PreparedParams = append(vars.PreparedParams, val) } if prepared.SchemaVersion != is.SchemaMetaVersion() { // If the schema version has changed we need to preprocess it again, diff --git a/session/session.go b/session/session.go index 83e871a3580a3..8dd8ab4113294 100644 --- a/session/session.go +++ b/session/session.go @@ -53,7 +53,7 @@ import ( "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/kvcache" - binlog "github.com/pingcap/tipb/go-binlog" + "github.com/pingcap/tipb/go-binlog" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "golang.org/x/net/context" @@ -785,7 +785,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []ast.Rec // Step2: Transform abstract syntax tree to a physical plan(stored in executor.ExecStmt). startTS = time.Now() // Some executions are done in compile stage, so we reset them before compile. - if err := executor.ResetStmtCtx(s, stmtNode); err != nil { + if err := executor.ResetContextOfStmt(s, stmtNode); err != nil { return nil, errors.Trace(err) } stmt, err := compiler.Compile(ctx, stmtNode) diff --git a/session/session_test.go b/session/session_test.go index d95d7e1a58032..e5c2b4d7ef2be 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -42,7 +42,7 @@ import ( "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" "github.com/pingcap/tidb/util/testutil" - binlog "github.com/pingcap/tipb/go-binlog" + "github.com/pingcap/tipb/go-binlog" "golang.org/x/net/context" "google.golang.org/grpc" ) @@ -423,7 +423,7 @@ func (s *testSessionSuite) TestRetryCleanTxn(c *C) { stmtNode, err := parser.New().ParseOneStmt("insert retrytxn values (2, 'a')", "", "") c.Assert(err, IsNil) stmt, _ := session.Compile(context.TODO(), tk.Se, stmtNode) - executor.ResetStmtCtx(tk.Se, stmtNode) + executor.ResetContextOfStmt(tk.Se, stmtNode) history.Add(0, stmt, tk.Se.GetSessionVars().StmtCtx) _, err = tk.Exec("commit") c.Assert(err, NotNil) diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 3b902e7a98602..c2e98780d79e6 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -303,7 +303,7 @@ func NewSessionVars() *SessionVars { systems: make(map[string]string), PreparedStmts: make(map[uint32]interface{}), PreparedStmtNameToID: make(map[string]uint32), - PreparedParams: make([]types.Datum, 10), + PreparedParams: make([]types.Datum, 0, 10), TxnCtx: &TransactionContext{}, KVVars: kv.NewVariables(), RetryInfo: &RetryInfo{}, @@ -461,7 +461,7 @@ func (s *SessionVars) GetExecuteArgumentsInfo() string { args = append(args, str) } } - return fmt.Sprintf("[arguments: %s]", strings.Join(args, ", ")) + return fmt.Sprintf(" [arguments: %s]", strings.Join(args, ", ")) } // GetSystemVar gets the string value of a system variable. From 617df56425469fa58badea436d0b761a026a98ee Mon Sep 17 00:00:00 2001 From: Yu Shuaipeng Date: Sat, 15 Sep 2018 11:35:13 +0800 Subject: [PATCH 09/10] fix reset statement --- executor/executor.go | 55 +++++++++++++++++++++++--------------------- executor/prepared.go | 17 ++++++++++++++ 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/executor/executor.go b/executor/executor.go index 7e8bcfc1f2183..e4dcff437fc63 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1067,10 +1067,10 @@ func (e *UnionExec) Close() error { // ResetContextOfStmt resets the StmtContext and session variables. // Before every execution, we must clear statement context. func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { - sessVars := ctx.GetSessionVars() + vars := ctx.GetSessionVars() sc := new(stmtctx.StatementContext) - sc.TimeZone = sessVars.Location() - sc.MemTracker = memory.NewTracker(s.Text(), sessVars.MemQuotaQuery) + sc.TimeZone = vars.Location() + sc.MemTracker = memory.NewTracker(s.Text(), vars.MemQuotaQuery) switch config.GetGlobalConfig().OOMAction { case config.OOMActionCancel: sc.MemTracker.SetActionOnExceed(&memory.PanicOnExceed{}) @@ -1080,6 +1080,9 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.MemTracker.SetActionOnExceed(&memory.LogOnExceed{}) } + if execStmt, ok := s.(*ast.ExecuteStmt); ok { + s, err = getPreparedStmt(execStmt, vars) + } // TODO: Many same bool variables here. // We should set only two variables ( // IgnoreErr and StrictSQLMode) to avoid setting the same bool variables and @@ -1088,33 +1091,33 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { case *ast.UpdateStmt: sc.InUpdateOrDeleteStmt = true sc.DupKeyAsWarning = stmt.IgnoreErr - sc.BadNullAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.DividedByZeroAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + sc.IgnoreZeroInDate = !vars.StrictSQLMode || stmt.IgnoreErr sc.Priority = stmt.Priority case *ast.DeleteStmt: sc.InUpdateOrDeleteStmt = true sc.DupKeyAsWarning = stmt.IgnoreErr - sc.BadNullAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.DividedByZeroAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + sc.IgnoreZeroInDate = !vars.StrictSQLMode || stmt.IgnoreErr sc.Priority = stmt.Priority case *ast.InsertStmt: sc.InInsertStmt = true sc.DupKeyAsWarning = stmt.IgnoreErr - sc.BadNullAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.DividedByZeroAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr - sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr + sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr + sc.IgnoreZeroInDate = !vars.StrictSQLMode || stmt.IgnoreErr sc.Priority = stmt.Priority case *ast.CreateTableStmt, *ast.AlterTableStmt: // Make sure the sql_mode is strict when checking column default value. case *ast.LoadDataStmt: sc.DupKeyAsWarning = true sc.BadNullAsWarning = true - sc.TruncateAsWarning = !sessVars.StrictSQLMode + sc.TruncateAsWarning = !vars.StrictSQLMode case *ast.SelectStmt: sc.InSelectStmt = true @@ -1137,27 +1140,27 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.IgnoreZeroInDate = true if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors { sc.InShowWarning = true - sc.SetWarnings(sessVars.StmtCtx.GetWarnings()) + sc.SetWarnings(vars.StmtCtx.GetWarnings()) } default: sc.IgnoreTruncate = true sc.IgnoreZeroInDate = true } - sessVars.PreparedParams = sessVars.PreparedParams[:0] - if sessVars.LastInsertID > 0 { - sessVars.PrevLastInsertID = sessVars.LastInsertID - sessVars.LastInsertID = 0 + vars.PreparedParams = vars.PreparedParams[:0] + if vars.LastInsertID > 0 { + vars.PrevLastInsertID = vars.LastInsertID + vars.LastInsertID = 0 } - sessVars.ResetPrevAffectedRows() - err = sessVars.SetSystemVar("warning_count", fmt.Sprintf("%d", sessVars.StmtCtx.NumWarnings(false))) + vars.ResetPrevAffectedRows() + err = vars.SetSystemVar("warning_count", fmt.Sprintf("%d", vars.StmtCtx.NumWarnings(false))) if err != nil { return errors.Trace(err) } - err = sessVars.SetSystemVar("error_count", fmt.Sprintf("%d", sessVars.StmtCtx.NumWarnings(true))) + err = vars.SetSystemVar("error_count", fmt.Sprintf("%d", vars.StmtCtx.NumWarnings(true))) if err != nil { return errors.Trace(err) } - sessVars.InsertID = 0 - sessVars.StmtCtx = sc + vars.InsertID = 0 + vars.StmtCtx = sc return } diff --git a/executor/prepared.go b/executor/prepared.go index f94e77e6ac27c..47d2316ead47c 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" @@ -262,3 +263,19 @@ func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...inter return stmt, nil } +func getPreparedStmt(stmt *ast.ExecuteStmt, vars *variable.SessionVars) (ast.StmtNode, error) { + execID := stmt.ExecID + ok := false + if stmt.Name != "" { + if execID, ok = vars.PreparedStmtNameToID[stmt.Name]; !ok { + return nil, plan.ErrStmtNotFound + } + } + if v, ok := vars.PreparedStmts[execID]; ok { + if prepared, ok := v.(*plan.Prepared); ok { + return prepared.Stmt, nil + } + return nil, plan.ErrStmtNotFound + } + return nil, plan.ErrStmtNotFound +} From 0466a38c18b24e4e34525ba1be1d285a5537d4a6 Mon Sep 17 00:00:00 2001 From: Yu Shuaipeng Date: Sat, 15 Sep 2018 12:12:43 +0800 Subject: [PATCH 10/10] prepare stmts in session var --- ast/misc.go | 8 ++++++++ executor/prepared.go | 11 ++++------- plan/common_plans.go | 15 +++------------ sessionctx/variable/session.go | 5 +++-- 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/ast/misc.go b/ast/misc.go index 2e89eeaec6945..886897d680d6c 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -183,6 +183,14 @@ func (n *DeallocateStmt) Accept(v Visitor) (Node, bool) { return v.Leave(n) } +// Prepared represents a prepared statement. +type Prepared struct { + Stmt StmtNode + Params []*ParamMarkerExpr + SchemaVersion int64 + UseCache bool +} + // ExecuteStmt is a statement to execute PreparedStmt. // See https://dev.mysql.com/doc/refman/5.7/en/execute.html type ExecuteStmt struct { diff --git a/executor/prepared.go b/executor/prepared.go index 47d2316ead47c..26829aa855e09 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -144,7 +144,7 @@ func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error { for i := 0; i < e.ParamCount; i++ { sorter.markers[i].Order = i } - prepared := &plan.Prepared{ + prepared := &ast.Prepared{ Stmt: stmt, Params: sorter.markers, SchemaVersion: e.is.SchemaMetaVersion(), @@ -257,7 +257,7 @@ func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...inter StmtNode: execStmt, Ctx: ctx, } - if prepared, ok := ctx.GetSessionVars().PreparedStmts[ID].(*plan.Prepared); ok { + if prepared, ok := ctx.GetSessionVars().PreparedStmts[ID]; ok { stmt.Text = prepared.Stmt.Text() } return stmt, nil @@ -271,11 +271,8 @@ func getPreparedStmt(stmt *ast.ExecuteStmt, vars *variable.SessionVars) (ast.Stm return nil, plan.ErrStmtNotFound } } - if v, ok := vars.PreparedStmts[execID]; ok { - if prepared, ok := v.(*plan.Prepared); ok { - return prepared.Stmt, nil - } - return nil, plan.ErrStmtNotFound + if prepared, ok := vars.PreparedStmts[execID]; ok { + return prepared.Stmt, nil } return nil, plan.ErrStmtNotFound } diff --git a/plan/common_plans.go b/plan/common_plans.go index edbe618940268..68958485a011e 100644 --- a/plan/common_plans.go +++ b/plan/common_plans.go @@ -117,14 +117,6 @@ type Prepare struct { SQLText string } -// Prepared represents a prepared statement. -type Prepared struct { - Stmt ast.StmtNode - Params []*ast.ParamMarkerExpr - SchemaVersion int64 - UseCache bool -} - // Execute represents prepare plan. type Execute struct { baseSchemaProducer @@ -141,11 +133,10 @@ func (e *Execute) optimizePreparedPlan(ctx sessionctx.Context, is infoschema.Inf if e.Name != "" { e.ExecID = vars.PreparedStmtNameToID[e.Name] } - v := vars.PreparedStmts[e.ExecID] - if v == nil { + prepared, ok := vars.PreparedStmts[e.ExecID] + if !ok { return errors.Trace(ErrStmtNotFound) } - prepared := v.(*Prepared) if len(prepared.Params) != len(e.UsingVars) { return errors.Trace(ErrWrongParamCount) @@ -177,7 +168,7 @@ func (e *Execute) optimizePreparedPlan(ctx sessionctx.Context, is infoschema.Inf return nil } -func (e *Execute) getPhysicalPlan(ctx sessionctx.Context, is infoschema.InfoSchema, prepared *Prepared) (Plan, error) { +func (e *Execute) getPhysicalPlan(ctx sessionctx.Context, is infoschema.InfoSchema, prepared *ast.Prepared) (Plan, error) { var cacheKey kvcache.Key sessionVars := ctx.GetSessionVars() sessionVars.StmtCtx.UseCache = prepared.UseCache diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index c2e98780d79e6..92abdc7460bf9 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "time" + "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" @@ -164,7 +165,7 @@ type SessionVars struct { // systems variables, don't modify it directly, use GetSystemVar/SetSystemVar method. systems map[string]string // PreparedStmts stores prepared statement. - PreparedStmts map[uint32]interface{} + PreparedStmts map[uint32]*ast.Prepared PreparedStmtNameToID map[string]uint32 // preparedStmtID is id of prepared statement. preparedStmtID uint32 @@ -301,7 +302,7 @@ func NewSessionVars() *SessionVars { vars := &SessionVars{ Users: make(map[string]string), systems: make(map[string]string), - PreparedStmts: make(map[uint32]interface{}), + PreparedStmts: make(map[uint32]*ast.Prepared), PreparedStmtNameToID: make(map[string]uint32), PreparedParams: make([]types.Datum, 0, 10), TxnCtx: &TransactionContext{},