diff --git a/src/hnsw.h b/src/hnsw.h index cfba560..add240e 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -116,6 +116,7 @@ typedef struct HnswNeighborArray { int length; HnswCandidate *items; + HnswElement firstPruned; } HnswNeighborArray; typedef struct HnswPairingHeapNode diff --git a/src/hnswutils.c b/src/hnswutils.c index 72cf94e..e25fdfb 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -139,6 +139,7 @@ HnswInitNeighbors(HnswElement element, int m) a = &element->neighbors[lc]; a->length = 0; a->items = palloc(sizeof(HnswCandidate) * lm); + a->firstPruned = NULL; } } @@ -748,11 +749,12 @@ CheckElementCloser(HnswCandidate * e, List *r, int lc, FmgrInfo *procinfo, Oid c * Algorithm 4 from paper */ static List * -SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswCandidate * *pruned) +SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswElement e2, HnswCandidate * newCandidate, HnswCandidate * *pruned) { List *r = NIL; List *w = list_copy(c); pairingheap *wd; + bool mustCalculate = e2->neighbors[lc].firstPruned == NULL; if (list_length(w) <= m) return w; @@ -767,7 +769,13 @@ SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswC w = list_delete_last(w); - closer = CheckElementCloser(e, r, lc, procinfo, collation); + if (!mustCalculate) + mustCalculate = e->element == e2->neighbors[lc].firstPruned || e == newCandidate; + + if (mustCalculate) + closer = CheckElementCloser(e, r, lc, procinfo, collation); + else + closer = true; if (closer) r = lappend(r, e); @@ -775,6 +783,10 @@ SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswC pairingheap_add(wd, &(CreatePairingHeapNode(e)->ph_node)); } + /* Save first pruned */ + if (!pairingheap_is_empty(wd)) + e2->neighbors[lc].firstPruned = ((HnswPairingHeapNode *) pairingheap_first(wd))->inner->element; + /* Keep pruned connections */ while (!pairingheap_is_empty(wd) && list_length(r) < m) r = lappend(r, ((HnswPairingHeapNode *) pairingheap_remove_first(wd))->inner); @@ -909,7 +921,7 @@ HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int c = lappend(c, &hc2); list_sort(c, CompareCandidateDistances); - SelectNeighbors(c, m, lc, procinfo, collation, &pruned); + SelectNeighbors(c, m, lc, procinfo, collation, hc->element, &hc2, &pruned); /* Should not happen */ if (pruned == NULL) @@ -1008,7 +1020,7 @@ HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, F else lw = w; - neighbors = SelectNeighbors(lw, lm, lc, procinfo, collation, NULL); + neighbors = SelectNeighbors(lw, lm, lc, procinfo, collation, element, NULL, NULL); AddConnections(element, neighbors, lm, lc);