Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: fix wrong collation when rewrite in condition #30492

Merged
merged 13 commits into from
Dec 21, 2021
10 changes: 10 additions & 0 deletions expression/integration_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ func TestCollationBasic(t *testing.T) {
tk.MustQuery("select * from t1 where col1 >= 0xc484 and col1 <= 0xc3b3;").Check(testkit.Rows("Ȇ"))

tk.MustQuery("select collation(IF('a' < 'B' collate utf8mb4_general_ci, 'smaller', 'greater' collate utf8mb4_unicode_ci));").Check(testkit.Rows("utf8mb4_unicode_ci"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char(10))")
tk.MustExec("insert into t values ('a')")
tk.MustQuery("select * from t where a in ('b' collate utf8mb4_general_ci, 'A', 3)").Check(testkit.Rows("a"))
// These test cases may not the same as MySQL, but it's more reasonable.
tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_general_ci));").Check(testkit.Rows("1"))
tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_bin));").Check(testkit.Rows("0"))
tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_general_ci), ('b', 'b'));").Check(testkit.Rows("1"))
tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_bin), ('b', 'b'));").Check(testkit.Rows("0"))
}

func TestWeightString(t *testing.T) {
Expand Down
60 changes: 60 additions & 0 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,12 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field
if allSameType && l == 1 && lLen > 1 {
function = er.notToExpression(not, ast.In, tp, er.ctxStack[stkLen-lLen-1:]...)
} else {
// If we rewrite IN to EQ, we need to decide what's the collation EQ uses.
coll := er.deriveCollationForIn(l, lLen, stkLen, args)
if er.err != nil {
return
}
er.castCollationForIn(l, lLen, stkLen, coll)
eqFunctions := make([]expression.Expression, 0, lLen)
for i := stkLen - lLen; i < stkLen; i++ {
expr, err := er.constructBinaryOpFunction(args[0], er.ctxStack[i], ast.EQ)
Expand All @@ -1515,6 +1521,60 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field
er.ctxStackAppend(function, types.EmptyName)
}

// deriveCollationForIn derives collation for in expression.
func (er *expressionRewriter) deriveCollationForIn(colLen int, elemCnt int, stkLen int, args []expression.Expression) []*expression.ExprCollation {
coll := make([]*expression.ExprCollation, 0, colLen)
if colLen == 1 {
// a in (x, y, z) => coll[0]
coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...)
er.err = err
if er.err != nil {
return nil
}
coll = append(coll, coll2)
} else {
// (a, b, c) in ((x1, x2, x3), (y1, y2, y3), (z1, z2, z3)) => coll[0], coll[1], coll[2]
for i := 0; i < colLen; i++ {
args := make([]expression.Expression, 0, elemCnt)
for j := stkLen - elemCnt - 1; j < stkLen; j++ {
rowFunc, _ := er.ctxStack[j].(*expression.ScalarFunction)
args = append(args, rowFunc.GetArgs()[i])
}
coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...)
er.err = err
if er.err != nil {
return nil
}
coll = append(coll, coll2)
}
}
return coll
}

// castCollationForIn casts collation info for arguments in the `in clause` to make sure the used collation is correct after we
// rewrite it to equal expression.
func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen int, coll []*expression.ExprCollation) {
for i := stkLen - elemCnt; i < stkLen; i++ {
if colLen == 1 && er.ctxStack[i].GetType().EvalType() == types.ETString {
tp := er.ctxStack[i].GetType().Clone()
tp.Charset, tp.Collate = coll[0].Charset, coll[0].Collation
er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp)
er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit)
} else {
rowFunc, _ := er.ctxStack[i].(*expression.ScalarFunction)
for j := 0; j < colLen; j++ {
if er.ctxStack[i].GetType().EvalType() != types.ETString {
continue
}
tp := rowFunc.GetArgs()[j].GetType().Clone()
tp.Charset, tp.Collate = coll[j].Charset, coll[j].Collation
rowFunc.GetArgs()[j] = expression.BuildCastFunction(er.sctx, rowFunc.GetArgs()[j], tp)
rowFunc.GetArgs()[j].SetCoercibility(expression.CoercibilityExplicit)
}
}
}
}

func (er *expressionRewriter) caseToExpression(v *ast.CaseExpr) {
stkLen := len(er.ctxStack)
argsLen := 2 * len(v.WhenClauses)
Expand Down