mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-04 11:40:57 +08:00
Moved norm check to separate function
This commit is contained in:
@@ -377,7 +377,8 @@ int HnswGetM(Relation index);
|
||||
int HnswGetEfConstruction(Relation index);
|
||||
FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum);
|
||||
HnswType HnswGetType(Relation index);
|
||||
bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type);
|
||||
Datum HnswNormValue(Datum value, HnswType type);
|
||||
bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
void HnswCheckValue(Datum value, HnswType type);
|
||||
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
|
||||
void HnswInitPage(Buffer buf, Page page);
|
||||
|
||||
@@ -493,8 +493,10 @@ InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, Hn
|
||||
/* Normalize if needed */
|
||||
if (buildstate->normprocinfo != NULL)
|
||||
{
|
||||
if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->type))
|
||||
if (!HnswCheckNorm(buildstate->normprocinfo, buildstate->collation, value))
|
||||
return false;
|
||||
|
||||
value = HnswNormValue(value, buildstate->type);
|
||||
}
|
||||
|
||||
/* Get datum size */
|
||||
|
||||
@@ -626,8 +626,10 @@ HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_ti
|
||||
normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC);
|
||||
if (normprocinfo != NULL)
|
||||
{
|
||||
if (!HnswNormValue(normprocinfo, collation, &value, type))
|
||||
if (!HnswCheckNorm(normprocinfo, collation, value))
|
||||
return;
|
||||
|
||||
value = HnswNormValue(value, type);
|
||||
}
|
||||
|
||||
HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false);
|
||||
|
||||
@@ -61,7 +61,7 @@ GetScanValue(IndexScanDesc scan)
|
||||
|
||||
/* Fine if normalization fails */
|
||||
if (so->normprocinfo != NULL)
|
||||
HnswNormValue(so->normprocinfo, so->collation, &value, HnswGetType(scan->indexRelation));
|
||||
value = HnswNormValue(value, HnswGetType(scan->indexRelation));
|
||||
}
|
||||
|
||||
return value;
|
||||
|
||||
@@ -193,34 +193,29 @@ HnswGetType(Relation index)
|
||||
}
|
||||
|
||||
/*
|
||||
* Divide by the norm
|
||||
*
|
||||
* Returns false if value should not be indexed
|
||||
*
|
||||
* The caller needs to free the pointer stored in value
|
||||
* if it's different than the original value
|
||||
* Normalize value
|
||||
*/
|
||||
Datum
|
||||
HnswNormValue(Datum value, HnswType type)
|
||||
{
|
||||
/* TODO Remove type-specific code */
|
||||
if (type == HNSW_TYPE_VECTOR)
|
||||
return DirectFunctionCall1(l2_normalize, value);
|
||||
else if (type == HNSW_TYPE_HALFVEC)
|
||||
return DirectFunctionCall1(halfvec_l2_normalize, value);
|
||||
else if (type == HNSW_TYPE_SPARSEVEC)
|
||||
return DirectFunctionCall1(sparsevec_l2_normalize, value);
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
}
|
||||
|
||||
/*
|
||||
* Check if non-zero norm
|
||||
*/
|
||||
bool
|
||||
HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, HnswType type)
|
||||
HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value)
|
||||
{
|
||||
double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value));
|
||||
|
||||
if (norm > 0)
|
||||
{
|
||||
/* TODO Remove type-specific code */
|
||||
if (type == HNSW_TYPE_VECTOR)
|
||||
*value = DirectFunctionCall1(l2_normalize, *value);
|
||||
else if (type == HNSW_TYPE_HALFVEC)
|
||||
*value = DirectFunctionCall1(halfvec_l2_normalize, *value);
|
||||
else if (type == HNSW_TYPE_SPARSEVEC)
|
||||
*value = DirectFunctionCall1(sparsevec_l2_normalize, *value);
|
||||
else
|
||||
elog(ERROR, "Unsupported type");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
return DatumGetFloat8(FunctionCall1Coll(procinfo, collation, value)) > 0;
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
Reference in New Issue
Block a user