mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Added support function for l2_normalize to ivfflat
This commit is contained in:
@@ -318,7 +318,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops
|
||||
OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
|
||||
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
|
||||
FUNCTION 4 l2_norm(halfvec);
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FOR TYPE halfvec USING ivfflat AS
|
||||
@@ -326,7 +327,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
|
||||
FUNCTION 2 l2_norm(halfvec),
|
||||
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
|
||||
FUNCTION 4 l2_norm(halfvec);
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_l2_ops
|
||||
FOR TYPE halfvec USING hnsw AS
|
||||
|
||||
@@ -627,7 +627,8 @@ CREATE OPERATOR CLASS halfvec_ip_ops
|
||||
OPERATOR 1 <#> (halfvec, halfvec) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
|
||||
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
|
||||
FUNCTION 4 l2_norm(halfvec);
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FOR TYPE halfvec USING ivfflat AS
|
||||
@@ -635,7 +636,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
|
||||
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
|
||||
FUNCTION 2 l2_norm(halfvec),
|
||||
FUNCTION 3 halfvec_spherical_distance(halfvec, halfvec),
|
||||
FUNCTION 4 l2_norm(halfvec);
|
||||
FUNCTION 4 l2_norm(halfvec),
|
||||
FUNCTION 5 l2_normalize(halfvec);
|
||||
|
||||
CREATE OPERATOR CLASS halfvec_l2_ops
|
||||
FOR TYPE halfvec USING hnsw AS
|
||||
|
||||
@@ -60,8 +60,10 @@ AddSample(Datum *values, IvfflatBuildState * buildstate)
|
||||
*/
|
||||
if (buildstate->kmeansnormprocinfo != NULL)
|
||||
{
|
||||
if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->type))
|
||||
if (!IvfflatCheckNorm(buildstate->kmeansnormprocinfo, buildstate->collation, value))
|
||||
return;
|
||||
|
||||
value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value);
|
||||
}
|
||||
|
||||
if (samples->length < targsamples)
|
||||
@@ -156,8 +158,10 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState
|
||||
/* Normalize if needed */
|
||||
if (buildstate->normprocinfo != NULL)
|
||||
{
|
||||
if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->type))
|
||||
if (!IvfflatCheckNorm(buildstate->normprocinfo, buildstate->collation, value))
|
||||
return;
|
||||
|
||||
value = IvfflatNormValue(buildstate->normalizeprocinfo, buildstate->collation, value);
|
||||
}
|
||||
|
||||
/* Find the list that minimizes the distance */
|
||||
@@ -379,6 +383,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
|
||||
buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
|
||||
buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
buildstate->kmeansnormprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
buildstate->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
buildstate->collation = index->rd_indcollation[0];
|
||||
|
||||
/* Require more than one dimension for spherical k-means */
|
||||
|
||||
@@ -188,7 +188,7 @@ ivfflathandler(PG_FUNCTION_ARGS)
|
||||
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
|
||||
|
||||
amroutine->amstrategies = 0;
|
||||
amroutine->amsupport = 4;
|
||||
amroutine->amsupport = 5;
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
amroutine->amoptsprocnum = 0;
|
||||
#endif
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
#define IVFFLAT_NORM_PROC 2
|
||||
#define IVFFLAT_KMEANS_DISTANCE_PROC 3
|
||||
#define IVFFLAT_KMEANS_NORM_PROC 4
|
||||
#define IVFFLAT_NORMALIZE_PROC 5
|
||||
|
||||
#define IVFFLAT_VERSION 1
|
||||
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
|
||||
@@ -175,6 +176,7 @@ typedef struct IvfflatBuildState
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *kmeansnormprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
Oid collation;
|
||||
|
||||
/* Variables */
|
||||
@@ -255,6 +257,7 @@ typedef struct IvfflatScanOpaqueData
|
||||
/* Support functions */
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
Oid collation;
|
||||
Datum (*distfunc) (FmgrInfo *flinfo, Oid collation, Datum arg1, Datum arg2);
|
||||
|
||||
@@ -276,7 +279,8 @@ void VectorArrayFree(VectorArray arr);
|
||||
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers, IvfflatType type);
|
||||
FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum);
|
||||
IvfflatType IvfflatGetType(Relation index);
|
||||
bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType type);
|
||||
Datum IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
bool IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
int IvfflatGetLists(Relation index);
|
||||
void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions);
|
||||
void IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum);
|
||||
|
||||
@@ -85,8 +85,12 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, R
|
||||
normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
if (normprocinfo != NULL)
|
||||
{
|
||||
if (!IvfflatNormValue(normprocinfo, index->rd_indcollation[0], &value, IvfflatGetType(index)))
|
||||
Oid collation = index->rd_indcollation[0];
|
||||
|
||||
if (!IvfflatCheckNorm(normprocinfo, collation, value))
|
||||
return;
|
||||
|
||||
value = IvfflatNormValue(IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC), collation, value);
|
||||
}
|
||||
|
||||
/* Find the insert page - sets the page and list info */
|
||||
|
||||
@@ -209,9 +209,9 @@ GetScanValue(IndexScanDesc scan)
|
||||
Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));
|
||||
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
|
||||
|
||||
/* Fine if normalization fails */
|
||||
/* Check normprocinfo since normalizeprocinfo not set for vector */
|
||||
if (so->normprocinfo != NULL)
|
||||
IvfflatNormValue(so->normprocinfo, so->collation, &value, IvfflatGetType(scan->indexRelation));
|
||||
value = IvfflatNormValue(so->normalizeprocinfo, so->collation, value);
|
||||
}
|
||||
|
||||
return value;
|
||||
@@ -249,6 +249,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
|
||||
/* Set support functions */
|
||||
so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
|
||||
so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
so->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
so->collation = index->rd_indcollation[0];
|
||||
|
||||
/* Create tuple description for sorting */
|
||||
|
||||
@@ -99,31 +99,24 @@ IvfflatGetType(Relation index)
|
||||
}
|
||||
|
||||
/*
|
||||
* Divide by the norm
|
||||
*
|
||||
* Returns false if value should not be indexed
|
||||
*
|
||||
* The caller needs to free the pointer stored in value
|
||||
* if it's different than the original value
|
||||
* Normalize value
|
||||
*/
|
||||
bool
|
||||
IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, IvfflatType type)
|
||||
Datum
|
||||
IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum value)
|
||||
{
|
||||
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
|
||||
if (procinfo == NULL)
|
||||
return DirectFunctionCall1(l2_normalize, value);
|
||||
|
||||
if (norm > 0)
|
||||
{
|
||||
if (type == IVFFLAT_TYPE_VECTOR)
|
||||
*value = DirectFunctionCall1(l2_normalize, *value);
|
||||
else if (type == IVFFLAT_TYPE_HALFVEC)
|
||||
*value = DirectFunctionCall1(halfvec_l2_normalize, *value);
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
return true;
|
||||
return FunctionCall1Coll(procinfo, collation, value);
|
||||
}
|
||||
|
||||
return false;
|
||||
/*
|
||||
* Check if non-zero norm
|
||||
*/
|
||||
bool
|
||||
IvfflatCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value)
|
||||
{
|
||||
return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0;
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
Reference in New Issue
Block a user