mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
DRY HNSW distance calculations
This commit is contained in:
@@ -482,6 +482,15 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Calculate the distance between values
|
||||
*/
|
||||
static inline float
|
||||
HnswGetDistance(Datum a, Datum b, FmgrInfo *procinfo, Oid collation)
|
||||
{
|
||||
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, a, b));
|
||||
}
|
||||
|
||||
/*
|
||||
* Load an element and optionally get its distance from q
|
||||
*/
|
||||
@@ -507,7 +516,7 @@ HnswLoadElementImpl(BlockNumber blkno, OffsetNumber offno, double *distance, Dat
|
||||
if (DatumGetPointer(*q) == NULL)
|
||||
*distance = 0;
|
||||
else
|
||||
*distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data)));
|
||||
*distance = HnswGetDistance(*q, PointerGetDatum(&etup->data), procinfo, collation);
|
||||
}
|
||||
|
||||
/* Load element */
|
||||
@@ -539,7 +548,7 @@ GetElementDistance(char *base, HnswElement element, Datum q, FmgrInfo *procinfo,
|
||||
{
|
||||
Datum value = HnswGetValue(base, element);
|
||||
|
||||
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value));
|
||||
return HnswGetDistance(q, value, procinfo, collation);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -921,18 +930,6 @@ CompareCandidateDistancesOffset(const ListCell *a, const ListCell *b)
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Calculate the distance between elements
|
||||
*/
|
||||
static float
|
||||
HnswGetDistance(char *base, HnswElement a, HnswElement b, FmgrInfo *procinfo, Oid collation)
|
||||
{
|
||||
Datum aValue = HnswGetValue(base, a);
|
||||
Datum bValue = HnswGetValue(base, b);
|
||||
|
||||
return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, aValue, bValue));
|
||||
}
|
||||
|
||||
/*
|
||||
* Check if an element is closer to q than any element from R
|
||||
*/
|
||||
@@ -940,13 +937,15 @@ static bool
|
||||
CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, Oid collation)
|
||||
{
|
||||
HnswElement eElement = HnswPtrAccess(base, e->element);
|
||||
Datum eValue = HnswGetValue(base, eElement);
|
||||
ListCell *lc2;
|
||||
|
||||
foreach(lc2, r)
|
||||
{
|
||||
HnswCandidate *ri = lfirst(lc2);
|
||||
HnswElement riElement = HnswPtrAccess(base, ri->element);
|
||||
float distance = HnswGetDistance(base, eElement, riElement, procinfo, collation);
|
||||
Datum riValue = HnswGetValue(base, riElement);
|
||||
float distance = HnswGetDistance(eValue, riValue, procinfo, collation);
|
||||
|
||||
if (distance <= e->distance)
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user