Moved norm check to separate function

This commit is contained in:
Andrew Kane
2024-04-15 15:32:08 -07:00
parent 342d82be65
commit 5215c28923
5 changed files with 29 additions and 29 deletions

View File

@@ -377,7 +377,8 @@ int HnswGetM(Relation index);
int HnswGetEfConstruction(Relation index);
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
HnswType HnswGetType(Relation index);
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type);
Datum HnswNormValue(Datum value, HnswType type);
bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
void HnswCheckValue(Datum value, HnswType type);
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
void HnswInitPage(Buffer buf, Page page);

View File

@@ -493,8 +493,10 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
/* Normalize if needed */
if (buildstate->normprocinfo != NULL)
{
if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->type))
if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation, value))
return false;
value = HnswNormValue(value, buildstate->type);
}
/* Get datum size */

View File

@@ -626,8 +626,10 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
if (normprocinfo != NULL)
{
if (!HnswNormValue(normprocinfo, collation, &value, type))
if (!HnswCheckNorm(normprocinfo, collation, value))
return;
value = HnswNormValue(value, type);
}
HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false);

View File

@@ -61,7 +61,7 @@ GetScanValue(IndexScanDesc scan)
/* Fine if normalization fails */
if (so->normprocinfo != NULL)
HnswNormValue(so->normprocinfo, so->collation, &value, HnswGetType(scan->indexRelation));
value = HnswNormValue(value, HnswGetType(scan->indexRelation));
}
return value;

View File

@@ -193,34 +193,29 @@ HnswGetType(Relation index)
}
/*
* Divide by the norm
*
* Returns false if value should not be indexed
*
* The caller needs to free the pointer stored in value
* if it's different than the original value
* Normalize value
*/
Datum
HnswNormValue(Datum value, HnswType type)
{
/* TODO Remove type-specific code */
if (type == HNSW_TYPE_VECTOR)
return DirectFunctionCall1(l2_normalize, value);
else if (type == HNSW_TYPE_HALFVEC)
return DirectFunctionCall1(halfvec_l2_normalize, value);
else if (type == HNSW_TYPE_SPARSEVEC)
return DirectFunctionCall1(sparsevec_l2_normalize, value);
else
elog(ERROR, "Unsupported type");
}
/*
* Check if non-zero norm
*/
bool
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type)
HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value)
{
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
if (norm > 0)
{
/* TODO Remove type-specific code */
if (type == HNSW_TYPE_VECTOR)
*value = DirectFunctionCall1(l2_normalize, *value);
else if (type == HNSW_TYPE_HALFVEC)
*value = DirectFunctionCall1(halfvec_l2_normalize, *value);
else if (type == HNSW_TYPE_SPARSEVEC)
*value = DirectFunctionCall1(sparsevec_l2_normalize, *value);
else
elog(ERROR, "Unsupported type");
return true;
}
return false;
return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0;
}
/*