Added support function for l2_normalize [skip ci]

This commit is contained in:
Andrew Kane
2024-04-22 18:36:47 -07:00
parent 2b77005610
commit f14c21748b
8 changed files with 22 additions and 18 deletions

View File

@@ -342,7 +342,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
FOR TYPE halfvec USING hnsw AS
OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops,
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
FUNCTION 2 l2_norm(halfvec);
FUNCTION 2 l2_norm(halfvec),
FUNCTION 3 l2_normalize(halfvec);
CREATE OPERATOR CLASS halfvec_l1_ops
FOR TYPE halfvec USING hnsw AS
@@ -529,7 +530,8 @@ CREATE OPERATOR CLASS sparsevec_cosine_ops
FOR TYPE sparsevec USING hnsw AS
OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops,
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
FUNCTION 2 l2_norm(sparsevec);
FUNCTION 2 l2_norm(sparsevec),
FUNCTION 3 l2_normalize(sparsevec);
CREATE OPERATOR CLASS sparsevec_l1_ops
FOR TYPE sparsevec USING hnsw AS

View File

@@ -651,7 +651,8 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
FOR TYPE halfvec USING hnsw AS
OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops,
FUNCTION 1 halfvec_negative_inner_product(halfvec, halfvec),
FUNCTION 2 l2_norm(halfvec);
FUNCTION 2 l2_norm(halfvec),
FUNCTION 3 l2_normalize(halfvec);
CREATE OPERATOR CLASS halfvec_l1_ops
FOR TYPE halfvec USING hnsw AS
@@ -852,7 +853,8 @@ CREATE OPERATOR CLASS sparsevec_cosine_ops
FOR TYPE sparsevec USING hnsw AS
OPERATOR 1 <=> (sparsevec, sparsevec) FOR ORDER BY float_ops,
FUNCTION 1 sparsevec_negative_inner_product(sparsevec, sparsevec),
FUNCTION 2 l2_norm(sparsevec);
FUNCTION 2 l2_norm(sparsevec),
FUNCTION 3 l2_normalize(sparsevec);
CREATE OPERATOR CLASS sparsevec_l1_ops
FOR TYPE sparsevec USING hnsw AS

View File

@@ -194,7 +194,7 @@ hnswhandler(PG_FUNCTION_ARGS)
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0;
amroutine->amsupport = 2;
amroutine->amsupport = 3;
#if PG_VERSION_NUM >= 130000
amroutine->amoptsprocnum = 0;
#endif

View File

@@ -22,6 +22,7 @@
/* Support functions */
#define HNSW_DISTANCE_PROC 1
#define HNSW_NORM_PROC 2
#define HNSW_NORMALIZE_PROC 3
#define HNSW_VERSION 1
#define HNSW_MAGIC_NUMBER 0xA953A953
@@ -265,6 +266,7 @@ typedef struct HnswBuildState
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
FmgrInfo *normalizeprocinfo;
Oid collation;
/* Variables */
@@ -341,6 +343,7 @@ typedef struct HnswScanOpaqueData
/* Support functions */
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
FmgrInfo *normalizeprocinfo;
Oid collation;
} HnswScanOpaqueData;
@@ -377,7 +380,7 @@ int HnswGetM(Relation index);
int HnswGetEfConstruction(Relation index);
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
HnswType HnswGetType(Relation index);
Datum HnswNormValue(Datum value, HnswType type);
Datum HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value);
bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
void HnswCheckValue(Datum value, HnswType type);
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);

View File

@@ -496,7 +496,7 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation, value))
return false;
value = HnswNormValue(value, buildstate->type);
value = HnswNormValue(buildstate->normalizeprocinfo, buildstate->collation, value);
}
/* Get datum size */
@@ -725,6 +725,7 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
/* Get support functions */
buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
buildstate->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
buildstate->collation = index->rd_indcollation[0];
InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L);

View File

@@ -629,7 +629,7 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
if (!HnswCheckNorm(normprocinfo, collation, value))
return;
value = HnswNormValue(value, type);
value = HnswNormValue(HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC), collation, value);
}
HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false);

View File

@@ -61,7 +61,7 @@ GetScanValue(IndexScanDesc scan)
/* Fine if normalization fails */
if (so->normprocinfo != NULL)
value = HnswNormValue(value, HnswGetType(scan->indexRelation));
value = HnswNormValue(so->normalizeprocinfo, so->collation, value);
}
return value;
@@ -87,6 +87,7 @@ hnswbeginscan(Relation index, int nkeys, int norderbys)
/* Set support functions */
so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC);
so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
so->normalizeprocinfo = HnswOptionalProcInfo(index, HNSW_NORMALIZE_PROC);
so->collation = index->rd_indcollation[0];
scan->opaque = so;

View File

@@ -195,17 +195,12 @@ HnswGetType(Relation index)
* Normalize value
*/
Datum
HnswNormValue(Datum value, HnswType type)
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum value)
{
/* TODO Remove type-specific code */
if (type == HNSW_TYPE_VECTOR)
if (procinfo == NULL)
return DirectFunctionCall1(l2_normalize, value);
else if (type == HNSW_TYPE_HALFVEC)
return DirectFunctionCall1(halfvec_l2_normalize, value);
else if (type == HNSW_TYPE_SPARSEVEC)
return DirectFunctionCall1(sparsevec_l2_normalize, value);
else
elog(ERROR, "Unsupported type");
return FunctionCall1Coll(procinfo, collation, value);
}
/*