diff --git a/reader/sql/sql.go b/reader/sql/sql.go index 5ec5ca48b..54d899647 100644 --- a/reader/sql/sql.go +++ b/reader/sql/sql.go @@ -719,7 +719,7 @@ func checkMagic(rawSql string) (valid bool) { } func (r *Reader) Name() string { - return strings.ToUpper(r.dbtype) + "_Reader:" + r.database + "_" + Hash(r.rawsqls) + return strings.ToUpper(r.dbtype) + "_Reader:" + r.rawDatabase + "_" + Hash(r.rawsqls) } func (r *Reader) setStatsError(err string) { @@ -889,7 +889,7 @@ func (r *Reader) run() { } // 开始work逻辑 for { - if atomic.LoadInt32(&r.status) == reader.StatusStopping { + if atomic.LoadInt32(&r.status) == reader.StatusStopping || atomic.LoadInt32(&r.status) == reader.StatusStopped { log.Warnf("Runner[%v] %v stopped from running", r.meta.RunnerName, r.Name()) return } @@ -1015,7 +1015,8 @@ func (r *Reader) exec(connectStr string) (err error) { } err = r.execReadDB(currentDB, now, recordTablesDone) if err != nil { - log.Errorf("Runner[%v] %v exect read db: %v error: %v", r.meta.RunnerName, currentDB, currentDB, err) + log.Errorf("Runner[%v] %v exect read db: %v error: %v,will retry read it", r.meta.RunnerName, currentDB, currentDB, err) + return err } if atomic.LoadInt32(&r.status) == reader.StatusStopping || atomic.LoadInt32(&r.status) == reader.StatusStopped { log.Warnf("Runner[%v] %v stopped from running", r.meta.RunnerName, currentDB) @@ -1049,16 +1050,7 @@ func (r *Reader) execCountDB(curDB string, now time.Time, recordTablesDone Table if err != nil { return err } - db, err := openSql(r.dbtype, connectStr, curDB) - if err != nil { - return err - } - defer func() { - db.Close() - }() - if err = db.Ping(); err != nil { - return err - } + log.Infof("Runner[%v] prepare %v change database success, current database is: %v", r.meta.RunnerName, r.dbtype, curDB) //更新sqls @@ -1066,7 +1058,7 @@ func (r *Reader) execCountDB(curDB string, now time.Time, recordTablesDone Table var sqls string if r.rawsqls == "" { // 获取符合条件的数据表,并且将计算表中记录数的query语句赋给 r.rawsqls - tables, sqls, err = r.getDatas(db, curDB, r.rawTable, now, COUNT) + tables, sqls, err = r.getDatas(connectStr, curDB, r.rawTable, now, COUNT) if err != nil { return err } @@ -1096,7 +1088,7 @@ func (r *Reader) execCountDB(curDB string, now time.Time, recordTablesDone Table // 每张表的记录数 var tableSize int64 - tableSize, err = r.execTableCount(db, idx, curDB, rawSql) + tableSize, err = r.execTableCount(connectStr, idx, curDB, rawSql) if err != nil { return err } @@ -1118,16 +1110,7 @@ func (r *Reader) execReadDB(curDB string, now time.Time, recordTablesDone TableR if err != nil { return err } - db, err := openSql(r.dbtype, connectStr, r.Name()) - if err != nil { - return err - } - defer func() { - db.Close() - }() - if err = db.Ping(); err != nil { - return err - } + log.Infof("Runner[%v] %v prepare %v change database success", r.meta.RunnerName, curDB, r.dbtype) r.database = curDB @@ -1136,7 +1119,7 @@ func (r *Reader) execReadDB(curDB string, now time.Time, recordTablesDone TableR var sqls string if r.rawsqls == "" { // 获取符合条件的数据表,并且将获取表中所有记录的语句赋给 r.rawsqls - tables, sqls, err = r.getDatas(db, curDB, r.rawTable, now, TABLE) + tables, sqls, err = r.getDatas(connectStr, curDB, r.rawTable, now, TABLE) if err != nil { log.Errorf("Runner[%v] %v rawTable: %v get tables and sqls error %v", r.meta.RunnerName, r.Name(), r.rawTable, r.rawsqls, err) if len(tables) == 0 && sqls == "" { @@ -1181,8 +1164,10 @@ func (r *Reader) execReadDB(curDB string, now time.Time, recordTablesDone TableR } } // 执行每条 sql 语句 - exit, isRawSql, readSize = r.execReadSql(db, idx, rawSql, tables) - + exit, isRawSql, readSize, err = r.execReadSql(connectStr, curDB, idx, rawSql, tables) + if err != nil { + return err + } if r.rawsqls == "" { tmpTablesRecords.SetTableInfo(tableName, TableInfo{size: readSize, offset: -1}) r.syncRecords.SetTableRecords(curDB, tmpTablesRecords) @@ -1640,7 +1625,7 @@ type DataQuery struct { sqls string } -func (r *Reader) getValidData(db *sql.DB, curDB, matchData, matchStr string, +func (r *Reader) getValidData(connectStr, curDB, matchData, matchStr string, startIndex, endIndex, timeIndex []int, queryType int) (validData []string, sqls string, err error) { // get all databases and check validate database query, err := r.getQuery(queryType, curDB) @@ -1648,6 +1633,17 @@ func (r *Reader) getValidData(db *sql.DB, curDB, matchData, matchStr string, return validData, sqls, err } + db, err := openSql(r.dbtype, connectStr, r.Name()) + if err != nil { + return nil, "", err + } + defer func() { + db.Close() + }() + if err = db.Ping(); err != nil { + return nil, "", err + } + rowsDBs, err := db.Query(query) if err != nil { log.Errorf("Runner[%v] %v prepare %v <%v> query error %v", r.meta.RunnerName, curDB, r.dbtype, query, err) @@ -1785,7 +1781,7 @@ func getDefaultSql(database, dbtype string) (defaultSql string, err error) { // 根据queryType获取符合要求的数据和需要执行的原始sql语句mr.rawsqls // queryType 可以为TABLE DATABASE COUNT -func (r *Reader) getDatas(db *sql.DB, curDB, rawData string, now time.Time, queryType int) (datas []string, rawsqls string, err error) { +func (r *Reader) getDatas(connectStr, curDB, rawData string, now time.Time, queryType int) (datas []string, rawsqls string, err error) { var startIndex, endIndex, timeIndex []int var matchData string @@ -1796,7 +1792,7 @@ func (r *Reader) getDatas(db *sql.DB, curDB, rawData string, now time.Time, quer } if checkAll { // 导入所有数据 - datas, rawsqls, err = r.getAllDatas(db, curDB, queryType) + datas, rawsqls, err = r.getAllDatas(connectStr, curDB, queryType) if err != nil { return datas, rawsqls, err } @@ -1819,7 +1815,7 @@ func (r *Reader) getDatas(db *sql.DB, curDB, rawData string, now time.Time, quer } matchStr := getRemainStr(matchData, timeIndex) - datas, rawsqls, err = r.getValidData(db, curDB, matchData, matchStr, startIndex, endIndex, timeIndex, queryType) + datas, rawsqls, err = r.getValidData(connectStr, curDB, matchData, matchStr, startIndex, endIndex, timeIndex, queryType) if err != nil { return datas, rawsqls, err } @@ -1875,13 +1871,25 @@ func (r *Reader) getQuery(queryType int, curDB string) (query string, err error) } // 计算每个table的记录条数 -func (r *Reader) execTableCount(db *sql.DB, idx int, curDB, rawSql string) (tableSize int64, err error) { +func (r *Reader) execTableCount(connectStr string, idx int, curDB, rawSql string) (tableSize int64, err error) { execSQL, err := r.getSQL(idx, rawSql) if err != nil { log.Errorf("Runner[%v] get SQL error %v, use raw SQL", r.meta.RunnerName, err) execSQL = rawSql } log.Infof("Runner[%v] reader <%v> exec sql <%v>", r.meta.RunnerName, curDB, execSQL) + + db, err := openSql(r.dbtype, connectStr, curDB) + if err != nil { + return 0, err + } + defer func() { + db.Close() + }() + if err = db.Ping(); err != nil { + return 0, err + } + rows, err := db.Query(execSQL) if err != nil { log.Errorf("Runner[%v] %v prepare %v <%v> query error %v", r.meta.RunnerName, curDB, r.dbtype, execSQL, err) @@ -1909,7 +1917,7 @@ func (r *Reader) execTableCount(db *sql.DB, idx int, curDB, rawSql string) (tabl } // 执行每条 sql 语句 -func (r *Reader) execReadSql(db *sql.DB, idx int, rawSql string, tables []string) (exit bool, isRawSql bool, readSize int64) { +func (r *Reader) execReadSql(connectStr, curDB string, idx int, rawSql string, tables []string) (exit bool, isRawSql bool, readSize int64, err error) { exit = true execSQL, err := r.getSQL(idx, r.syncSQLs[idx]) @@ -1922,13 +1930,24 @@ func (r *Reader) execReadSql(db *sql.DB, idx int, rawSql string, tables []string isRawSql = true } + db, err := openSql(r.dbtype, connectStr, curDB) + if err != nil { + return exit, isRawSql, 0, err + } + defer func() { + db.Close() + }() + if err = db.Ping(); err != nil { + return exit, isRawSql, 0, err + } + log.Infof("Runner[%v] reader <%v> exec sql <%v>", r.meta.RunnerName, r.Name(), execSQL) rows, err := db.Query(execSQL) if err != nil { err = fmt.Errorf("runner[%v] %v prepare %v <%v> query error %v", r.meta.RunnerName, r.Name(), r.dbtype, execSQL, err) log.Error(err) r.sendError(err) - return exit, isRawSql, readSize + return exit, isRawSql, readSize, err } defer rows.Close() // Get column names @@ -1937,7 +1956,7 @@ func (r *Reader) execReadSql(db *sql.DB, idx int, rawSql string, tables []string err = fmt.Errorf("runner[%v] %v prepare %v <%v> columns error %v", r.meta.RunnerName, r.Name(), r.dbtype, execSQL, err) log.Error(err) r.sendError(err) - return exit, isRawSql, readSize + return exit, isRawSql, readSize, err } log.Infof("Runner[%v] SQL :<%v>, schemas: <%v>", r.meta.RunnerName, execSQL, strings.Join(columns, ", ")) scanArgs, nochiced := r.getInitScans(len(columns), rows, r.dbtype) @@ -2053,7 +2072,7 @@ func (r *Reader) execReadSql(db *sql.DB, idx int, rawSql string, tables []string } if atomic.LoadInt32(&r.status) == reader.StatusStopping || atomic.LoadInt32(&r.status) == reader.StatusStopped { log.Warnf("Runner[%v] %v stopped from running", r.meta.RunnerName, r.Name()) - return exit, isRawSql, readSize + return exit, isRawSql, readSize, nil } r.readChan <- readInfo{data, totalBytes} r.CurrentCount++ @@ -2082,12 +2101,12 @@ func (r *Reader) execReadSql(db *sql.DB, idx int, rawSql string, tables []string } } - return exit, isRawSql, readSize + return exit, isRawSql, readSize, rows.Err() } -func (r *Reader) getAllDatas(db *sql.DB, curDB string, queryType int) (datas []string, sqls string, err error) { +func (r *Reader) getAllDatas(connectStr, curDB string, queryType int) (datas []string, sqls string, err error) { // 拿到数据库中所有表及对应的sql语句 - datas, sqls, err = r.getValidData(db, curDB, "", "", []int{}, []int{}, []int{}, queryType) + datas, sqls, err = r.getValidData(connectStr, curDB, "", "", []int{}, []int{}, []int{}, queryType) if err != nil { return datas, sqls, err } @@ -2185,7 +2204,7 @@ func (r *Reader) getDBs(connectStr string, now time.Time) ([]string, error) { if err = db.Ping(); err != nil { return nil, err } - dbsAll, _, err := r.getDatas(db, "", r.rawDatabase, now, DATABASE) + dbsAll, _, err := r.getDatas(connectStr, "", r.rawDatabase, now, DATABASE) if err != nil { return dbsAll, err } diff --git a/reader/sql/sql_test.go b/reader/sql/sql_test.go index 86eb05d86..d60f76cc9 100644 --- a/reader/sql/sql_test.go +++ b/reader/sql/sql_test.go @@ -499,14 +499,15 @@ func TestSQLReader(t *testing.T) { defer os.RemoveAll(MetaDir) database := "TestSQLReaderdatabase" mr := &Reader{ - database: database, - rawsqls: "select * from mysql123 ;select * from mysql345;", - syncSQLs: []string{"select * from mysql123", "select * from mysql345"}, - readBatch: 100, - meta: meta, - offsetKey: "id", - offsets: []int64{123, 456}, - dbtype: "mysql", + rawDatabase: database, + database: database, + rawsqls: "select * from mysql123 ;select * from mysql345;", + syncSQLs: []string{"select * from mysql123", "select * from mysql345"}, + readBatch: 100, + meta: meta, + offsetKey: "id", + offsets: []int64{123, 456}, + dbtype: "mysql", } assert.Equal(t, mr.dbtype+"_"+database, mr.Source())