diff --git a/statistics/BUILD.bazel b/statistics/BUILD.bazel index e6992020197c3..8dccd523fc887 100644 --- a/statistics/BUILD.bazel +++ b/statistics/BUILD.bazel @@ -13,6 +13,7 @@ go_library( "fmsketch.go", "histogram.go", "index.go", + "interact_with_storage.go", "merge_worker.go", "row_sampler.go", "sample.go", diff --git a/statistics/handle/dump.go b/statistics/handle/dump.go index 81e982881ee83..da8603ea90573 100644 --- a/statistics/handle/dump.go +++ b/statistics/handle/dump.go @@ -263,7 +263,7 @@ func (h *Handle) tableHistoricalStatsToJSON(physicalID int64, snapshot uint64) ( }() // get meta version - rows, _, err := reader.read("select distinct version from mysql.stats_meta_history where table_id = %? and version <= %? order by version desc limit 1", physicalID, snapshot) + rows, _, err := reader.Read("select distinct version from mysql.stats_meta_history where table_id = %? and version <= %? order by version desc limit 1", physicalID, snapshot) if err != nil { return nil, errors.AddStack(err) } @@ -272,14 +272,14 @@ func (h *Handle) tableHistoricalStatsToJSON(physicalID int64, snapshot uint64) ( } statsMetaVersion := rows[0].GetInt64(0) // get stats meta - rows, _, err = reader.read("select modify_count, count from mysql.stats_meta_history where table_id = %? and version = %?", physicalID, statsMetaVersion) + rows, _, err = reader.Read("select modify_count, count from mysql.stats_meta_history where table_id = %? and version = %?", physicalID, statsMetaVersion) if err != nil { return nil, errors.AddStack(err) } modifyCount, count := rows[0].GetInt64(0), rows[0].GetInt64(1) // get stats version - rows, _, err = reader.read("select distinct version from mysql.stats_history where table_id = %? and version <= %? order by version desc limit 1", physicalID, snapshot) + rows, _, err = reader.Read("select distinct version from mysql.stats_history where table_id = %? and version <= %? order by version desc limit 1", physicalID, snapshot) if err != nil { return nil, errors.AddStack(err) } @@ -289,7 +289,7 @@ func (h *Handle) tableHistoricalStatsToJSON(physicalID int64, snapshot uint64) ( statsVersion := rows[0].GetInt64(0) // get stats - rows, _, err = reader.read("select stats_data from mysql.stats_history where table_id = %? and version = %? order by seq_no", physicalID, statsVersion) + rows, _, err = reader.Read("select stats_data from mysql.stats_history where table_id = %? and version = %? order by seq_no", physicalID, statsVersion) if err != nil { return nil, errors.AddStack(err) } diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index fc4f86dc54fb8..0f46a1f74f395 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -1067,7 +1067,7 @@ func (h *Handle) LoadNeededHistograms() (err error) { return nil } -func (h *Handle) loadNeededColumnHistograms(reader *statsReader, col model.TableItemID, loadFMSketch bool) (err error) { +func (h *Handle) loadNeededColumnHistograms(reader *statistics.StatsReader, col model.TableItemID, loadFMSketch bool) (err error) { oldCache := h.statsCache.Load().(statsCache) tbl, ok := oldCache.Get(col.TableID) if !ok { @@ -1093,7 +1093,7 @@ func (h *Handle) loadNeededColumnHistograms(reader *statsReader, col model.Table return errors.Trace(err) } } - rows, _, err := reader.read("select stats_ver from mysql.stats_histograms where is_index = 0 and table_id = %? and hist_id = %?", col.TableID, col.ID) + rows, _, err := reader.Read("select stats_ver from mysql.stats_histograms where is_index = 0 and table_id = %? and hist_id = %?", col.TableID, col.ID) if err != nil { return errors.Trace(err) } @@ -1134,7 +1134,7 @@ func (h *Handle) loadNeededColumnHistograms(reader *statsReader, col model.Table return nil } -func (h *Handle) loadNeededIndexHistograms(reader *statsReader, idx model.TableItemID, loadFMSketch bool) (err error) { +func (h *Handle) loadNeededIndexHistograms(reader *statistics.StatsReader, idx model.TableItemID, loadFMSketch bool) (err error) { oldCache := h.statsCache.Load().(statsCache) tbl, ok := oldCache.Get(idx.TableID) if !ok { @@ -1160,7 +1160,7 @@ func (h *Handle) loadNeededIndexHistograms(reader *statsReader, idx model.TableI return errors.Trace(err) } } - rows, _, err := reader.read("select stats_ver from mysql.stats_histograms where is_index = 1 and table_id = %? and hist_id = %?", idx.TableID, idx.ID) + rows, _, err := reader.Read("select stats_ver from mysql.stats_histograms where is_index = 1 and table_id = %? and hist_id = %?", idx.TableID, idx.ID) if err != nil { return errors.Trace(err) } @@ -1214,12 +1214,12 @@ func (h *Handle) FlushStats() { } } -func (h *Handle) cmSketchAndTopNFromStorage(reader *statsReader, tblID int64, isIndex, histID int64) (_ *statistics.CMSketch, _ *statistics.TopN, err error) { - topNRows, _, err := reader.read("select HIGH_PRIORITY value, count from mysql.stats_top_n where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) +func (h *Handle) cmSketchAndTopNFromStorage(reader *statistics.StatsReader, tblID int64, isIndex, histID int64) (_ *statistics.CMSketch, _ *statistics.TopN, err error) { + topNRows, _, err := reader.Read("select HIGH_PRIORITY value, count from mysql.stats_top_n where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) if err != nil { return nil, nil, err } - rows, _, err := reader.read("select cm_sketch from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) + rows, _, err := reader.Read("select cm_sketch from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) if err != nil { return nil, nil, err } @@ -1229,15 +1229,15 @@ func (h *Handle) cmSketchAndTopNFromStorage(reader *statsReader, tblID int64, is return statistics.DecodeCMSketchAndTopN(rows[0].GetBytes(0), topNRows) } -func (h *Handle) fmSketchFromStorage(reader *statsReader, tblID int64, isIndex, histID int64) (_ *statistics.FMSketch, err error) { - rows, _, err := reader.read("select value from mysql.stats_fm_sketch where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) +func (h *Handle) fmSketchFromStorage(reader *statistics.StatsReader, tblID int64, isIndex, histID int64) (_ *statistics.FMSketch, err error) { + rows, _, err := reader.Read("select value from mysql.stats_fm_sketch where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) if err != nil || len(rows) == 0 { return nil, err } return statistics.DecodeFMSketch(rows[0].GetBytes(0)) } -func (h *Handle) indexStatsFromStorage(reader *statsReader, row chunk.Row, table *statistics.Table, tableInfo *model.TableInfo) error { +func (h *Handle) indexStatsFromStorage(reader *statistics.StatsReader, row chunk.Row, table *statistics.Table, tableInfo *model.TableInfo) error { histID := row.GetInt64(2) distinct := row.GetInt64(3) histVer := row.GetUint64(4) @@ -1247,7 +1247,7 @@ func (h *Handle) indexStatsFromStorage(reader *statsReader, row chunk.Row, table errorRate := statistics.ErrorRate{} flag := row.GetInt64(8) lastAnalyzePos := row.GetDatum(10, types.NewFieldType(mysql.TypeBlob)) - if statistics.IsAnalyzed(flag) && !reader.isHistory() { + if statistics.IsAnalyzed(flag) && !reader.IsHistory() { h.mu.rateMap.clear(table.PhysicalID, histID, true) } else if idx != nil { errorRate = idx.ErrorRate @@ -1295,7 +1295,7 @@ func (h *Handle) indexStatsFromStorage(reader *statsReader, row chunk.Row, table return nil } -func (h *Handle) columnStatsFromStorage(reader *statsReader, row chunk.Row, table *statistics.Table, tableInfo *model.TableInfo, loadAll bool) error { +func (h *Handle) columnStatsFromStorage(reader *statistics.StatsReader, row chunk.Row, table *statistics.Table, tableInfo *model.TableInfo, loadAll bool) error { histID := row.GetInt64(2) distinct := row.GetInt64(3) histVer := row.GetUint64(4) @@ -1307,7 +1307,7 @@ func (h *Handle) columnStatsFromStorage(reader *statsReader, row chunk.Row, tabl col := table.Columns[histID] errorRate := statistics.ErrorRate{} flag := row.GetInt64(8) - if statistics.IsAnalyzed(flag) && !reader.isHistory() { + if statistics.IsAnalyzed(flag) && !reader.IsHistory() { h.mu.rateMap.clear(table.PhysicalID, histID, false) } else if col != nil { errorRate = col.ErrorRate @@ -1439,14 +1439,14 @@ func (h *Handle) TableStatsFromStorage(tableInfo *model.TableInfo, physicalID in } table.Pseudo = false - rows, _, err := reader.read("select modify_count, count from mysql.stats_meta where table_id = %?", physicalID) + rows, _, err := reader.Read("select modify_count, count from mysql.stats_meta where table_id = %?", physicalID) if err != nil || len(rows) == 0 { return nil, err } table.ModifyCount = rows[0].GetInt64(0) table.Count = rows[0].GetInt64(1) - rows, _, err = reader.read("select table_id, is_index, hist_id, distinct_count, version, null_count, tot_col_size, stats_ver, flag, correlation, last_analyze_pos from mysql.stats_histograms where table_id = %?", physicalID) + rows, _, err = reader.Read("select table_id, is_index, hist_id, distinct_count, version, null_count, tot_col_size, stats_ver, flag, correlation, last_analyze_pos from mysql.stats_histograms where table_id = %?", physicalID) // Check deleted table. if err != nil || len(rows) == 0 { return nil, nil @@ -1464,7 +1464,7 @@ func (h *Handle) TableStatsFromStorage(tableInfo *model.TableInfo, physicalID in return h.extendedStatsFromStorage(reader, table, physicalID, loadAll) } -func (h *Handle) extendedStatsFromStorage(reader *statsReader, table *statistics.Table, physicalID int64, loadAll bool) (*statistics.Table, error) { +func (h *Handle) extendedStatsFromStorage(reader *statistics.StatsReader, table *statistics.Table, physicalID int64, loadAll bool) (*statistics.Table, error) { failpoint.Inject("injectExtStatsLoadErr", func() { failpoint.Return(nil, errors.New("gofail extendedStatsFromStorage error")) }) @@ -1474,7 +1474,7 @@ func (h *Handle) extendedStatsFromStorage(reader *statsReader, table *statistics } else { table.ExtendedStats = statistics.NewExtendedStatsColl() } - rows, _, err := reader.read("select name, status, type, column_ids, stats, version from mysql.stats_extended where table_id = %? and status in (%?, %?, %?) and version > %?", physicalID, StatsStatusInited, StatsStatusAnalyzed, StatsStatusDeleted, lastVersion) + rows, _, err := reader.Read("select name, status, type, column_ids, stats, version from mysql.stats_extended where table_id = %? and status in (%?, %?, %?) and version > %?", physicalID, StatsStatusInited, StatsStatusAnalyzed, StatsStatusDeleted, lastVersion) if err != nil || len(rows) == 0 { return table, nil } @@ -1525,7 +1525,7 @@ func (h *Handle) StatsMetaCountAndModifyCount(tableID int64) (int64, int64, erro err = err1 } }() - rows, _, err := reader.read("select count, modify_count from mysql.stats_meta where table_id = %?", tableID) + rows, _, err := reader.Read("select count, modify_count from mysql.stats_meta where table_id = %?", tableID) if err != nil { return 0, 0, err } @@ -1913,8 +1913,8 @@ func (h *Handle) SaveMetaToStorage(tableID, count, modifyCount int64, source str return err } -func (h *Handle) histogramFromStorage(reader *statsReader, tableID int64, colID int64, tp *types.FieldType, distinct int64, isIndex int, ver uint64, nullCount int64, totColSize int64, corr float64) (_ *statistics.Histogram, err error) { - rows, fields, err := reader.read("select count, repeats, lower_bound, upper_bound, ndv from mysql.stats_buckets where table_id = %? and is_index = %? and hist_id = %? order by bucket_id", tableID, isIndex, colID) +func (h *Handle) histogramFromStorage(reader *statistics.StatsReader, tableID int64, colID int64, tp *types.FieldType, distinct int64, isIndex int, ver uint64, nullCount int64, totColSize int64, corr float64) (_ *statistics.Histogram, err error) { + rows, fields, err := reader.Read("select count, repeats, lower_bound, upper_bound, ndv from mysql.stats_buckets where table_id = %? and is_index = %? and hist_id = %? order by bucket_id", tableID, isIndex, colID) if err != nil { return nil, errors.Trace(err) } @@ -1961,9 +1961,9 @@ func (h *Handle) histogramFromStorage(reader *statsReader, tableID int64, colID return hg, nil } -func (h *Handle) columnCountFromStorage(reader *statsReader, tableID, colID, statsVer int64) (int64, error) { +func (h *Handle) columnCountFromStorage(reader *statistics.StatsReader, tableID, colID, statsVer int64) (int64, error) { count := int64(0) - rows, _, err := reader.read("select sum(count) from mysql.stats_buckets where table_id = %? and is_index = 0 and hist_id = %?", tableID, colID) + rows, _, err := reader.Read("select sum(count) from mysql.stats_buckets where table_id = %? and is_index = 0 and hist_id = %?", tableID, colID) if err != nil { return 0, errors.Trace(err) } @@ -1979,7 +1979,7 @@ func (h *Handle) columnCountFromStorage(reader *statsReader, tableID, colID, sta // Before stats ver 2, histogram represents all data in this column. // In stats ver 2, histogram + TopN represent all data in this column. // So we need to add TopN total count here. - rows, _, err = reader.read("select sum(count) from mysql.stats_top_n where table_id = %? and is_index = 0 and hist_id = %?", tableID, colID) + rows, _, err = reader.Read("select sum(count) from mysql.stats_top_n where table_id = %? and is_index = 0 and hist_id = %?", tableID, colID) if err != nil { return 0, errors.Trace(err) } @@ -2014,26 +2014,7 @@ func (h *Handle) statsMetaByTableIDFromStorage(tableID int64, snapshot uint64) ( return } -// statsReader is used for simplify code that needs to read system tables in different sqls -// but requires the same transactions. -type statsReader struct { - ctx sqlexec.RestrictedSQLExecutor - snapshot uint64 -} - -func (sr *statsReader) read(sql string, args ...interface{}) (rows []chunk.Row, fields []*ast.ResultField, err error) { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - if sr.snapshot > 0 { - return sr.ctx.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseSessionPool, sqlexec.ExecOptionWithSnapshot(sr.snapshot)}, sql, args...) - } - return sr.ctx.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseCurSession}, sql, args...) -} - -func (sr *statsReader) isHistory() bool { - return sr.snapshot > 0 -} - -func (h *Handle) getGlobalStatsReader(snapshot uint64) (reader *statsReader, err error) { +func (h *Handle) getGlobalStatsReader(snapshot uint64) (reader *statistics.StatsReader, err error) { h.mu.Lock() defer func() { if r := recover(); r != nil { @@ -2043,44 +2024,12 @@ func (h *Handle) getGlobalStatsReader(snapshot uint64) (reader *statsReader, err h.mu.Unlock() } }() - return h.getStatsReader(snapshot, h.mu.ctx.(sqlexec.RestrictedSQLExecutor)) + return statistics.GetStatsReader(snapshot, h.mu.ctx.(sqlexec.RestrictedSQLExecutor)) } -func (h *Handle) releaseGlobalStatsReader(reader *statsReader) error { +func (h *Handle) releaseGlobalStatsReader(reader *statistics.StatsReader) error { defer h.mu.Unlock() - return h.releaseStatsReader(reader, h.mu.ctx.(sqlexec.RestrictedSQLExecutor)) -} - -func (h *Handle) getStatsReader(snapshot uint64, exec sqlexec.RestrictedSQLExecutor) (reader *statsReader, err error) { - failpoint.Inject("mockGetStatsReaderFail", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, errors.New("gofail genStatsReader error")) - } - }) - if snapshot > 0 { - return &statsReader{ctx: exec, snapshot: snapshot}, nil - } - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("getStatsReader panic %v", r) - } - }() - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - failpoint.Inject("mockGetStatsReaderPanic", nil) - _, err = exec.(sqlexec.SQLExecutor).ExecuteInternal(ctx, "begin") - if err != nil { - return nil, err - } - return &statsReader{ctx: exec}, nil -} - -func (h *Handle) releaseStatsReader(reader *statsReader, exec sqlexec.RestrictedSQLExecutor) error { - if reader.snapshot > 0 { - return nil - } - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - _, err := exec.(sqlexec.SQLExecutor).ExecuteInternal(ctx, "commit") - return err + return reader.Close() } const ( diff --git a/statistics/handle/handle_hist.go b/statistics/handle/handle_hist.go index ad04d946e3f22..1d41e14791446 100644 --- a/statistics/handle/handle_hist.go +++ b/statistics/handle/handle_hist.go @@ -177,7 +177,7 @@ var errExit = errors.New("Stop loading since domain is closed") // StatsReaderContext exported for testing type StatsReaderContext struct { - reader *statsReader + reader *statistics.StatsReader createdTime time.Time } @@ -188,7 +188,7 @@ func (h *Handle) SubLoadWorker(ctx sessionctx.Context, exit chan struct{}, exitW exitWg.Done() logutil.BgLogger().Info("SubLoadWorker exited.") if readerCtx.reader != nil { - err := h.releaseStatsReader(readerCtx.reader, ctx.(sqlexec.RestrictedSQLExecutor)) + err := readerCtx.reader.Close() if err != nil { logutil.BgLogger().Error("Fail to release stats loader: ", zap.Error(err)) } @@ -295,13 +295,13 @@ func (h *Handle) handleOneItemTask(task *NeededItemTask, readerCtx *StatsReaderC func (h *Handle) loadFreshStatsReader(readerCtx *StatsReaderContext, ctx sqlexec.RestrictedSQLExecutor) { if readerCtx.reader == nil || readerCtx.createdTime.Add(h.Lease()).Before(time.Now()) { if readerCtx.reader != nil { - err := h.releaseStatsReader(readerCtx.reader, ctx) + err := readerCtx.reader.Close() if err != nil { logutil.BgLogger().Warn("Fail to release stats loader: ", zap.Error(err)) } } for { - newReader, err := h.getStatsReader(0, ctx) + newReader, err := statistics.GetStatsReader(0, ctx) if err != nil { logutil.BgLogger().Error("Fail to new stats loader, retry after a while.", zap.Error(err)) time.Sleep(h.Lease() / 10) @@ -317,7 +317,7 @@ func (h *Handle) loadFreshStatsReader(readerCtx *StatsReaderContext, ctx sqlexec } // readStatsForOneItem reads hist for one column/index, TODO load data via kv-get asynchronously -func (h *Handle) readStatsForOneItem(item model.TableItemID, w *statsWrapper, reader *statsReader) (*statsWrapper, error) { +func (h *Handle) readStatsForOneItem(item model.TableItemID, w *statsWrapper, reader *statistics.StatsReader) (*statsWrapper, error) { failpoint.Inject("mockReadStatsForOnePanic", nil) failpoint.Inject("mockReadStatsForOneFail", func(val failpoint.Value) { if val.(bool) { @@ -357,7 +357,7 @@ func (h *Handle) readStatsForOneItem(item model.TableItemID, w *statsWrapper, re return nil, errors.Trace(err) } } - rows, _, err := reader.read("select stats_ver from mysql.stats_histograms where table_id = %? and hist_id = %? and is_index = %?", item.TableID, item.ID, int(isIndexFlag)) + rows, _, err := reader.Read("select stats_ver from mysql.stats_histograms where table_id = %? and hist_id = %? and is_index = %?", item.TableID, item.ID, int(isIndexFlag)) if err != nil { return nil, errors.Trace(err) } diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index 2b0669033f8c9..dc399a87fcad3 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -622,16 +622,16 @@ func TestLoadStats(t *testing.T) { require.True(t, idx.IsFullLoad()) // Following test tests whether the LoadNeededHistograms would panic. - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/statistics/handle/mockGetStatsReaderFail", `return(true)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/statistics/mockGetStatsReaderFail", `return(true)`)) err = h.LoadNeededHistograms() require.Error(t, err) - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/statistics/handle/mockGetStatsReaderFail")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/statistics/mockGetStatsReaderFail")) - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/statistics/handle/mockGetStatsReaderPanic", "panic")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/statistics/mockGetStatsReaderPanic", "panic")) err = h.LoadNeededHistograms() require.Error(t, err) require.Regexp(t, ".*getStatsReader panic.*", err.Error()) - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/statistics/handle/mockGetStatsReaderPanic")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/statistics/mockGetStatsReaderPanic")) err = h.LoadNeededHistograms() require.NoError(t, err) } diff --git a/statistics/interact_with_storage.go b/statistics/interact_with_storage.go new file mode 100644 index 0000000000000..478b845937067 --- /dev/null +++ b/statistics/interact_with_storage.go @@ -0,0 +1,86 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package statistics + +import ( + "context" + "fmt" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sqlexec" +) + +// StatsReader is used for simplifying code that needs to read statistics from system tables(mysql.stats_xxx) in different sqls +// but requires the same transactions. +// +// Note that: +// 1. Remember to call (*StatsReader).Close after reading all statistics. +// 2. StatsReader is not thread-safe. Different goroutines cannot call (*StatsReader).Read concurrently. +type StatsReader struct { + ctx sqlexec.RestrictedSQLExecutor + snapshot uint64 +} + +// GetStatsReader returns a StatsReader. +func GetStatsReader(snapshot uint64, exec sqlexec.RestrictedSQLExecutor) (reader *StatsReader, err error) { + failpoint.Inject("mockGetStatsReaderFail", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, errors.New("gofail genStatsReader error")) + } + }) + if snapshot > 0 { + return &StatsReader{ctx: exec, snapshot: snapshot}, nil + } + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("getStatsReader panic %v", r) + } + }() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + failpoint.Inject("mockGetStatsReaderPanic", nil) + _, err = exec.(sqlexec.SQLExecutor).ExecuteInternal(ctx, "begin") + if err != nil { + return nil, err + } + return &StatsReader{ctx: exec}, nil +} + +// Read is a thin wrapper reading statistics from storage by sql command. +func (sr *StatsReader) Read(sql string, args ...interface{}) (rows []chunk.Row, fields []*ast.ResultField, err error) { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + if sr.snapshot > 0 { + return sr.ctx.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseSessionPool, sqlexec.ExecOptionWithSnapshot(sr.snapshot)}, sql, args...) + } + return sr.ctx.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseCurSession}, sql, args...) +} + +// IsHistory indicates whether to read history statistics. +func (sr *StatsReader) IsHistory() bool { + return sr.snapshot > 0 +} + +// Close closes the StatsReader. +func (sr *StatsReader) Close() error { + if sr.IsHistory() || sr.ctx == nil { + return nil + } + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + _, err := sr.ctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, "commit") + return err +}