diff --git a/src/hnsw.h b/src/hnsw.h index 232fa3c..5aa98b7 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -114,6 +114,7 @@ typedef struct HnswCandidate { HnswElement element; float distance; + bool matches; bool closer; } HnswCandidate; @@ -288,7 +289,7 @@ void HnswInitNeighbors(HnswElement element, int m); bool HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel); void HnswUpdateNeighborPages(Relation index, FmgrInfo **procinfos, Oid *collations, HnswElement e, int m, bool checkExisting); void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index); -void HnswLoadElement(HnswElement element, float *distance, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool loadVec); +void HnswLoadElement(HnswElement element, float *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool loadVec); void HnswSetElementTuple(HnswElementTuple etup, HnswElement element, bool useIndexTuple); void HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int *updateIdx, Relation index, FmgrInfo **procinfos, Oid *collations, bool inMemory); void HnswLoadNeighbors(HnswElement element, Relation index, int m); diff --git a/src/hnswutils.c b/src/hnswutils.c index 6cfdab0..e338e32 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -545,10 +545,13 @@ AttributeDistance(double e) * Get the distance */ static double -GetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations) +GetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool *matches) { double g = DatumGetFloat8(FunctionCall2Coll(procinfos[0], collations[0], q, vec)); + Assert(PointerIsValid(matches)); + *matches = true; + if (IndexRelationGetNumberOfKeyAttributes(index) > 1) { double w = 0.25; @@ -570,7 +573,10 @@ GetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *k if (isnull || attnull) { if (isnull != attnull) + { e += 1000; + *matches = false; + } } else if (!DatumGetBool(FunctionCall2Coll(&key->sk_func, key->sk_collation, value, key->sk_argument))) { @@ -581,6 +587,8 @@ GetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *k else /* Distance is zero for inequality */ e += 1000; + + *matches = false; } } @@ -617,7 +625,7 @@ GetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *k * Load an element and optionally get its distance from q */ void -HnswLoadElement(HnswElement element, float *distance, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool loadVec) +HnswLoadElement(HnswElement element, float *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool loadVec) { Buffer buf; Page page; @@ -652,7 +660,7 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, IndexTuple qtup, else value = PointerGetDatum(&etup->data); - *distance = GetDistance(itup, value, *q, qtup, keyData, index, procinfos, collations); + *distance = GetDistance(itup, value, *q, qtup, keyData, index, procinfos, collations, matches); } UnlockReleaseBuffer(buf); @@ -664,7 +672,7 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, IndexTuple qtup, static float GetCandidateDistance(HnswCandidate * hc, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations) { - return GetDistance(hc->element->itup, hc->element->value, q, qtup, keyData, index, procinfos, collations); + return GetDistance(hc->element->itup, hc->element->value, q, qtup, keyData, index, procinfos, collations, &hc->matches); } /* @@ -679,7 +687,7 @@ HnswEntryCandidate(HnswElement entryPoint, Datum q, IndexTuple qtup, ScanKeyData if (inMemory) hc->distance = GetCandidateDistance(hc, q, qtup, keyData, index, procinfos, collations); else - HnswLoadElement(hc->element, &hc->distance, &q, qtup, keyData, index, procinfos, collations, loadVec); + HnswLoadElement(hc->element, &hc->distance, &hc->matches, &q, qtup, keyData, index, procinfos, collations, loadVec); return hc; } @@ -754,6 +762,8 @@ HnswSearchLayer(Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL); int wlen = 0; + uint64 additional = 0; + uint64 maxAdditional = keyData ? 4 * ef : 0; HASHCTL hash_ctl; HTAB *v; @@ -782,6 +792,13 @@ HnswSearchLayer(Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node)); pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node)); + /* Do not count elements that do not match filter towards ef */ + if (!hc->matches) + { + if ((++additional) <= maxAdditional) + continue; + } + /* * Do not count elements being deleted towards ef when vacuuming. It * would be ideal to do this for inserts as well, but this could @@ -822,7 +839,7 @@ HnswSearchLayer(Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef if (inMemory) eDistance = GetCandidateDistance(e, q, qtup, keyData, index, procinfos, collations); else - HnswLoadElement(e->element, &eDistance, &q, qtup, keyData, index, procinfos, collations, loadVec); + HnswLoadElement(e->element, &eDistance, &e->matches, &q, qtup, keyData, index, procinfos, collations, loadVec); Assert(!e->element->deleted); @@ -848,6 +865,16 @@ HnswSearchLayer(Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef */ if (skipElement == NULL || list_length(e->element->heaptids) != 0) { + /* + * Do not count elements that do not match filter + * towards ef + */ + if (!e->matches) + { + if ((++additional) <= maxAdditional) + continue; + } + wlen++; /* No need to decrement wlen */ @@ -904,6 +931,8 @@ CompareCandidateDistances(const void *a, const void *b) static float HnswGetCachedDistance(HnswElement a, HnswElement b, int lc, Relation index, FmgrInfo **procinfos, Oid *collations) { + bool matches; + /* Look for cached distance */ if (a->neighbors != NULL) { @@ -927,7 +956,7 @@ HnswGetCachedDistance(HnswElement a, HnswElement b, int lc, Relation index, Fmgr } } - return GetDistance(a->itup, a->value, b->value, b->itup, NULL, index, procinfos, collations); + return GetDistance(a->itup, a->value, b->value, b->itup, NULL, index, procinfos, collations, &matches); } /* @@ -1120,7 +1149,7 @@ HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int HnswCandidate *hc3 = ¤tNeighbors->items[i]; if (DatumGetPointer(hc3->element->value) == NULL) - HnswLoadElement(hc3->element, &hc3->distance, &q, qtup, keyData, index, procinfos, collations, true); + HnswLoadElement(hc3->element, &hc3->distance, &hc3->matches, &q, qtup, keyData, index, procinfos, collations, true); else hc3->distance = GetCandidateDistance(hc3, q, qtup, keyData, index, procinfos, collations); diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c index ce05643..c84ab17 100644 --- a/src/hnswvacuum.c +++ b/src/hnswvacuum.c @@ -258,7 +258,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) LockPage(index, HNSW_UPDATE_LOCK, ShareLock); /* Load element */ - HnswLoadElement(highestPoint, NULL, NULL, NULL, NULL, index, vacuumstate->procinfos, vacuumstate->collations, true); + HnswLoadElement(highestPoint, NULL, NULL, NULL, NULL, NULL, index, vacuumstate->procinfos, vacuumstate->collations, true); /* Repair if needed */ if (NeedsUpdated(vacuumstate, highestPoint)) @@ -296,7 +296,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate) * is outdated, this can remove connections at higher levels in * the graph until they are repaired, but this should be fine. */ - HnswLoadElement(entryPoint, NULL, NULL, NULL, NULL, index, vacuumstate->procinfos, vacuumstate->collations, true); + HnswLoadElement(entryPoint, NULL, NULL, NULL, NULL, NULL, index, vacuumstate->procinfos, vacuumstate->collations, true); if (NeedsUpdated(vacuumstate, entryPoint)) {