From 5215c289231f7a87e859548231099ec40d2be84b Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 15 Apr 2024 15:32:08 -0700 Subject: [PATCH] Moved norm check to separate function --- src/hnsw.h | 3 ++- src/hnswbuild.c | 4 +++- src/hnswinsert.c | 4 +++- src/hnswscan.c | 2 +- src/hnswutils.c | 45 ++++++++++++++++++++------------------------- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/hnsw.h b/src/hnsw.h index 772b228..c1a7f12 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -377,7 +377,8 @@ int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); HnswType HnswGetType(Relation index); -bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type); +Datum HnswNormValue(Datum value, HnswType type); +bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value); void HnswCheckValue(Datum value, HnswType type); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 2300127..bd32e06 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -493,8 +493,10 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn /* Normalize if needed */ if (buildstate->normprocinfo != NULL) { - if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->type)) + if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation, value)) return false; + + value = HnswNormValue(value, buildstate->type); } /* Get datum size */ diff --git a/src/hnswinsert.c b/src/hnswinsert.c index c5ea1fd..88fc6c3 100644 --- a/src/hnswinsert.c +++ b/src/hnswinsert.c @@ -626,8 +626,10 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); if (normprocinfo != NULL) { - if (!HnswNormValue(normprocinfo, collation, &value, type)) + if (!HnswCheckNorm(normprocinfo, collation, value)) return; + + value = HnswNormValue(value, type); } HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false); diff --git a/src/hnswscan.c b/src/hnswscan.c index 3bd0b8f..ad00852 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -61,7 +61,7 @@ GetScanValue(IndexScanDesc scan) /* Fine if normalization fails */ if (so->normprocinfo != NULL) - HnswNormValue(so->normprocinfo, so->collation, &value, HnswGetType(scan->indexRelation)); + value = HnswNormValue(value, HnswGetType(scan->indexRelation)); } return value; diff --git a/src/hnswutils.c b/src/hnswutils.c index bd6dd5d..1e0f238 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -193,34 +193,29 @@ HnswGetType(Relation index) } /* - * Divide by the norm - * - * Returns false if value should not be indexed - * - * The caller needs to free the pointer stored in value - * if it's different than the original value + * Normalize value + */ +Datum +HnswNormValue(Datum value, HnswType type) +{ + /* TODO Remove type-specific code */ + if (type == HNSW_TYPE_VECTOR) + return DirectFunctionCall1(l2_normalize, value); + else if (type == HNSW_TYPE_HALFVEC) + return DirectFunctionCall1(halfvec_l2_normalize, value); + else if (type == HNSW_TYPE_SPARSEVEC) + return DirectFunctionCall1(sparsevec_l2_normalize, value); + else + elog(ERROR, "Unsupported type"); +} + +/* + * Check if non-zero norm */ bool -HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type) +HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value) { - double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); - - if (norm > 0) - { - /* TODO Remove type-specific code */ - if (type == HNSW_TYPE_VECTOR) - *value = DirectFunctionCall1(l2_normalize, *value); - else if (type == HNSW_TYPE_HALFVEC) - *value = DirectFunctionCall1(halfvec_l2_normalize, *value); - else if (type == HNSW_TYPE_SPARSEVEC) - *value = DirectFunctionCall1(sparsevec_l2_normalize, *value); - else - elog(ERROR, "Unsupported type"); - - return true; - } - - return false; + return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0; } /*