From bebb3dabb4c8bde0a9c8b04d609f954881368c4e Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Tue, 7 Dec 2021 19:20:21 +0800 Subject: [PATCH 1/8] done Signed-off-by: wjhuang2016 --- expression/integration_test.go | 5 +++++ planner/core/expression_rewriter.go | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/expression/integration_test.go b/expression/integration_test.go index 61a72a9ce49f1..7ae9b2d4a6665 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -6663,6 +6663,11 @@ func (s *testIntegrationSerialSuite) TestCollationBasic(c *C) { tk.MustQuery("select * from t1 where col1 >= 0xc484 and col1 <= 0xc3b3;").Check(testkit.Rows("Ȇ")) tk.MustQuery("select collation(IF('a' < 'B' collate utf8mb4_general_ci, 'smaller', 'greater' collate utf8mb4_unicode_ci));").Check(testkit.Rows("utf8mb4_unicode_ci")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a char(10))") + tk.MustExec("insert into t values ('a')") + tk.MustQuery("select * from t where a in ('b' collate utf8mb4_general_ci, 'A', 3)").Check(testkit.Rows("a")) } func (s *testIntegrationSerialSuite) TestWeightString(c *C) { diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index ab57b29c2096d..b550b03a8aa99 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1481,6 +1481,11 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field if allSameType && l == 1 && lLen > 1 { function = er.notToExpression(not, ast.In, tp, er.ctxStack[stkLen-lLen-1:]...) } else { + coll, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...) + er.err = err + if er.err != nil { + return + } eqFunctions := make([]expression.Expression, 0, lLen) for i := stkLen - lLen; i < stkLen; i++ { expr, err := er.constructBinaryOpFunction(args[0], er.ctxStack[i], ast.EQ) @@ -1488,6 +1493,7 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field er.err = err return } + expr.SetCharsetAndCollation(coll.Charset, coll.Collation) eqFunctions = append(eqFunctions, expr) } function = expression.ComposeDNFCondition(er.sctx, eqFunctions...) From b3740d2d746f14d216e8fe8936e26fb781c42c7c Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Tue, 7 Dec 2021 20:08:29 +0800 Subject: [PATCH 2/8] fix Signed-off-by: wjhuang2016 --- planner/core/expression_rewriter.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index b550b03a8aa99..54b3370d15505 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1488,12 +1488,17 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field } eqFunctions := make([]expression.Expression, 0, lLen) for i := stkLen - lLen; i < stkLen; i++ { + if er.ctxStack[i].GetType().EvalType() == types.ETString { + tp := er.ctxStack[i].GetType().Clone() + tp.Charset, tp.Collate = coll.Charset, coll.Collation + er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp) + er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit) + } expr, err := er.constructBinaryOpFunction(args[0], er.ctxStack[i], ast.EQ) if err != nil { er.err = err return } - expr.SetCharsetAndCollation(coll.Charset, coll.Collation) eqFunctions = append(eqFunctions, expr) } function = expression.ComposeDNFCondition(er.sctx, eqFunctions...) From 7c443f397e79b3836352d5345acc193cf6ce0b7d Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Wed, 8 Dec 2021 15:26:14 +0800 Subject: [PATCH 3/8] try fix Signed-off-by: wjhuang2016 --- planner/core/expression_rewriter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 54b3370d15505..7ec6d9ba64c66 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1488,7 +1488,7 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field } eqFunctions := make([]expression.Expression, 0, lLen) for i := stkLen - lLen; i < stkLen; i++ { - if er.ctxStack[i].GetType().EvalType() == types.ETString { + if er.ctxStack[i].GetType().EvalType() == types.ETString && l == 1 { tp := er.ctxStack[i].GetType().Clone() tp.Charset, tp.Collate = coll.Charset, coll.Collation er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp) From a43ef03c71c7de95e60002d9b65556f116b3bdfe Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Wed, 8 Dec 2021 17:14:05 +0800 Subject: [PATCH 4/8] save Signed-off-by: wjhuang2016 --- expression/builtin_cast.go | 2 + expression/builtin_convert_charset.go | 2 +- planner/core/expression_rewriter.go | 63 +++++++++++++++++++++++---- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 0ec4ad904dee4..318afd77a181b 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -283,6 +283,8 @@ func (c *castAsStringFunctionClass) getFunction(ctx sessionctx.Context, args []E if err != nil { return nil, err } + + bf.args[0] = HandleBinaryLiteral(ctx, args[0], &ExprCollation{Charset: c.tp.Charset, Collation: c.tp.Collate}, c.funcName) bf.tp = c.tp if args[0].GetType().Hybrid() || IsBinaryLiteral(args[0]) { sig = &builtinCastStringAsStringSig{bf} diff --git a/expression/builtin_convert_charset.go b/expression/builtin_convert_charset.go index 4101d00a3b66b..da2d1d82e430e 100644 --- a/expression/builtin_convert_charset.go +++ b/expression/builtin_convert_charset.go @@ -263,7 +263,7 @@ func HandleBinaryLiteral(ctx sessionctx.Context, expr Expression, ec *ExprCollat ast.Left, ast.Right, ast.Repeat, ast.Trim, ast.LTrim, ast.RTrim, ast.Substr, ast.SubstringIndex, ast.Replace, ast.Substring, ast.Mid, ast.Translate, ast.InsertFunc, ast.Lpad, ast.Rpad, ast.Elt, ast.ExportSet, ast.MakeSet, ast.FindInSet, ast.Regexp, ast.Field, ast.Locate, ast.Instr, ast.Position, ast.GE, ast.LE, ast.GT, ast.LT, ast.EQ, - ast.NE, ast.NullEQ, ast.Strcmp, ast.If, ast.Ifnull, ast.Like, ast.In, ast.DateFormat, ast.TimeFormat: + ast.NE, ast.NullEQ, ast.Strcmp, ast.If, ast.Ifnull, ast.Like, ast.In, ast.DateFormat, ast.TimeFormat, ast.Cast: if ec.Charset == charset.CharsetBin && expr.GetType().Charset != charset.CharsetBin { return BuildToBinaryFunction(ctx, expr) } else if ec.Charset != charset.CharsetBin && expr.GetType().Charset == charset.CharsetBin { diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 4deb03929be23..dd024419973ae 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1481,19 +1481,14 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field if allSameType && l == 1 && lLen > 1 { function = er.notToExpression(not, ast.In, tp, er.ctxStack[stkLen-lLen-1:]...) } else { - coll, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...) - er.err = err + // If we rewrite IN to EQ, we need to decide what's the collation EQ uses. + coll := er.deriveCollationForIn(l, lLen, stkLen, args) if er.err != nil { return } + er.castCollationForIn(l, lLen, stkLen, coll) eqFunctions := make([]expression.Expression, 0, lLen) for i := stkLen - lLen; i < stkLen; i++ { - if er.ctxStack[i].GetType().EvalType() == types.ETString && l == 1 { - tp := er.ctxStack[i].GetType().Clone() - tp.Charset, tp.Collate = coll.Charset, coll.Collation - er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp) - er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit) - } expr, err := er.constructBinaryOpFunction(args[0], er.ctxStack[i], ast.EQ) if err != nil { er.err = err @@ -1515,6 +1510,58 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field er.ctxStackAppend(function, types.EmptyName) } +// deriveCollationForIn derive collation for in expression. +func (er *expressionRewriter) deriveCollationForIn(l int, lLen int, stkLen int, args []expression.Expression) []*expression.ExprCollation { + coll := make([]*expression.ExprCollation, 0, l) + if l == 1 { + // a in (x, y, z) => coll[0] + coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...) + er.err = err + if er.err != nil { + return nil + } + coll = append(coll, coll2) + } else { + // (a, b, c) in ((x1, x2, x3), (y1, y2, y3), (z1, z2, z3)) => coll[0], coll[1], coll[2] + for i := 0; i < lLen; i++ { + args := make([]expression.Expression, 0, lLen) + for j := stkLen - lLen - 1; j < stkLen; j++ { + rowFunc, _ := er.ctxStack[j].(*expression.ScalarFunction) + args = append(args, rowFunc.GetArgs()[i]) + } + coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...) + er.err = err + if er.err != nil { + return nil + } + coll = append(coll, coll2) + } + } + return coll +} + +func (er *expressionRewriter) castCollationForIn(l int, lLen int, stkLen int, coll []*expression.ExprCollation) { + for i := stkLen - lLen; i < stkLen; i++ { + if l == 1 && er.ctxStack[i].GetType().EvalType() == types.ETString { + tp := er.ctxStack[i].GetType().Clone() + tp.Charset, tp.Collate = coll[0].Charset, coll[0].Collation + er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp) + er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit) + } else { + rowFunc, _ := er.ctxStack[i].(*expression.ScalarFunction) + for j := 0; j < lLen; j++ { + if er.ctxStack[i].GetType().EvalType() != types.ETString { + continue + } + tp := rowFunc.GetArgs()[j].GetType().Clone() + tp.Charset, tp.Collate = coll[j].Charset, coll[j].Collation + rowFunc.GetArgs()[j] = expression.BuildCastFunction(er.sctx, rowFunc.GetArgs()[j], tp) + rowFunc.GetArgs()[j].SetCoercibility(expression.CoercibilityExplicit) + } + } + } +} + func (er *expressionRewriter) caseToExpression(v *ast.CaseExpr) { stkLen := len(er.ctxStack) argsLen := 2 * len(v.WhenClauses) From 9829562e6189aa402724aef146891ef50ee8be91 Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Tue, 21 Dec 2021 15:44:24 +0800 Subject: [PATCH 5/8] refine Signed-off-by: wjhuang2016 --- expression/integration_serial_test.go | 5 +++++ planner/core/expression_rewriter.go | 22 ++++++++++++---------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/expression/integration_serial_test.go b/expression/integration_serial_test.go index 3be189d90246a..3077e2f1b33a7 100644 --- a/expression/integration_serial_test.go +++ b/expression/integration_serial_test.go @@ -180,6 +180,11 @@ func TestCollationBasic(t *testing.T) { tk.MustExec("create table t(a char(10))") tk.MustExec("insert into t values ('a')") tk.MustQuery("select * from t where a in ('b' collate utf8mb4_general_ci, 'A', 3)").Check(testkit.Rows("a")) + // These test cases may not the same as MySQL, but it's more reasonable. + tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_general_ci));").Check(testkit.Rows("1")) + tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_bin));").Check(testkit.Rows("0")) + tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_general_ci), ('b', 'b'));").Check(testkit.Rows("1")) + tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_bin), ('b', 'b'));").Check(testkit.Rows("0")) } func TestWeightString(t *testing.T) { diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index e2f657c054b2b..f0f26ba73c56d 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1522,9 +1522,9 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field } // deriveCollationForIn derive collation for in expression. -func (er *expressionRewriter) deriveCollationForIn(l int, lLen int, stkLen int, args []expression.Expression) []*expression.ExprCollation { - coll := make([]*expression.ExprCollation, 0, l) - if l == 1 { +func (er *expressionRewriter) deriveCollationForIn(colLen int, elemCnt int, stkLen int, args []expression.Expression) []*expression.ExprCollation { + coll := make([]*expression.ExprCollation, 0, colLen) + if colLen == 1 { // a in (x, y, z) => coll[0] coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...) er.err = err @@ -1534,9 +1534,9 @@ func (er *expressionRewriter) deriveCollationForIn(l int, lLen int, stkLen int, coll = append(coll, coll2) } else { // (a, b, c) in ((x1, x2, x3), (y1, y2, y3), (z1, z2, z3)) => coll[0], coll[1], coll[2] - for i := 0; i < lLen; i++ { - args := make([]expression.Expression, 0, lLen) - for j := stkLen - lLen - 1; j < stkLen; j++ { + for i := 0; i < elemCnt; i++ { + args := make([]expression.Expression, 0, elemCnt) + for j := stkLen - elemCnt - 1; j < stkLen; j++ { rowFunc, _ := er.ctxStack[j].(*expression.ScalarFunction) args = append(args, rowFunc.GetArgs()[i]) } @@ -1551,16 +1551,18 @@ func (er *expressionRewriter) deriveCollationForIn(l int, lLen int, stkLen int, return coll } -func (er *expressionRewriter) castCollationForIn(l int, lLen int, stkLen int, coll []*expression.ExprCollation) { - for i := stkLen - lLen; i < stkLen; i++ { - if l == 1 && er.ctxStack[i].GetType().EvalType() == types.ETString { +// castCollationForIn casts collation info for arguments in the `in clause` to make sure the used collation is correct after we +// rewrite it to equal expression. +func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen int, coll []*expression.ExprCollation) { + for i := stkLen - elemCnt; i < stkLen; i++ { + if colLen == 1 && er.ctxStack[i].GetType().EvalType() == types.ETString { tp := er.ctxStack[i].GetType().Clone() tp.Charset, tp.Collate = coll[0].Charset, coll[0].Collation er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp) er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit) } else { rowFunc, _ := er.ctxStack[i].(*expression.ScalarFunction) - for j := 0; j < lLen; j++ { + for j := 0; j < elemCnt; j++ { if er.ctxStack[i].GetType().EvalType() != types.ETString { continue } From a6ce5df8c22dc63131e2c2fd3708c2f6322c449c Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Tue, 21 Dec 2021 16:38:24 +0800 Subject: [PATCH 6/8] fix Signed-off-by: wjhuang2016 --- planner/core/expression_rewriter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index f0f26ba73c56d..2c9f55b80122c 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1534,7 +1534,7 @@ func (er *expressionRewriter) deriveCollationForIn(colLen int, elemCnt int, stkL coll = append(coll, coll2) } else { // (a, b, c) in ((x1, x2, x3), (y1, y2, y3), (z1, z2, z3)) => coll[0], coll[1], coll[2] - for i := 0; i < elemCnt; i++ { + for i := 0; i < colLen; i++ { args := make([]expression.Expression, 0, elemCnt) for j := stkLen - elemCnt - 1; j < stkLen; j++ { rowFunc, _ := er.ctxStack[j].(*expression.ScalarFunction) From c1feda63f5c254c6727aca652dc4715281cafa41 Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Tue, 21 Dec 2021 16:57:28 +0800 Subject: [PATCH 7/8] fix Signed-off-by: wjhuang2016 --- planner/core/expression_rewriter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 2c9f55b80122c..e1619785ec1a4 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1562,7 +1562,7 @@ func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit) } else { rowFunc, _ := er.ctxStack[i].(*expression.ScalarFunction) - for j := 0; j < elemCnt; j++ { + for j := 0; j < colLen; j++ { if er.ctxStack[i].GetType().EvalType() != types.ETString { continue } From 1d86a9c7451ccb1d82ee327c8779f97dfa7dc7aa Mon Sep 17 00:00:00 2001 From: wjhuang2016 Date: Tue, 21 Dec 2021 17:54:58 +0800 Subject: [PATCH 8/8] fix grammar Signed-off-by: wjhuang2016 --- planner/core/expression_rewriter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index e1619785ec1a4..08e0262613cb9 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1521,7 +1521,7 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field er.ctxStackAppend(function, types.EmptyName) } -// deriveCollationForIn derive collation for in expression. +// deriveCollationForIn derives collation for in expression. func (er *expressionRewriter) deriveCollationForIn(colLen int, elemCnt int, stkLen int, args []expression.Expression) []*expression.ExprCollation { coll := make([]*expression.ExprCollation, 0, colLen) if colLen == 1 {