mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Added streaming option for HNSW
This commit is contained in:
10
src/hnsw.c
10
src/hnsw.c
@@ -18,6 +18,7 @@
|
||||
#endif
|
||||
|
||||
int hnsw_ef_search;
|
||||
bool hnsw_streaming;
|
||||
int hnsw_lock_tranche_id;
|
||||
static relopt_kind hnsw_relopt_kind;
|
||||
|
||||
@@ -68,6 +69,13 @@ HnswInit(void)
|
||||
"Valid range is 1..1000.", &hnsw_ef_search,
|
||||
HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL);
|
||||
|
||||
/* TODO Figure out name */
|
||||
DefineCustomBoolVariable("hnsw.streaming", "Use streaming mode",
|
||||
NULL, &hnsw_streaming,
|
||||
HNSW_DEFAULT_STREAMING, PGC_USERSET, 0, NULL, NULL, NULL);
|
||||
|
||||
/* TODO Add option for limiting iterative search */
|
||||
|
||||
MarkGUCPrefixReserved("hnsw");
|
||||
}
|
||||
|
||||
@@ -126,6 +134,8 @@ hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count,
|
||||
/* Account for number of tuples (or entry level), m, and ef_search */
|
||||
costs.numIndexTuples = (entryLevel + 2) * m;
|
||||
|
||||
/* TODO Adjust for selectivity for iterative scans */
|
||||
|
||||
genericcostestimate(root, path, loop_count, &costs);
|
||||
|
||||
/* Use total cost since most work happens before first tuple is returned */
|
||||
|
||||
45
src/hnsw.h
45
src/hnsw.h
@@ -12,6 +12,10 @@
|
||||
#include "utils/sampling.h"
|
||||
#include "vector.h"
|
||||
|
||||
#ifdef HNSW_BENCH
|
||||
#include "portability/instr_time.h"
|
||||
#endif
|
||||
|
||||
#define HNSW_MAX_DIM 2000
|
||||
#define HNSW_MAX_NNZ 1000
|
||||
|
||||
@@ -42,6 +46,7 @@
|
||||
#define HNSW_DEFAULT_EF_SEARCH 40
|
||||
#define HNSW_MIN_EF_SEARCH 1
|
||||
#define HNSW_MAX_EF_SEARCH 1000
|
||||
#define HNSW_DEFAULT_STREAMING false
|
||||
|
||||
/* Tuple types */
|
||||
#define HNSW_ELEMENT_TUPLE_TYPE 1
|
||||
@@ -68,6 +73,21 @@
|
||||
#define HnswPageGetOpaque(page) ((HnswPageOpaque) PageGetSpecialPointer(page))
|
||||
#define HnswPageGetMeta(page) ((HnswMetaPageData *) PageGetContents(page))
|
||||
|
||||
#ifdef HNSW_BENCH
|
||||
#define HnswBench(name, code) \
|
||||
do { \
|
||||
instr_time start; \
|
||||
instr_time duration; \
|
||||
INSTR_TIME_SET_CURRENT(start); \
|
||||
(code); \
|
||||
INSTR_TIME_SET_CURRENT(duration); \
|
||||
INSTR_TIME_SUBTRACT(duration, start); \
|
||||
elog(INFO, "%s: %.3f ms", name, INSTR_TIME_GET_MILLISEC(duration)); \
|
||||
} while (0)
|
||||
#else
|
||||
#define HnswBench(name, code) (code)
|
||||
#endif
|
||||
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
#define RandomDouble() pg_prng_double(&pg_global_prng_state)
|
||||
#define SeedRandom(seed) pg_prng_seed(&pg_global_prng_state, seed)
|
||||
@@ -106,6 +126,7 @@
|
||||
|
||||
/* Variables */
|
||||
extern int hnsw_ef_search;
|
||||
extern bool hnsw_streaming;
|
||||
extern int hnsw_lock_tranche_id;
|
||||
|
||||
typedef struct HnswElementData HnswElementData;
|
||||
@@ -129,6 +150,7 @@ struct HnswElementData
|
||||
uint8 heaptidsLength;
|
||||
uint8 level;
|
||||
uint8 deleted;
|
||||
uint8 version;
|
||||
uint32 hash;
|
||||
HnswNeighborsPtr neighbors;
|
||||
BlockNumber blkno;
|
||||
@@ -163,6 +185,9 @@ typedef struct HnswSearchCandidate
|
||||
float distance;
|
||||
} HnswSearchCandidate;
|
||||
|
||||
#define HnswGetSearchCandidate(membername, ptr) pairingheap_container(HnswSearchCandidate, membername, ptr)
|
||||
#define HnswGetSearchCandidateConst(membername, ptr) pairingheap_const_container(HnswSearchCandidate, membername, ptr)
|
||||
|
||||
/* HNSW index options */
|
||||
typedef struct HnswOptions
|
||||
{
|
||||
@@ -306,10 +331,10 @@ typedef struct HnswElementTupleData
|
||||
uint8 type;
|
||||
uint8 level;
|
||||
uint8 deleted;
|
||||
uint8 unused;
|
||||
uint8 version;
|
||||
ItemPointerData heaptids[HNSW_HEAPTIDS];
|
||||
ItemPointerData neighbortid;
|
||||
uint16 unused2;
|
||||
uint16 unused;
|
||||
Vector data;
|
||||
} HnswElementTupleData;
|
||||
|
||||
@@ -318,18 +343,30 @@ typedef HnswElementTupleData * HnswElementTuple;
|
||||
typedef struct HnswNeighborTupleData
|
||||
{
|
||||
uint8 type;
|
||||
uint8 unused;
|
||||
uint8 version;
|
||||
uint16 count;
|
||||
ItemPointerData indextids[FLEXIBLE_ARRAY_MEMBER];
|
||||
} HnswNeighborTupleData;
|
||||
|
||||
typedef HnswNeighborTupleData * HnswNeighborTuple;
|
||||
|
||||
typedef union
|
||||
{
|
||||
struct pointerhash_hash *pointers;
|
||||
struct offsethash_hash *offsets;
|
||||
struct tidhash_hash *tids;
|
||||
} visited_hash;
|
||||
|
||||
typedef struct HnswScanOpaqueData
|
||||
{
|
||||
const HnswTypeInfo *typeInfo;
|
||||
bool first;
|
||||
List *w;
|
||||
visited_hash v;
|
||||
pairingheap *discarded;
|
||||
Datum q;
|
||||
int m;
|
||||
int64 tuples;
|
||||
MemoryContext tmpCtx;
|
||||
|
||||
/* Support functions */
|
||||
@@ -375,7 +412,7 @@ bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
|
||||
void HnswInitPage(Buffer buf, Page page);
|
||||
void HnswInit(void);
|
||||
List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement);
|
||||
List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited);
|
||||
HnswElement HnswGetEntryPoint(Relation index);
|
||||
void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint);
|
||||
void *HnswAlloc(HnswAllocator * allocator, Size size);
|
||||
|
||||
@@ -36,7 +36,7 @@ GetInsertPage(Relation index)
|
||||
* Check for a free offset
|
||||
*/
|
||||
static bool
|
||||
HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size etupSize, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage)
|
||||
HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size etupSize, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage, uint8 *tupleVersion)
|
||||
{
|
||||
OffsetNumber offno;
|
||||
OffsetNumber maxoffno = PageGetMaxOffsetNumber(page);
|
||||
@@ -98,6 +98,7 @@ HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size
|
||||
{
|
||||
*freeOffno = offno;
|
||||
*freeNeighborOffno = neighborOffno;
|
||||
*tupleVersion = etup->version;
|
||||
return true;
|
||||
}
|
||||
else if (*nbuf != buf)
|
||||
@@ -153,6 +154,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B
|
||||
OffsetNumber freeOffno = InvalidOffsetNumber;
|
||||
OffsetNumber freeNeighborOffno = InvalidOffsetNumber;
|
||||
BlockNumber newInsertPage = InvalidBlockNumber;
|
||||
uint8 tupleVersion;
|
||||
char *base = NULL;
|
||||
|
||||
/* Calculate sizes */
|
||||
@@ -202,7 +204,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B
|
||||
}
|
||||
|
||||
/* Next, try space from a deleted element */
|
||||
if (HnswFreeOffset(index, buf, page, e, etupSize, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage))
|
||||
if (HnswFreeOffset(index, buf, page, e, etupSize, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage, &tupleVersion))
|
||||
{
|
||||
if (nbuf != buf)
|
||||
{
|
||||
@@ -212,6 +214,10 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B
|
||||
npage = GenericXLogRegisterBuffer(state, nbuf, 0);
|
||||
}
|
||||
|
||||
/* Set tuple version */
|
||||
etup->version = tupleVersion;
|
||||
ntup->version = tupleVersion;
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
113
src/hnswscan.c
113
src/hnswscan.c
@@ -26,6 +26,9 @@ GetScanItems(IndexScanDesc scan, Datum q)
|
||||
/* Get m and entry point */
|
||||
HnswGetMetaPageInfo(index, &m, &entryPoint);
|
||||
|
||||
so->q = q;
|
||||
so->m = m;
|
||||
|
||||
if (entryPoint == NULL)
|
||||
return NIL;
|
||||
|
||||
@@ -33,11 +36,44 @@ GetScanItems(IndexScanDesc scan, Datum q)
|
||||
|
||||
for (int lc = entryPoint->level; lc >= 1; lc--)
|
||||
{
|
||||
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL);
|
||||
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL, NULL, NULL, true);
|
||||
ep = w;
|
||||
}
|
||||
|
||||
return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL);
|
||||
return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL, &so->v, &so->discarded, true);
|
||||
}
|
||||
|
||||
/*
|
||||
* Resume scan at ground level with discarded candidates
|
||||
*/
|
||||
static List *
|
||||
ResumeScanItems(IndexScanDesc scan)
|
||||
{
|
||||
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
|
||||
Relation index = scan->indexRelation;
|
||||
FmgrInfo *procinfo = so->procinfo;
|
||||
Oid collation = so->collation;
|
||||
List *ep = NIL;
|
||||
char *base = NULL;
|
||||
int batch_size = hnsw_ef_search;
|
||||
|
||||
if (pairingheap_is_empty(so->discarded))
|
||||
return NIL;
|
||||
|
||||
/* Get next batch of candidates */
|
||||
for (int i = 0; i < batch_size; i++)
|
||||
{
|
||||
HnswSearchCandidate *hc;
|
||||
|
||||
if (pairingheap_is_empty(so->discarded))
|
||||
break;
|
||||
|
||||
hc = HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded));
|
||||
|
||||
ep = lappend(ep, hc);
|
||||
}
|
||||
|
||||
return HnswSearchLayer(base, so->q, ep, batch_size, 0, index, procinfo, collation, so->m, false, NULL, &so->v, &so->discarded, false);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -103,7 +139,13 @@ hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int no
|
||||
{
|
||||
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
|
||||
|
||||
if (!so->first)
|
||||
{
|
||||
pairingheap_reset(so->discarded);
|
||||
tidhash_reset(so->v.tids);
|
||||
}
|
||||
so->first = true;
|
||||
so->tuples = 0;
|
||||
MemoryContextReset(so->tmpCtx);
|
||||
|
||||
if (keys && scan->numberOfKeys > 0)
|
||||
@@ -153,7 +195,7 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir)
|
||||
*/
|
||||
LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock);
|
||||
|
||||
so->w = GetScanItems(scan, value);
|
||||
HnswBench("scan iteration", so->w = GetScanItems(scan, value));
|
||||
|
||||
/* Release shared lock */
|
||||
UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock);
|
||||
@@ -165,20 +207,79 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir)
|
||||
#endif
|
||||
}
|
||||
|
||||
while (list_length(so->w) > 0)
|
||||
for (;;)
|
||||
{
|
||||
char *base = NULL;
|
||||
HnswSearchCandidate *hc = llast(so->w);
|
||||
HnswElement element = HnswPtrAccess(base, hc->element);
|
||||
HnswSearchCandidate *hc;
|
||||
HnswElement element;
|
||||
ItemPointer heaptid;
|
||||
|
||||
if (list_length(so->w) == 0)
|
||||
{
|
||||
if (!hnsw_streaming)
|
||||
break;
|
||||
|
||||
/* Prevent scans from consuming too much memory */
|
||||
if (MemoryContextMemAllocated(so->tmpCtx, false) > (Size) work_mem * 1024L)
|
||||
{
|
||||
if (pairingheap_is_empty(so->discarded))
|
||||
{
|
||||
ereport(NOTICE,
|
||||
(errmsg("hnsw iterative search exceeded work_mem after " INT64_FORMAT " tuples", so->tuples),
|
||||
errhint("Increase work_mem to scan more tuples.")));
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
/* Return remaining tuples */
|
||||
so->w = lappend(so->w, HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded)));
|
||||
}
|
||||
else
|
||||
{
|
||||
/*
|
||||
* Locking ensures when neighbors are read, the elements they
|
||||
* reference will not be deleted (and replaced) during the
|
||||
* iteration.
|
||||
*
|
||||
* Elements loaded into memory on previous iterations may have
|
||||
* been deleted (and replaced), so when reading neighbors, the
|
||||
* element version must be checked.
|
||||
*/
|
||||
LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock);
|
||||
|
||||
HnswBench("scan iteration", so->w = ResumeScanItems(scan));
|
||||
|
||||
UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock);
|
||||
|
||||
#if defined(HNSW_MEMORY)
|
||||
elog(INFO, "memory: %zu KB", MemoryContextMemAllocated(so->tmpCtx, false) / 1024);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (list_length(so->w) == 0)
|
||||
break;
|
||||
}
|
||||
|
||||
hc = llast(so->w);
|
||||
element = HnswPtrAccess(base, hc->element);
|
||||
|
||||
/* Move to next element if no valid heap TIDs */
|
||||
if (element->heaptidsLength == 0)
|
||||
{
|
||||
so->w = list_delete_last(so->w);
|
||||
|
||||
/* Mark memory as free for next iteration */
|
||||
if (hnsw_streaming)
|
||||
{
|
||||
pfree(element);
|
||||
pfree(hc);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
so->tuples++;
|
||||
|
||||
heaptid = &element->heaptids[--element->heaptidsLength];
|
||||
|
||||
MemoryContextSwitchTo(oldCtx);
|
||||
|
||||
102
src/hnswutils.c
102
src/hnswutils.c
@@ -100,13 +100,6 @@ hash_offset(Size offset)
|
||||
#define SH_DEFINE
|
||||
#include "lib/simplehash.h"
|
||||
|
||||
typedef union
|
||||
{
|
||||
pointerhash_hash *pointers;
|
||||
offsethash_hash *offsets;
|
||||
tidhash_hash *tids;
|
||||
} visited_hash;
|
||||
|
||||
typedef union
|
||||
{
|
||||
HnswElement element;
|
||||
@@ -253,6 +246,8 @@ HnswInitElement(char *base, ItemPointer heaptid, int m, double ml, int maxLevel,
|
||||
|
||||
element->level = level;
|
||||
element->deleted = 0;
|
||||
/* Start at one to make it easier to find issues */
|
||||
element->version = 1;
|
||||
|
||||
HnswInitNeighbors(base, element, m, allocator);
|
||||
|
||||
@@ -405,6 +400,7 @@ HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element)
|
||||
etup->type = HNSW_ELEMENT_TUPLE_TYPE;
|
||||
etup->level = element->level;
|
||||
etup->deleted = 0;
|
||||
etup->version = element->version;
|
||||
for (int i = 0; i < HNSW_HEAPTIDS; i++)
|
||||
{
|
||||
if (i < element->heaptidsLength)
|
||||
@@ -447,6 +443,7 @@ HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m)
|
||||
}
|
||||
|
||||
ntup->count = idx;
|
||||
ntup->version = e->version;
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -520,6 +517,7 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe
|
||||
{
|
||||
element->level = etup->level;
|
||||
element->deleted = etup->deleted;
|
||||
element->version = etup->version;
|
||||
element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid);
|
||||
element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid);
|
||||
element->heaptidsLength = 0;
|
||||
@@ -621,9 +619,6 @@ HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index,
|
||||
return hc;
|
||||
}
|
||||
|
||||
#define HnswGetSearchCandidate(membername, ptr) pairingheap_container(HnswSearchCandidate, membername, ptr)
|
||||
#define HnswGetSearchCandidateConst(membername, ptr) pairingheap_const_container(HnswSearchCandidate, membername, ptr)
|
||||
|
||||
/*
|
||||
* Compare candidate distances
|
||||
*/
|
||||
@@ -639,6 +634,21 @@ CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, v
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Compare discarded candidate distances
|
||||
*/
|
||||
static int
|
||||
CompareNearestDiscardedCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg)
|
||||
{
|
||||
if (HnswGetSearchCandidateConst(w_node, a)->distance < HnswGetSearchCandidateConst(w_node, b)->distance)
|
||||
return 1;
|
||||
|
||||
if (HnswGetSearchCandidateConst(w_node, a)->distance > HnswGetSearchCandidateConst(w_node, b)->distance)
|
||||
return -1;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Compare candidate distances
|
||||
*/
|
||||
@@ -754,20 +764,30 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u
|
||||
int start;
|
||||
ItemPointerData indextids[HNSW_MAX_M * 2];
|
||||
|
||||
*unvisitedLength = 0;
|
||||
|
||||
buf = ReadBuffer(index, element->neighborPage);
|
||||
LockBuffer(buf, BUFFER_LOCK_SHARE);
|
||||
page = BufferGetPage(buf);
|
||||
|
||||
ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno));
|
||||
start = (element->level - lc) * m;
|
||||
|
||||
/*
|
||||
* Ensure the neighbor tuple has not been deleted or replaced between
|
||||
* index scan iterations
|
||||
*/
|
||||
if (ntup->version != element->version)
|
||||
{
|
||||
UnlockReleaseBuffer(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
/* Copy to minimize lock time */
|
||||
start = (element->level - lc) * m;
|
||||
memcpy(&indextids, ntup->indextids + start, lm * sizeof(ItemPointerData));
|
||||
|
||||
UnlockReleaseBuffer(buf);
|
||||
|
||||
*unvisitedLength = 0;
|
||||
|
||||
for (int i = 0; i < lm; i++)
|
||||
{
|
||||
ItemPointer indextid = &indextids[i];
|
||||
@@ -787,13 +807,13 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u
|
||||
* Algorithm 2 from paper
|
||||
*/
|
||||
List *
|
||||
HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement)
|
||||
HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited)
|
||||
{
|
||||
List *w = NIL;
|
||||
pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL);
|
||||
pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL);
|
||||
int wlen = 0;
|
||||
visited_hash v;
|
||||
visited_hash vh;
|
||||
ListCell *lc2;
|
||||
HnswNeighborArray *localNeighborhood = NULL;
|
||||
Size neighborhoodSize = 0;
|
||||
@@ -801,7 +821,19 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
|
||||
HnswUnvisited *unvisited = palloc(lm * sizeof(HnswUnvisited));
|
||||
int unvisitedLength;
|
||||
|
||||
InitVisited(base, &v, index, ef, m);
|
||||
if (v == NULL)
|
||||
{
|
||||
v = &vh;
|
||||
initVisited = true;
|
||||
}
|
||||
|
||||
if (initVisited)
|
||||
{
|
||||
InitVisited(base, v, index, ef, m);
|
||||
|
||||
if (discarded != NULL)
|
||||
*discarded = pairingheap_allocate(CompareNearestDiscardedCandidates, NULL);
|
||||
}
|
||||
|
||||
/* Create local memory for neighborhood if needed */
|
||||
if (index == NULL)
|
||||
@@ -816,7 +848,8 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
|
||||
HnswSearchCandidate *hc = (HnswSearchCandidate *) lfirst(lc2);
|
||||
bool found;
|
||||
|
||||
AddToVisited(base, &v, hc->element, index, &found);
|
||||
if (initVisited)
|
||||
AddToVisited(base, v, hc->element, index, &found);
|
||||
|
||||
pairingheap_add(C, &hc->c_node);
|
||||
pairingheap_add(W, &hc->w_node);
|
||||
@@ -842,9 +875,9 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
|
||||
cElement = HnswPtrAccess(base, c->element);
|
||||
|
||||
if (index == NULL)
|
||||
HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, &v, lc, localNeighborhood, neighborhoodSize);
|
||||
HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, v, lc, localNeighborhood, neighborhoodSize);
|
||||
else
|
||||
HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, &v, index, m, lm, lc);
|
||||
HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, v, index, m, lm, lc);
|
||||
|
||||
for (int i = 0; i < unvisitedLength; i++)
|
||||
{
|
||||
@@ -868,16 +901,22 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
|
||||
|
||||
/* Avoid any allocations if not adding */
|
||||
eElement = NULL;
|
||||
HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance, &eElement);
|
||||
|
||||
if (eElement == NULL)
|
||||
continue;
|
||||
HnswLoadElementImpl(blkno, offno, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance, &eElement);
|
||||
}
|
||||
|
||||
if (!(eDistance < f->distance || alwaysAdd))
|
||||
continue;
|
||||
if (eElement == NULL || !(eDistance < f->distance || alwaysAdd))
|
||||
{
|
||||
if (discarded != NULL)
|
||||
{
|
||||
/* Create a new candidate */
|
||||
e = palloc(sizeof(HnswSearchCandidate));
|
||||
HnswPtrStore(base, e->element, eElement);
|
||||
e->distance = eDistance;
|
||||
pairingheap_add(*discarded, &e->w_node);
|
||||
}
|
||||
|
||||
Assert(!eElement->deleted);
|
||||
continue;
|
||||
}
|
||||
|
||||
/* Make robust to issues */
|
||||
if (eElement->level < lc)
|
||||
@@ -901,7 +940,12 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
|
||||
|
||||
/* No need to decrement wlen */
|
||||
if (wlen > ef)
|
||||
pairingheap_remove_first(W);
|
||||
{
|
||||
HnswSearchCandidate *d = HnswGetSearchCandidate(w_node, pairingheap_remove_first(W));
|
||||
|
||||
if (discarded != NULL)
|
||||
pairingheap_add(*discarded, &d->w_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1274,7 +1318,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
|
||||
/* 1st phase: greedy search to insert level */
|
||||
for (int lc = entryLevel; lc >= level + 1; lc--)
|
||||
{
|
||||
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement);
|
||||
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true);
|
||||
ep = w;
|
||||
}
|
||||
|
||||
@@ -1293,7 +1337,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
|
||||
List *lw = NIL;
|
||||
ListCell *lc2;
|
||||
|
||||
w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement);
|
||||
w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true);
|
||||
|
||||
/* Convert search candidates to candidates */
|
||||
foreach(lc2, w)
|
||||
|
||||
@@ -527,6 +527,14 @@ MarkDeleted(HnswVacuumState * vacuumstate)
|
||||
for (int i = 0; i < ntup->count; i++)
|
||||
ItemPointerSetInvalid(&ntup->indextids[i]);
|
||||
|
||||
/* Increment version */
|
||||
/* This is used to avoid incorrect reads for iterative scans */
|
||||
/* Reserve some bits for future use */
|
||||
etup->version++;
|
||||
if (etup->version > 15)
|
||||
etup->version = 1;
|
||||
ntup->version = etup->version;
|
||||
|
||||
/*
|
||||
* We modified the tuples in place, no need to call
|
||||
* PageIndexTupleOverwrite
|
||||
|
||||
43
test/t/039_hnsw_streaming.pl
Normal file
43
test/t/039_hnsw_streaming.pl
Normal file
@@ -0,0 +1,43 @@
|
||||
use strict;
|
||||
use warnings FATAL => 'all';
|
||||
use PostgreSQL::Test::Cluster;
|
||||
use PostgreSQL::Test::Utils;
|
||||
use Test::More;
|
||||
|
||||
my $dim = 3;
|
||||
my $array_sql = join(",", ('random()') x $dim);
|
||||
|
||||
# Initialize node
|
||||
my $node = PostgreSQL::Test::Cluster->new('node');
|
||||
$node->init;
|
||||
$node->start;
|
||||
|
||||
# Create table
|
||||
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
|
||||
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));");
|
||||
$node->safe_psql("postgres",
|
||||
"INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;"
|
||||
);
|
||||
$node->safe_psql("postgres", qq(
|
||||
SET maintenance_work_mem = '128MB';
|
||||
SET max_parallel_maintenance_workers = 2;
|
||||
CREATE INDEX ON tst USING hnsw (v vector_l2_ops)
|
||||
));
|
||||
|
||||
my $count = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SET hnsw.streaming = on;
|
||||
SET work_mem = '8MB';
|
||||
SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t;
|
||||
));
|
||||
is($count, 10);
|
||||
|
||||
my ($ret, $stdout, $stderr) = $node->psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SET hnsw.streaming = on;
|
||||
SET work_mem = '2MB';
|
||||
SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t;
|
||||
));
|
||||
like($stderr, qr/iterative search exceeded work_mem after \d+ tuples/);
|
||||
|
||||
done_testing();
|
||||
131
test/t/040_hnsw_streaming_recall.pl
Normal file
131
test/t/040_hnsw_streaming_recall.pl
Normal file
@@ -0,0 +1,131 @@
|
||||
use strict;
|
||||
use warnings FATAL => 'all';
|
||||
use PostgreSQL::Test::Cluster;
|
||||
use PostgreSQL::Test::Utils;
|
||||
use Test::More;
|
||||
|
||||
my $node;
|
||||
my @queries = ();
|
||||
my @expected;
|
||||
my $limit = 20;
|
||||
my $dim = 3;
|
||||
my $array_sql = join(",", ('random()') x $dim);
|
||||
my @cs = (100, 1000);
|
||||
|
||||
sub test_recall
|
||||
{
|
||||
my ($c, $ef_search, $min, $operator) = @_;
|
||||
my $correct = 0;
|
||||
my $total = 0;
|
||||
|
||||
my $explain = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SET hnsw.ef_search = $ef_search;
|
||||
SET hnsw.streaming = on;
|
||||
EXPLAIN ANALYZE SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[0]' LIMIT $limit;
|
||||
));
|
||||
like($explain, qr/Index Scan using idx on tst/);
|
||||
|
||||
for my $i (0 .. $#queries)
|
||||
{
|
||||
my $actual = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SET hnsw.ef_search = $ef_search;
|
||||
SET hnsw.streaming = on;
|
||||
SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[$i]' LIMIT $limit;
|
||||
));
|
||||
my @actual_ids = split("\n", $actual);
|
||||
|
||||
my @expected_ids = split("\n", $expected[$i]);
|
||||
my %expected_set = map { $_ => 1 } @expected_ids;
|
||||
|
||||
foreach (@actual_ids)
|
||||
{
|
||||
if (exists($expected_set{$_}))
|
||||
{
|
||||
$correct++;
|
||||
}
|
||||
}
|
||||
|
||||
$total += $limit;
|
||||
}
|
||||
|
||||
cmp_ok($correct / $total, ">=", $min, $operator);
|
||||
}
|
||||
|
||||
# Initialize node
|
||||
$node = PostgreSQL::Test::Cluster->new('node');
|
||||
$node->init;
|
||||
$node->start;
|
||||
|
||||
# Create table
|
||||
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
|
||||
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));");
|
||||
$node->safe_psql("postgres",
|
||||
"INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;"
|
||||
);
|
||||
|
||||
# Generate queries
|
||||
for (1 .. 20)
|
||||
{
|
||||
my @r = ();
|
||||
for (1 .. $dim)
|
||||
{
|
||||
push(@r, rand());
|
||||
}
|
||||
push(@queries, "[" . join(",", @r) . "]");
|
||||
}
|
||||
|
||||
# Check each index type
|
||||
my @operators = ("<->", "<=>");
|
||||
my @opclasses = ("vector_l2_ops", "vector_cosine_ops");
|
||||
|
||||
for my $i (0 .. $#operators)
|
||||
{
|
||||
my $operator = $operators[$i];
|
||||
my $opclass = $opclasses[$i];
|
||||
|
||||
$node->safe_psql("postgres", qq(
|
||||
SET maintenance_work_mem = '128MB';
|
||||
CREATE INDEX idx ON tst USING hnsw (v $opclass);
|
||||
));
|
||||
|
||||
foreach (@cs)
|
||||
{
|
||||
my $c = $_;
|
||||
|
||||
# Get exact results
|
||||
@expected = ();
|
||||
foreach (@queries)
|
||||
{
|
||||
my $res = $node->safe_psql("postgres", qq(
|
||||
SET enable_indexscan = off;
|
||||
WITH top AS (
|
||||
SELECT v $operator '$_' AS distance FROM tst WHERE i % $c = 0 ORDER BY distance LIMIT $limit
|
||||
)
|
||||
SELECT i FROM tst WHERE (v $operator '$_') <= (SELECT MAX(distance) FROM top)
|
||||
));
|
||||
push(@expected, $res);
|
||||
}
|
||||
|
||||
if ($c == 100)
|
||||
{
|
||||
test_recall($c, 40, 0.99, $operator);
|
||||
}
|
||||
else
|
||||
{
|
||||
if ($operator eq "<->")
|
||||
{
|
||||
test_recall($c, 40, 0.99, $operator);
|
||||
}
|
||||
else
|
||||
{
|
||||
test_recall($c, 40, 0.99, $operator);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$node->safe_psql("postgres", "DROP INDEX idx;");
|
||||
}
|
||||
|
||||
done_testing();
|
||||
Reference in New Issue
Block a user