diff --git a/sql/vector.sql b/sql/vector.sql index fb5fbdc..071b444 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -292,4 +292,4 @@ CREATE OPERATOR CLASS vector_cosine_ops FOR TYPE vector USING hnsw AS OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), - FUNCTION 2 vector_norm(vector); + FUNCTION 3 normalize_l2(vector); diff --git a/src/hnsw.c b/src/hnsw.c index 6248aa8..8bfb5be 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -167,7 +167,7 @@ hnswhandler(PG_FUNCTION_ARGS) IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; - amroutine->amsupport = 2; + amroutine->amsupport = 3; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif diff --git a/src/hnsw.h b/src/hnsw.h index 0b550a3..1f69bac 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -19,6 +19,7 @@ /* Support functions */ #define HNSW_DISTANCE_PROC 1 #define HNSW_NORM_PROC 2 +#define HNSW_NORMALIZE_PROC 3 #define HNSW_VERSION 1 #define HNSW_MAGIC_NUMBER 0xA953A953 @@ -147,6 +148,7 @@ typedef struct HnswBuildState /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; + FmgrInfo *normalizeprocinfo; Oid collation; /* Variables */ @@ -220,6 +222,7 @@ typedef struct HnswScanOpaqueData /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; + FmgrInfo *normalizeprocinfo; Oid collation; } HnswScanOpaqueData; @@ -255,7 +258,7 @@ typedef struct HnswVacuumState int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation rel, uint16 procnum); -bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); +bool HnswNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result); void HnswCommitBuffer(Buffer buf, GenericXLogState *state); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 2341cdb..5315c09 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -278,11 +278,8 @@ InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState * Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Normalize if needed */ - if (buildstate->normprocinfo != NULL) - { - if (!HnswNormValue(buildstate->normprocinfo, collation, &value, buildstate->normvec)) - return false; - } + if (!HnswNormValue(buildstate->normprocinfo, buildstate->normalizeprocinfo, collation, &value, buildstate->normvec)) + return false; /* Copy value to element so accessible outside of memory context */ memcpy(element->vec, DatumGetVector(value), VECTOR_SIZE(buildstate->dimensions)); @@ -413,6 +410,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index /* Get support functions */ buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC); buildstate->collation = index->rd_indcollation[0]; buildstate->elements = NIL; diff --git a/src/hnswinsert.c b/src/hnswinsert.c index 5dc4fb3..d4b89b3 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -417,6 +417,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti { Datum value; FmgrInfo *normprocinfo; + FmgrInfo *normalizeprocinfo; HnswElement entryPoint; HnswElement element; int m = HnswGetM(index); @@ -432,11 +433,9 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti /* Normalize if needed */ normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); - if (normprocinfo != NULL) - { - if (!HnswNormValue(normprocinfo, collation, &value, NULL)) - return false; - } + normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC); + if (!HnswNormValue(normprocinfo, normalizeprocinfo, collation, &value, NULL)) + return false; /* Create an element */ element = HnswInitElement(heap_tid, m, ml, HnswGetMaxLevel(m)); diff --git a/src/hnswscan.c b/src/hnswscan.c index 365c6e3..2aa8d3c 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -78,6 +78,7 @@ hnswbeginscan(Relation index, int nkeys, int norderbys) /* Set support functions */ so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + so->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC); so->collation = index->rd_indcollation[0]; scan->opaque = so; @@ -140,8 +141,7 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); /* Fine if normalization fails */ - if (so->normprocinfo != NULL) - HnswNormValue(so->normprocinfo, so->collation, &value, NULL); + HnswNormValue(so->normprocinfo, so->normalizeprocinfo, so->collation, &value, NULL); } GetScanItems(scan, value); diff --git a/src/hnswutils.c b/src/hnswutils.c index a7bccc8..36dc46d 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -55,9 +55,20 @@ HnswOptionalProcInfo(Relation rel, uint16 procnum) * if it's different than the original value */ bool -HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result) +HnswNormValue(FmgrInfo *procinfo, FmgrInfo *normalizeprocinfo, Oid collation, Datum *value, Vector * result) { - double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); + double norm; + + if (normalizeprocinfo != NULL) + { + *value = FunctionCall1Coll(normalizeprocinfo, collation, *value); + return true; + } + + if (procinfo == NULL) + return true; + + norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); if (norm > 0) { diff --git a/test/expected/hnsw_cosine.out b/test/expected/hnsw_cosine.out index df9eb81..01542bc 100644 --- a/test/expected/hnsw_cosine.out +++ b/test/expected/hnsw_cosine.out @@ -9,18 +9,19 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]'; [1,1,1] [1,2,3] [1,2,4] -(3 rows) + [0,0,0] +(4 rows) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; count ------- - 3 + 4 (1 row) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; count ------- - 3 + 4 (1 row) DROP TABLE t;