From a5bb59d9f6aa5c06a06a699169fdf22f36eb8808 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 16 Oct 2023 17:56:50 -0700 Subject: [PATCH] Use normalize_l2 for ivfflat --- sql/vector.sql | 6 +++--- src/ivfbuild.c | 14 ++++---------- src/ivfflat.c | 2 +- src/ivfflat.h | 7 ++++++- src/ivfinsert.c | 8 +++----- src/ivfscan.c | 4 ++-- src/ivfutils.c | 25 +++++++++++++++---------- test/expected/ivfflat_cosine.out | 7 ++++--- 8 files changed, 38 insertions(+), 35 deletions(-) diff --git a/sql/vector.sql b/sql/vector.sql index 071b444..e179df0 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -268,15 +268,15 @@ CREATE OPERATOR CLASS vector_ip_ops OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 3 vector_spherical_distance(vector, vector), - FUNCTION 4 vector_norm(vector); + FUNCTION 6 normalize_l2(vector); CREATE OPERATOR CLASS vector_cosine_ops FOR TYPE vector USING ivfflat AS OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), - FUNCTION 2 vector_norm(vector), FUNCTION 3 vector_spherical_distance(vector, vector), - FUNCTION 4 vector_norm(vector); + FUNCTION 5 normalize_l2(vector), + FUNCTION 6 normalize_l2(vector); CREATE OPERATOR CLASS vector_l2_ops FOR TYPE vector USING hnsw AS diff --git a/src/ivfbuild.c b/src/ivfbuild.c index cc4f7a3..022d611 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -75,11 +75,7 @@ AddSample(Datum *values, IvfflatBuildState * buildstate) * Normalize with KMEANS_NORM_PROC since spherical distance function * expects unit vectors */ - if (buildstate->kmeansnormprocinfo != NULL) - { - if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->normvec)) - return; - } + IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->kmeansnormalizeprocinfo, buildstate->collation, &value, buildstate->normvec); if (samples->length < targsamples) { @@ -176,11 +172,7 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Normalize if needed */ - if (buildstate->normprocinfo != NULL) - { - if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec)) - return; - } + IvfflatNormValue(buildstate->normprocinfo, buildstate->normalizeprocinfo, buildstate->collation, &value, buildstate->normvec); /* Find the list that minimizes the distance */ for (int i = 0; i < centers->length; i++) @@ -368,6 +360,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); buildstate->kmeansnormprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); + buildstate->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); + buildstate->kmeansnormalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORMALIZE_PROC); buildstate->collation = index->rd_indcollation[0]; /* Require more than one dimension for spherical k-means */ diff --git a/src/ivfflat.c b/src/ivfflat.c index 3753a74..706f716 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -194,7 +194,7 @@ ivfflathandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 4; + amroutine->amsupport = 6; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/ivfflat.h b/src/ivfflat.h index 2c18fd4..1b24c8b 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -31,6 +31,8 @@ #define IVFFLAT_NORM_PROC 2 #define IVFFLAT_KMEANS_DISTANCE_PROC 3 #define IVFFLAT_KMEANS_NORM_PROC 4 +#define IVFFLAT_NORMALIZE_PROC 5 +#define IVFFLAT_KMEANS_NORMALIZE_PROC 6 #define IVFFLAT_VERSION 1 #define IVFFLAT_MAGIC_NUMBER 0x14FF1A7 @@ -172,6 +174,8 @@ typedef struct IvfflatBuildState FmgrInfo *procinfo; FmgrInfo *normprocinfo; FmgrInfo *kmeansnormprocinfo; + FmgrInfo *normalizeprocinfo; + FmgrInfo *kmeansnormalizeprocinfo; Oid collation; /* Variables */ @@ -253,6 +257,7 @@ typedef struct IvfflatScanOpaqueData /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; + FmgrInfo *normalizeprocinfo; Oid collation; /* Lists */ @@ -273,7 +278,7 @@ void VectorArrayFree(VectorArray arr); void PrintVectorArray(char *msg, VectorArray arr); void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers); FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum); -bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); +void IvfflatNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result); int IvfflatGetLists(Relation index); void IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum); void IvfflatCommitBuffer(Buffer buf, GenericXLogState *state); diff --git a/src/ivfinsert.c b/src/ivfinsert.c index cc744af..2306d32 100644 --- a/src/ivfinsert.c +++ b/src/ivfinsert.c @@ -68,6 +68,7 @@ InsertTuple(Relation rel, Datum *values, bool *isnull, ItemPointer heap_tid, Rel IndexTuple itup; Datum value; FmgrInfo *normprocinfo; + FmgrInfo *normalizeprocinfo; Buffer buf; Page page; GenericXLogState *state; @@ -81,11 +82,8 @@ InsertTuple(Relation rel, Datum *values, bool *isnull, ItemPointer heap_tid, Rel /* Normalize if needed */ normprocinfo = IvfflatOptionalProcInfo(rel, IVFFLAT_NORM_PROC); - if (normprocinfo != NULL) - { - if (!IvfflatNormValue(normprocinfo, rel->rd_indcollation[0], &value, NULL)) - return; - } + normalizeprocinfo = IvfflatOptionalProcInfo(rel, IVFFLAT_NORMALIZE_PROC); + IvfflatNormValue(normprocinfo, normalizeprocinfo, rel->rd_indcollation[0], &value, NULL); /* Find the insert page - sets the page and list info */ FindInsertPage(rel, values, &insertPage, &listInfo); diff --git a/src/ivfscan.c b/src/ivfscan.c index 6703120..c4591f3 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -232,6 +232,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) /* Set support functions */ so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); + so->normalizeprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORMALIZE_PROC); so->collation = index->rd_indcollation[0]; /* Create tuple description for sorting */ @@ -319,8 +320,7 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); /* Fine if normalization fails */ - if (so->normprocinfo != NULL) - IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL); + IvfflatNormValue(so->normprocinfo, so->normalizeprocinfo, so->collation, &value, NULL); } IvfflatBench("GetScanLists", GetScanLists(scan, value)); diff --git a/src/ivfutils.c b/src/ivfutils.c index a37f51f..8d1983c 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -66,17 +66,26 @@ IvfflatOptionalProcInfo(Relation rel, uint16 procnum) } /* - * Divide by the norm - * - * Returns false if value should not be indexed + * Normalize a vector * * The caller needs to free the pointer stored in value * if it's different than the original value */ -bool -IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result) +void +IvfflatNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result) { - double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); + double norm; + + if (normalizeprocinfo != NULL) + { + *value = FunctionCall1Coll(normalizeprocinfo, collation, *value); + return; + } + + if (procinfo == NULL) + return; + + norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); if (norm > 0) { @@ -89,11 +98,7 @@ IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * resul result->x[i] = v->x[i] / norm; *value = PointerGetDatum(result); - - return true; } - - return false; } /* diff --git a/test/expected/ivfflat_cosine.out b/test/expected/ivfflat_cosine.out index 8584d95..d647057 100644 --- a/test/expected/ivfflat_cosine.out +++ b/test/expected/ivfflat_cosine.out @@ -9,18 +9,19 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]'; [1,1,1] [1,2,3] [1,2,4] -(3 rows) + [0,0,0] +(4 rows) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; count ------- - 3 + 4 (1 row) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; count ------- - 3 + 4 (1 row) DROP TABLE t;