Moved sorting logic into SelectNeighbors

This commit is contained in:
Andrew Kane
2023-10-06 12:56:15 -07:00
parent cae162ffc6
commit 8085d3e538

View File

@@ -693,6 +693,34 @@ HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *pro
return w;
}
/*
* Compare candidate distances
*/
static int
#if PG_VERSION_NUM >= 130000
CompareCandidateDistances(const ListCell *a, const ListCell *b)
#else
CompareCandidateDistances(const void *a, const void *b)
#endif
{
HnswCandidate *hca = lfirst((ListCell *) a);
HnswCandidate *hcb = lfirst((ListCell *) b);
if (hca->distance < hcb->distance)
return 1;
if (hca->distance > hcb->distance)
return -1;
if (hca->element < hcb->element)
return 1;
if (hca->element > hcb->element)
return -1;
return 0;
}
/*
* Calculate the distance between elements
*/
@@ -749,7 +777,7 @@ CheckElementCloser(HnswCandidate * e, List *r, int lc, FmgrInfo *procinfo, Oid c
* Algorithm 4 from paper
*/
static List *
SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswElement e2, HnswCandidate * newCandidate, HnswCandidate * *pruned)
SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswElement e2, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool needsSorting)
{
List *r = NIL;
List *w = list_copy(c);
@@ -763,6 +791,10 @@ SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswE
wd = pairingheap_allocate(CompareNearestCandidates, NULL);
/* Ensure order is deterministic for closer caching */
if (needsSorting)
list_sort(w, CompareCandidateDistances);
while (list_length(w) > 0 && list_length(r) < m)
{
/* Assumes w is already ordered desc */
@@ -868,34 +900,6 @@ AddConnections(HnswElement element, List *neighbors, int m, int lc)
a->items[a->length++] = *((HnswCandidate *) lfirst(lc2));
}
/*
* Compare candidate distances
*/
static int
#if PG_VERSION_NUM >= 130000
CompareCandidateDistances(const ListCell *a, const ListCell *b)
#else
CompareCandidateDistances(const void *a, const void *b)
#endif
{
HnswCandidate *hca = lfirst((ListCell *) a);
HnswCandidate *hcb = lfirst((ListCell *) b);
if (hca->distance < hcb->distance)
return 1;
if (hca->distance > hcb->distance)
return -1;
if (hca->element < hcb->element)
return 1;
if (hca->element > hcb->element)
return -1;
return 0;
}
/*
* Update connections
*/
@@ -949,13 +953,12 @@ HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int
{
List *c = NIL;
/* Add and sort candidates */
/* Add candidates */
for (int i = 0; i < currentNeighbors->length; i++)
c = lappend(c, &currentNeighbors->items[i]);
c = lappend(c, &hc2);
list_sort(c, CompareCandidateDistances);
SelectNeighbors(c, m, lc, procinfo, collation, hc->element, &hc2, &pruned);
SelectNeighbors(c, m, lc, procinfo, collation, hc->element, &hc2, &pruned, true);
/* Should not happen */
if (pruned == NULL)
@@ -1054,11 +1057,7 @@ HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, F
else
lw = w;
/* Ensure order is deterministic for SelectNeighbors closer caching */
if (index == NULL)
list_sort(lw, CompareCandidateDistances);
neighbors = SelectNeighbors(lw, lm, lc, procinfo, collation, element, NULL, NULL);
neighbors = SelectNeighbors(lw, lm, lc, procinfo, collation, element, NULL, NULL, index == NULL);
AddConnections(element, neighbors, lm, lc);