Use normalize_l2 for ivfflat

This commit is contained in:
Andrew Kane
2023-10-16 17:56:50 -07:00
parent dd609f200b
commit a5bb59d9f6
8 changed files with 38 additions and 35 deletions

View File

@@ -268,15 +268,15 @@ CREATE OPERATOR CLASS vector_ip_ops
OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 3 vector_spherical_distance(vector, vector),
FUNCTION 4 vector_norm(vector);
FUNCTION 6 normalize_l2(vector);
CREATE OPERATOR CLASS vector_cosine_ops
FOR TYPE vector USING ivfflat AS
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 2 vector_norm(vector),
FUNCTION 3 vector_spherical_distance(vector, vector),
FUNCTION 4 vector_norm(vector);
FUNCTION 5 normalize_l2(vector),
FUNCTION 6 normalize_l2(vector);
CREATE OPERATOR CLASS vector_l2_ops
FOR TYPE vector USING hnsw AS

View File

@@ -75,11 +75,7 @@ AddSample(Datum *values, IvfflatBuildState * buildstate)
* Normalize with KMEANS_NORM_PROC since spherical distance function
* expects unit vectors
*/
if (buildstate->kmeansnormprocinfo != NULL)
{
if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->normvec))
return;
}
IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->kmeansnormalizeprocinfo, buildstate->collation, &value, buildstate->normvec);
if (samples->length < targsamples)
{
@@ -176,11 +172,7 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0]));
/* Normalize if needed */
if (buildstate->normprocinfo != NULL)
{
if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec))
return;
}
IvfflatNormValue(buildstate->normprocinfo, buildstate->normalizeprocinfo, buildstate->collation, &value, buildstate->normvec);
/* Find the list that minimizes the distance */
for (int i = 0; i < centers->length; i++)
@@ -368,6 +360,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
buildstate->kmeansnormprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC);
buildstate->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
buildstate->kmeansnormalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORMALIZE_PROC);
buildstate->collation = index->rd_indcollation[0];
/* Require more than one dimension for spherical k-means */

View File

@@ -194,7 +194,7 @@ ivfflathandler(PG_FUNCTION_ARGS)
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0;
amroutine->amsupport = 4;
amroutine->amsupport = 6;
#if PG_VERSION_NUM >= 130000
amroutine->amoptsprocnum = 0;
#endif

View File

@@ -31,6 +31,8 @@
#define IVFFLAT_NORM_PROC 2
#define IVFFLAT_KMEANS_DISTANCE_PROC 3
#define IVFFLAT_KMEANS_NORM_PROC 4
#define IVFFLAT_NORMALIZE_PROC 5
#define IVFFLAT_KMEANS_NORMALIZE_PROC 6
#define IVFFLAT_VERSION 1
#define IVFFLAT_MAGIC_NUMBER 0x14FF1A7
@@ -172,6 +174,8 @@ typedef struct IvfflatBuildState
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
FmgrInfo *kmeansnormprocinfo;
FmgrInfo *normalizeprocinfo;
FmgrInfo *kmeansnormalizeprocinfo;
Oid collation;
/* Variables */
@@ -253,6 +257,7 @@ typedef struct IvfflatScanOpaqueData
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
FmgrInfo *normalizeprocinfo;
Oid collation;
/* Lists */
@@ -273,7 +278,7 @@ void VectorArrayFree(VectorArray arr);
void PrintVectorArray(char *msg, VectorArray arr);
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum);
bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result);
void IvfflatNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result);
int IvfflatGetLists(Relation index);
void IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum);
void IvfflatCommitBuffer(Buffer buf, GenericXLogState *state);

View File

@@ -68,6 +68,7 @@ InsertTuple(Relation rel, Datum *values, bool *isnull, ItemPointer heap_tid, Rel
IndexTuple itup;
Datum value;
FmgrInfo *normprocinfo;
FmgrInfo *normalizeprocinfo;
Buffer buf;
Page page;
GenericXLogState *state;
@@ -81,11 +82,8 @@ InsertTuple(Relation rel, Datum *values, bool *isnull, ItemPointer heap_tid, Rel
/* Normalize if needed */
normprocinfo = IvfflatOptionalProcInfo(rel, IVFFLAT_NORM_PROC);
if (normprocinfo != NULL)
{
if (!IvfflatNormValue(normprocinfo, rel->rd_indcollation[0], &value, NULL))
return;
}
normalizeprocinfo = IvfflatOptionalProcInfo(rel, IVFFLAT_NORMALIZE_PROC);
IvfflatNormValue(normprocinfo, normalizeprocinfo, rel->rd_indcollation[0], &value, NULL);
/* Find the insert page - sets the page and list info */
FindInsertPage(rel, values, &insertPage, &listInfo);

View File

@@ -232,6 +232,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
/* Set support functions */
so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC);
so->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC);
so->collation = index->rd_indcollation[0];
/* Create tuple description for sorting */
@@ -319,8 +320,7 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir)
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));
/* Fine if normalization fails */
if (so->normprocinfo != NULL)
IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL);
IvfflatNormValue(so->normprocinfo, so->normalizeprocinfo, so->collation, &value, NULL);
}
IvfflatBench("GetScanLists", GetScanLists(scan, value));

View File

@@ -66,17 +66,26 @@ IvfflatOptionalProcInfo(Relation rel, uint16 procnum)
}
/*
* Divide by the norm
*
* Returns false if value should not be indexed
* Normalize a vector
*
* The caller needs to free the pointer stored in value
* if it's different than the original value
*/
bool
IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result)
void
IvfflatNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result)
{
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
double norm;
if (normalizeprocinfo != NULL)
{
*value = FunctionCall1Coll(normalizeprocinfo, collation, *value);
return;
}
if (procinfo == NULL)
return;
norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
if (norm > 0)
{
@@ -89,11 +98,7 @@ IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * resul
result->x[i] = v->x[i] / norm;
*value = PointerGetDatum(result);
return true;
}
return false;
}
/*

View File

@@ -9,18 +9,19 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]';
[1,1,1]
[1,2,3]
[1,2,4]
(3 rows)
[0,0,0]
(4 rows)
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2;
count
-------
3
4
(1 row)
SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2;
count
-------
3
4
(1 row)
DROP TABLE t;