diff --git a/src/hnsw.c b/src/hnsw.c index dd0a558..93c8cc7 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -155,6 +155,15 @@ hnswvalidate(Oid opclassoid) return true; } +/* + * Checks if index-only scan is supported + */ +static bool +hnswcanreturn(Relation indexRelation, int attno) +{ + return attno == 1 && HnswOptionalProcInfo(indexRelation, HNSW_NORM_PROC) == NULL; +} + /* * Define index handler * @@ -196,7 +205,7 @@ hnswhandler(PG_FUNCTION_ARGS) amroutine->aminsert = hnswinsert; amroutine->ambulkdelete = hnswbulkdelete; amroutine->amvacuumcleanup = hnswvacuumcleanup; - amroutine->amcanreturn = NULL; + amroutine->amcanreturn = hnswcanreturn; amroutine->amcostestimate = hnswcostestimate; amroutine->amoptions = hnswoptions; amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */ diff --git a/src/hnsw.h b/src/hnsw.h index 3e8bdc2..599c59a 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -266,7 +266,7 @@ Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state); void HnswInit(void); -List *HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool inserting, HnswElement skipElement); +List *HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, HnswElement skipElement); HnswElement HnswGetEntryPoint(Relation index); HnswElement HnswInitElement(ItemPointer tid, int m, double ml, int maxLevel); void HnswFreeElement(HnswElement element); diff --git a/src/hnswscan.c b/src/hnswscan.c index 6190328..fd8faca 100644 --- a/src/hnswscan.c +++ b/src/hnswscan.c @@ -17,6 +17,7 @@ GetScanItems(IndexScanDesc scan, Datum q) Relation index = scan->indexRelation; FmgrInfo *procinfo = so->procinfo; Oid collation = so->collation; + bool loadVec = scan->xs_want_itup; List *ep; List *w; HnswElement entryPoint = HnswGetEntryPoint(index); @@ -24,15 +25,15 @@ GetScanItems(IndexScanDesc scan, Datum q) if (entryPoint == NULL) return NIL; - ep = list_make1(HnswEntryCandidate(entryPoint, q, index, procinfo, collation, false)); + ep = list_make1(HnswEntryCandidate(entryPoint, q, index, procinfo, collation, loadVec)); for (int lc = entryPoint->level; lc >= 1; lc--) { - w = HnswSearchLayer(q, ep, 1, lc, index, procinfo, collation, false, NULL); + w = HnswSearchLayer(q, ep, 1, lc, index, procinfo, collation, loadVec, NULL); ep = w; } - return HnswSearchLayer(q, ep, hnsw_ef_search, 0, index, procinfo, collation, false, NULL); + return HnswSearchLayer(q, ep, hnsw_ef_search, 0, index, procinfo, collation, loadVec, NULL); } /* @@ -83,6 +84,8 @@ hnswbeginscan(Relation index, int nkeys, int norderbys) scan->opaque = so; + scan->xs_itupdesc = RelationGetDescr(index); + return scan; } @@ -177,6 +180,15 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir) hc->element->heaptids = list_delete_last(hc->element->heaptids); + if (scan->xs_want_itup) + { + Datum value = PointerGetDatum(hc->element->vec); + bool isnull = false; + + scan->xs_itup = index_form_tuple(scan->xs_itupdesc, &value, &isnull); + scan->xs_itup->t_tid = *tid; + } + MemoryContextSwitchTo(oldCtx); #if PG_VERSION_NUM >= 120000 diff --git a/src/hnswutils.c b/src/hnswutils.c index 8e6f2a9..86c4bec 100644 --- a/src/hnswutils.c +++ b/src/hnswutils.c @@ -543,7 +543,7 @@ AddToVisited(HTAB *v, HnswCandidate * hc, Relation index, bool *found) * Algorithm 2 from paper */ List * -HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool inserting, HnswElement skipElement) +HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec, HnswElement skipElement) { ListCell *lc2; @@ -619,7 +619,7 @@ HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *pro if (index == NULL) eDistance = GetCandidateDistance(e, q, procinfo, collation); else - HnswLoadElement(e->element, &eDistance, &q, index, procinfo, collation, inserting); + HnswLoadElement(e->element, &eDistance, &q, index, procinfo, collation, loadVec); Assert(!e->element->deleted);