Increase ef_search if needed to improve recall

This commit is contained in:
Andrew Kane
2023-11-18 15:44:38 -08:00
parent bacd99b37a
commit e70e582d2f
3 changed files with 41 additions and 11 deletions

View File

@@ -114,6 +114,7 @@ typedef struct HnswCandidate
{
HnswElement element;
float distance;
bool matches;
bool closer;
} HnswCandidate;
@@ -288,7 +289,7 @@ void HnswInitNeighbors(HnswElement element, int m);
bool HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel);
void HnswUpdateNeighborPages(Relation index, FmgrInfo **procinfos, Oid *collations, HnswElement e, int m, bool checkExisting);
void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec, Relation index);
void HnswLoadElement(HnswElement element, float *distance, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool loadVec);
void HnswLoadElement(HnswElement element, float *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool loadVec);
void HnswSetElementTuple(HnswElementTuple etup, HnswElement element, bool useIndexTuple);
void HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int *updateIdx, Relation index, FmgrInfo **procinfos, Oid *collations, bool inMemory);
void HnswLoadNeighbors(HnswElement element, Relation index, int m);

View File

@@ -545,10 +545,13 @@ AttributeDistance(double e)
* Get the distance
*/
static double
GetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations)
GetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool *matches)
{
double g = DatumGetFloat8(FunctionCall2Coll(procinfos[0], collations[0], q, vec));
Assert(PointerIsValid(matches));
*matches = true;
if (IndexRelationGetNumberOfKeyAttributes(index) > 1)
{
double w = 0.25;
@@ -570,7 +573,10 @@ GetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *k
if (isnull || attnull)
{
if (isnull != attnull)
{
e += 1000;
*matches = false;
}
}
else if (!DatumGetBool(FunctionCall2Coll(&key->sk_func, key->sk_collation, value, key->sk_argument)))
{
@@ -581,6 +587,8 @@ GetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *k
else
/* Distance is zero for inequality */
e += 1000;
*matches = false;
}
}
@@ -617,7 +625,7 @@ GetDistance(IndexTuple itup, Datum vec, Datum q, IndexTuple qtup, ScanKeyData *k
* Load an element and optionally get its distance from q
*/
void
HnswLoadElement(HnswElement element, float *distance, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool loadVec)
HnswLoadElement(HnswElement element, float *distance, bool *matches, Datum *q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations, bool loadVec)
{
Buffer buf;
Page page;
@@ -652,7 +660,7 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, IndexTuple qtup,
else
value = PointerGetDatum(&etup->data);
*distance = GetDistance(itup, value, *q, qtup, keyData, index, procinfos, collations);
*distance = GetDistance(itup, value, *q, qtup, keyData, index, procinfos, collations, matches);
}
UnlockReleaseBuffer(buf);
@@ -664,7 +672,7 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, IndexTuple qtup,
static float
GetCandidateDistance(HnswCandidate * hc, Datum q, IndexTuple qtup, ScanKeyData *keyData, Relation index, FmgrInfo **procinfos, Oid *collations)
{
return GetDistance(hc->element->itup, hc->element->value, q, qtup, keyData, index, procinfos, collations);
return GetDistance(hc->element->itup, hc->element->value, q, qtup, keyData, index, procinfos, collations, &hc->matches);
}
/*
@@ -679,7 +687,7 @@ HnswEntryCandidate(HnswElement entryPoint, Datum q, IndexTuple qtup, ScanKeyData
if (inMemory)
hc->distance = GetCandidateDistance(hc, q, qtup, keyData, index, procinfos, collations);
else
HnswLoadElement(hc->element, &hc->distance, &q, qtup, keyData, index, procinfos, collations, loadVec);
HnswLoadElement(hc->element, &hc->distance, &hc->matches, &q, qtup, keyData, index, procinfos, collations, loadVec);
return hc;
}
@@ -754,6 +762,8 @@ HnswSearchLayer(Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef
pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL);
pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL);
int wlen = 0;
uint64 additional = 0;
uint64 maxAdditional = keyData ? 4 * ef : 0;
HASHCTL hash_ctl;
HTAB *v;
@@ -782,6 +792,13 @@ HnswSearchLayer(Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef
pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node));
pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node));
/* Do not count elements that do not match filter towards ef */
if (!hc->matches)
{
if ((++additional) <= maxAdditional)
continue;
}
/*
* Do not count elements being deleted towards ef when vacuuming. It
* would be ideal to do this for inserts as well, but this could
@@ -822,7 +839,7 @@ HnswSearchLayer(Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef
if (inMemory)
eDistance = GetCandidateDistance(e, q, qtup, keyData, index, procinfos, collations);
else
HnswLoadElement(e->element, &eDistance, &q, qtup, keyData, index, procinfos, collations, loadVec);
HnswLoadElement(e->element, &eDistance, &e->matches, &q, qtup, keyData, index, procinfos, collations, loadVec);
Assert(!e->element->deleted);
@@ -848,6 +865,16 @@ HnswSearchLayer(Datum q, IndexTuple qtup, ScanKeyData *keyData, List *ep, int ef
*/
if (skipElement == NULL || list_length(e->element->heaptids) != 0)
{
/*
* Do not count elements that do not match filter
* towards ef
*/
if (!e->matches)
{
if ((++additional) <= maxAdditional)
continue;
}
wlen++;
/* No need to decrement wlen */
@@ -904,6 +931,8 @@ CompareCandidateDistances(const void *a, const void *b)
static float
HnswGetCachedDistance(HnswElement a, HnswElement b, int lc, Relation index, FmgrInfo **procinfos, Oid *collations)
{
bool matches;
/* Look for cached distance */
if (a->neighbors != NULL)
{
@@ -927,7 +956,7 @@ HnswGetCachedDistance(HnswElement a, HnswElement b, int lc, Relation index, Fmgr
}
}
return GetDistance(a->itup, a->value, b->value, b->itup, NULL, index, procinfos, collations);
return GetDistance(a->itup, a->value, b->value, b->itup, NULL, index, procinfos, collations, &matches);
}
/*
@@ -1120,7 +1149,7 @@ HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int
HnswCandidate *hc3 = &currentNeighbors->items[i];
if (DatumGetPointer(hc3->element->value) == NULL)
HnswLoadElement(hc3->element, &hc3->distance, &q, qtup, keyData, index, procinfos, collations, true);
HnswLoadElement(hc3->element, &hc3->distance, &hc3->matches, &q, qtup, keyData, index, procinfos, collations, true);
else
hc3->distance = GetCandidateDistance(hc3, q, qtup, keyData, index, procinfos, collations);

View File

@@ -258,7 +258,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate)
LockPage(index, HNSW_UPDATE_LOCK, ShareLock);
/* Load element */
HnswLoadElement(highestPoint, NULL, NULL, NULL, NULL, index, vacuumstate->procinfos, vacuumstate->collations, true);
HnswLoadElement(highestPoint, NULL, NULL, NULL, NULL, NULL, index, vacuumstate->procinfos, vacuumstate->collations, true);
/* Repair if needed */
if (NeedsUpdated(vacuumstate, highestPoint))
@@ -296,7 +296,7 @@ RepairGraphEntryPoint(HnswVacuumState * vacuumstate)
* is outdated, this can remove connections at higher levels in
* the graph until they are repaired, but this should be fine.
*/
HnswLoadElement(entryPoint, NULL, NULL, NULL, NULL, index, vacuumstate->procinfos, vacuumstate->collations, true);
HnswLoadElement(entryPoint, NULL, NULL, NULL, NULL, NULL, index, vacuumstate->procinfos, vacuumstate->collations, true);
if (NeedsUpdated(vacuumstate, entryPoint))
{