mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Use normalize_l2 for ivfflat
This commit is contained in:
@@ -268,15 +268,15 @@ CREATE OPERATOR CLASS vector_ip_ops
|
||||
OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 vector_negative_inner_product(vector, vector),
|
||||
FUNCTION 3 vector_spherical_distance(vector, vector),
|
||||
FUNCTION 4 vector_norm(vector);
|
||||
FUNCTION 6 normalize_l2(vector);
|
||||
|
||||
CREATE OPERATOR CLASS vector_cosine_ops
|
||||
FOR TYPE vector USING ivfflat AS
|
||||
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
|
||||
FUNCTION 1 vector_negative_inner_product(vector, vector),
|
||||
FUNCTION 2 vector_norm(vector),
|
||||
FUNCTION 3 vector_spherical_distance(vector, vector),
|
||||
FUNCTION 4 vector_norm(vector);
|
||||
FUNCTION 5 normalize_l2(vector),
|
||||
FUNCTION 6 normalize_l2(vector);
|
||||
|
||||
CREATE OPERATOR CLASS vector_l2_ops
|
||||
FOR TYPE vector USING hnsw AS
|
||||
|
||||
@@ -75,11 +75,7 @@ AddSample(Datum *values, IvfflatBuildState * buildstate)
|
||||
* Normalize with KMEANS_NORM_PROC since spherical distance function
|
||||
* expects unit vectors
|
||||
*/
|
||||
if (buildstate->kmeansnormprocinfo != NULL)
|
||||
{
|
||||
if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->normvec))
|
||||
return;
|
||||
}
|
||||
IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->kmeansnormalizeprocinfo, buildstate->collation, &value, buildstate->normvec);
|
||||
|
||||
if (samples->length < targsamples)
|
||||
{
|
||||
@@ -176,11 +172,7 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState
|
||||
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
|
||||
|
||||
/* Normalize if needed */
|
||||
if (buildstate->normprocinfo != NULL)
|
||||
{
|
||||
if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec))
|
||||
return;
|
||||
}
|
||||
IvfflatNormValue(buildstate->normprocinfo, buildstate->normalizeprocinfo, buildstate->collation, &value, buildstate->normvec);
|
||||
|
||||
/* Find the list that minimizes the distance */
|
||||
for (int i = 0; i < centers->length; i++)
|
||||
@@ -368,6 +360,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
|
||||
buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
|
||||
buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
buildstate->kmeansnormprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
|
||||
buildstate->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
buildstate->kmeansnormalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORMALIZE_PROC);
|
||||
buildstate->collation = index->rd_indcollation[0];
|
||||
|
||||
/* Require more than one dimension for spherical k-means */
|
||||
|
||||
@@ -194,7 +194,7 @@ ivfflathandler(PG_FUNCTION_ARGS)
|
||||
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
|
||||
|
||||
amroutine->amstrategies = 0;
|
||||
amroutine->amsupport = 4;
|
||||
amroutine->amsupport = 6;
|
||||
#if PG_VERSION_NUM >= 130000
|
||||
amroutine->amoptsprocnum = 0;
|
||||
#endif
|
||||
|
||||
@@ -31,6 +31,8 @@
|
||||
#define IVFFLAT_NORM_PROC 2
|
||||
#define IVFFLAT_KMEANS_DISTANCE_PROC 3
|
||||
#define IVFFLAT_KMEANS_NORM_PROC 4
|
||||
#define IVFFLAT_NORMALIZE_PROC 5
|
||||
#define IVFFLAT_KMEANS_NORMALIZE_PROC 6
|
||||
|
||||
#define IVFFLAT_VERSION 1
|
||||
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
|
||||
@@ -172,6 +174,8 @@ typedef struct IvfflatBuildState
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *kmeansnormprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
FmgrInfo *kmeansnormalizeprocinfo;
|
||||
Oid collation;
|
||||
|
||||
/* Variables */
|
||||
@@ -253,6 +257,7 @@ typedef struct IvfflatScanOpaqueData
|
||||
/* Support functions */
|
||||
FmgrInfo *procinfo;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
Oid collation;
|
||||
|
||||
/* Lists */
|
||||
@@ -273,7 +278,7 @@ void VectorArrayFree(VectorArray arr);
|
||||
void PrintVectorArray(char *msg, VectorArray arr);
|
||||
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
|
||||
FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum);
|
||||
bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
|
||||
void IvfflatNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result);
|
||||
int IvfflatGetLists(Relation index);
|
||||
void IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum);
|
||||
void IvfflatCommitBuffer(Buffer buf, GenericXLogState *state);
|
||||
|
||||
@@ -68,6 +68,7 @@ InsertTuple(Relation rel, Datum *values, bool *isnull, ItemPointer heap_tid, Rel
|
||||
IndexTuple itup;
|
||||
Datum value;
|
||||
FmgrInfo *normprocinfo;
|
||||
FmgrInfo *normalizeprocinfo;
|
||||
Buffer buf;
|
||||
Page page;
|
||||
GenericXLogState *state;
|
||||
@@ -81,11 +82,8 @@ InsertTuple(Relation rel, Datum *values, bool *isnull, ItemPointer heap_tid, Rel
|
||||
|
||||
/* Normalize if needed */
|
||||
normprocinfo = IvfflatOptionalProcInfo(rel, IVFFLAT_NORM_PROC);
|
||||
if (normprocinfo != NULL)
|
||||
{
|
||||
if (!IvfflatNormValue(normprocinfo, rel->rd_indcollation[0], &value, NULL))
|
||||
return;
|
||||
}
|
||||
normalizeprocinfo = IvfflatOptionalProcInfo(rel, IVFFLAT_NORMALIZE_PROC);
|
||||
IvfflatNormValue(normprocinfo, normalizeprocinfo, rel->rd_indcollation[0], &value, NULL);
|
||||
|
||||
/* Find the insert page - sets the page and list info */
|
||||
FindInsertPage(rel, values, &insertPage, &listInfo);
|
||||
|
||||
@@ -232,6 +232,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
|
||||
/* Set support functions */
|
||||
so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
|
||||
so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
|
||||
so->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
|
||||
so->collation = index->rd_indcollation[0];
|
||||
|
||||
/* Create tuple description for sorting */
|
||||
@@ -319,8 +320,7 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir)
|
||||
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
|
||||
|
||||
/* Fine if normalization fails */
|
||||
if (so->normprocinfo != NULL)
|
||||
IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL);
|
||||
IvfflatNormValue(so->normprocinfo, so->normalizeprocinfo, so->collation, &value, NULL);
|
||||
}
|
||||
|
||||
IvfflatBench("GetScanLists", GetScanLists(scan, value));
|
||||
|
||||
@@ -66,17 +66,26 @@ IvfflatOptionalProcInfo(Relation rel, uint16 procnum)
|
||||
}
|
||||
|
||||
/*
|
||||
* Divide by the norm
|
||||
*
|
||||
* Returns false if value should not be indexed
|
||||
* Normalize a vector
|
||||
*
|
||||
* The caller needs to free the pointer stored in value
|
||||
* if it's different than the original value
|
||||
*/
|
||||
bool
|
||||
IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result)
|
||||
void
|
||||
IvfflatNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result)
|
||||
{
|
||||
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
|
||||
double norm;
|
||||
|
||||
if (normalizeprocinfo != NULL)
|
||||
{
|
||||
*value = FunctionCall1Coll(normalizeprocinfo, collation, *value);
|
||||
return;
|
||||
}
|
||||
|
||||
if (procinfo == NULL)
|
||||
return;
|
||||
|
||||
norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
|
||||
|
||||
if (norm > 0)
|
||||
{
|
||||
@@ -89,11 +98,7 @@ IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * resul
|
||||
result->x[i] = v->x[i] / norm;
|
||||
|
||||
*value = PointerGetDatum(result);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -9,18 +9,19 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]';
|
||||
[1,1,1]
|
||||
[1,2,3]
|
||||
[1,2,4]
|
||||
(3 rows)
|
||||
[0,0,0]
|
||||
(4 rows)
|
||||
|
||||
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2;
|
||||
count
|
||||
-------
|
||||
3
|
||||
4
|
||||
(1 row)
|
||||
|
||||
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2;
|
||||
count
|
||||
-------
|
||||
3
|
||||
4
|
||||
(1 row)
|
||||
|
||||
DROP TABLE t;
|
||||
|
||||
Reference in New Issue
Block a user