Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: refine text protocol multiple query response (#11263) #11290

Merged
merged 1 commit into from
Jul 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import (
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/memory"
"github.com/pingcap/tidb/util/sqlexec"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -945,18 +946,22 @@ func (cc *clientConn) flush() error {

func (cc *clientConn) writeOK() error {
msg := cc.ctx.LastMessage()
return cc.writeOkWith(msg, cc.ctx.AffectedRows(), cc.ctx.LastInsertID(), cc.ctx.Status(), cc.ctx.WarningCount())
}

func (cc *clientConn) writeOkWith(msg string, affectedRows, lastInsertID uint64, status, warnCnt uint16) error {
enclen := 0
if len(msg) > 0 {
enclen = lengthEncodedIntSize(uint64(len(msg))) + len(msg)
}

data := cc.alloc.AllocWithLen(4, 32+enclen)
data = append(data, mysql.OKHeader)
data = dumpLengthEncodedInt(data, cc.ctx.AffectedRows())
data = dumpLengthEncodedInt(data, cc.ctx.LastInsertID())
data = dumpLengthEncodedInt(data, affectedRows)
data = dumpLengthEncodedInt(data, lastInsertID)
if cc.capability&mysql.ClientProtocol41 > 0 {
data = dumpUint16(data, cc.ctx.Status())
data = dumpUint16(data, cc.ctx.WarningCount())
data = dumpUint16(data, status)
data = dumpUint16(data, warnCnt)
}
if enclen > 0 {
// although MySQL manual says the info message is string<EOF>(https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html),
Expand Down Expand Up @@ -1396,12 +1401,27 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet
}

func (cc *clientConn) writeMultiResultset(ctx context.Context, rss []ResultSet, binary bool) error {
for _, rs := range rss {
if err := cc.writeResultset(ctx, rs, binary, mysql.ServerMoreResultsExists, 0); err != nil {
for i, rs := range rss {
lastRs := i == len(rss)-1
if r, ok := rs.(*tidbResultSet).recordSet.(sqlexec.MultiQueryNoDelayResult); ok {
status := r.Status()
if !lastRs {
status |= mysql.ServerMoreResultsExists
}
if err := cc.writeOkWith(r.LastMessage(), r.AffectedRows(), r.LastInsertID(), status, r.WarnCount()); err != nil {
return err
}
continue
}
status := uint16(0)
if !lastRs {
status |= mysql.ServerMoreResultsExists
}
if err := cc.writeResultset(ctx, rs, binary, status, 0); err != nil {
return err
}
}
return cc.writeOK()
return nil
}

func (cc *clientConn) setConn(conn net.Conn) {
Expand Down
59 changes: 57 additions & 2 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu
s.processInfo.Store(&pi)
}

func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode ast.StmtNode, stmt sqlexec.Statement, recordSets []sqlexec.RecordSet) ([]sqlexec.RecordSet, error) {
func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode ast.StmtNode, stmt sqlexec.Statement, recordSets []sqlexec.RecordSet, inMulitQuery bool) ([]sqlexec.RecordSet, error) {
s.SetValue(sessionctx.QueryString, stmt.OriginText())
if _, ok := stmtNode.(ast.DDLNode); ok {
s.SetValue(sessionctx.LastExecuteDDL, true)
Expand All @@ -970,6 +970,16 @@ func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode
sessionExecuteRunDurationGeneral.Observe(time.Since(startTime).Seconds())
}

if inMulitQuery && recordSet == nil {
recordSet = &multiQueryNoDelayRecordSet{
affectedRows: s.AffectedRows(),
lastMessage: s.LastMessage(),
warnCount: s.sessionVars.StmtCtx.WarningCount(),
lastInsertID: s.sessionVars.StmtCtx.LastInsertID,
status: s.sessionVars.Status,
}
}

if recordSet != nil {
recordSets = append(recordSets, recordSet)
}
Expand Down Expand Up @@ -1016,6 +1026,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec

var tempStmtNodes []ast.StmtNode
compiler := executor.Compiler{Ctx: s}
multiQuery := len(stmtNodes) > 1
for idx, stmtNode := range stmtNodes {
s.PrepareTxnCtx(ctx)

Expand Down Expand Up @@ -1052,7 +1063,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec
s.currentPlan = stmt.Plan

// Step3: Execute the physical plan.
if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets); err != nil {
if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets, multiQuery); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -1889,3 +1900,47 @@ func (s *session) recordTransactionCounter(err error) {
}
}
}

type multiQueryNoDelayRecordSet struct {
affectedRows uint64
lastMessage string
status uint16
warnCount uint16
lastInsertID uint64
}

func (c *multiQueryNoDelayRecordSet) Fields() []*ast.ResultField {
panic("unsupported method")
}

func (c *multiQueryNoDelayRecordSet) Next(ctx context.Context, chk *chunk.Chunk) error {
panic("unsupported method")
}

func (c *multiQueryNoDelayRecordSet) NewChunk() *chunk.Chunk {
panic("unsupported method")
}

func (c *multiQueryNoDelayRecordSet) Close() error {
return nil
}

func (c *multiQueryNoDelayRecordSet) AffectedRows() uint64 {
return c.affectedRows
}

func (c *multiQueryNoDelayRecordSet) LastMessage() string {
return c.lastMessage
}

func (c *multiQueryNoDelayRecordSet) WarnCount() uint16 {
return c.warnCount
}

func (c *multiQueryNoDelayRecordSet) Status() uint16 {
return c.status
}

func (c *multiQueryNoDelayRecordSet) LastInsertID() uint64 {
return c.lastInsertID
}
14 changes: 14 additions & 0 deletions util/sqlexec/restricted_sql_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,17 @@ type RecordSet interface {
// restart the iteration.
Close() error
}

// MultiQueryNoDelayResult is an interface for one no-delay result for one statement in multi-queries.
type MultiQueryNoDelayResult interface {
// AffectedRows return affected row for one statement in multi-queries.
AffectedRows() uint64
// LastMessage return last message for one statement in multi-queries.
LastMessage() string
// WarnCount return warn count for one statement in multi-queries.
WarnCount() uint16
// Status return status when executing one statement in multi-queries.
Status() uint16
// LastInsertID return last insert id for one statement in multi-queries.
LastInsertID() uint64
}