Skip to content

Commit

Permalink
*: tiny update
Browse files Browse the repository at this point in the history
  • Loading branch information
zimulala committed Sep 9, 2024
1 parent 0fce721 commit ffc8f35
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 8 deletions.
4 changes: 2 additions & 2 deletions pkg/ddl/generated_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ func (c *illegalFunctionChecker) Enter(inNode ast.Node) (outNode ast.Node, skipC
return inNode, true
}
if c.genType == typeVectorIndex {
_, ok := ast.IsVectorIndexDistanceMetricSupported[node.FnName.L]
if ok {
_, isFunc4Vec := variable.DistanceMetric4VectorIndex[node.FnName.L]
if isFunc4Vec {
c.hasFunc4VectorIdx = true
if len(node.Args) != 1 {
c.otherErr = expression.ErrIncorrectParameterCount.GenWithStackByArgs(node.FnName)
Expand Down
3 changes: 1 addition & 2 deletions pkg/ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ func buildVectorInfoWithCheck(indexPartSpecifications []*ast.IndexPartSpecificat
return nil, "", dbterror.ErrUnsupportedAddVectorIndex.FastGenByArgs(fmt.Sprintf("unsupported function: %v", idxPart.Expr))
}

_, ok = ast.IsVectorIndexDistanceMetricSupported[f.FnName.L]
distanceMetric, ok := variable.DistanceMetric4VectorIndex[f.FnName.L]
if !ok {
return nil, "", dbterror.ErrUnsupportedAddVectorIndex.FastGenByArgs("unsupported function")
}
Expand All @@ -416,7 +416,6 @@ func buildVectorInfoWithCheck(indexPartSpecifications []*ast.IndexPartSpecificat
return nil, "", infoschema.ErrColumnNotExists.GenWithStackByArgs(colExpr.Name.Name.String())
}

distanceMetric := model.DistanceMetric(f.FnName.L)
// check duplicated function on the same column
for _, idx := range tblInfo.Indices {
if idx.VectorInfo == nil {
Expand Down
3 changes: 2 additions & 1 deletion pkg/executor/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,8 @@ func constructResultOfShowCreateTable(ctx sessionctx.Context, dbName *model.CISt
cols = append(cols, colInfo)
}
if idxInfo.VectorInfo != nil {
fmt.Fprintf(buf, "((%s(%s)))", strings.ToUpper(string(idxInfo.VectorInfo.DistanceMetric)), strings.Join(cols, ","))
funcName := variable.Function4VectorIndex[idxInfo.VectorInfo.DistanceMetric]
fmt.Fprintf(buf, "((%s(%s)))", strings.ToUpper(funcName), strings.Join(cols, ","))
} else {
fmt.Fprintf(buf, "(%s)", strings.Join(cols, ","))
}
Expand Down
7 changes: 4 additions & 3 deletions pkg/parser/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -1449,10 +1449,11 @@ type DistanceMetric string

// Note: tipb.VectorDistanceMetric's enum names must be aligned with these constant values.
const (
// DistanceMetricL2 is L2 distance.
DistanceMetricL2 DistanceMetric = "vec_l2_distance"
DistanceMetricL2 DistanceMetric = "L2"
// DistanceMetricCosine is cosine distance.
DistanceMetricCosine DistanceMetric = "vec_cosine_distance"
DistanceMetricCosine DistanceMetric = "COSINE"
// DistanceMetricInnerProduct is inner product.
DistanceMetricInnerProduct DistanceMetric = "INNER_PRODUCT"
)

// VectorIndexInfo is the information of vector index of a column.
Expand Down
13 changes: 13 additions & 0 deletions pkg/sessionctx/variable/varsutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/collate"
Expand Down Expand Up @@ -632,3 +633,15 @@ func parseSchemaCacheSize(s *SessionVars, normalizedValue string, originalValue

return 0, "", ErrTruncatedWrongValue.GenWithStackByArgs(TiDBSchemaCacheSize, originalValue)
}

// DistanceMetric4VectorIndex stores distance metrics for the vector index.
var DistanceMetric4VectorIndex = map[string]model.DistanceMetric{
ast.VecCosineDistance: model.DistanceMetricCosine,
ast.VecL2Distance: model.DistanceMetricL2,
}

// Function4VectorIndex stores functions for the vector index.
var Function4VectorIndex = map[model.DistanceMetric]string{
model.DistanceMetricCosine: ast.VecCosineDistance,
model.DistanceMetricL2: ast.VecL2Distance,
}

0 comments on commit ffc8f35

Please sign in to comment.