Added streaming option for HNSW [skip ci]

This commit is contained in:
Andrew Kane
2024-09-18 14:55:58 -07:00
parent a1b80faa67
commit af1727775d
5 changed files with 134 additions and 22 deletions

View File

@@ -17,6 +17,7 @@
#endif
int hnsw_ef_search;
bool hnsw_streaming;
int hnsw_lock_tranche_id;
static relopt_kind hnsw_relopt_kind;
@@ -75,6 +76,10 @@ HnswInit(void)
"Valid range is 1..1000.", &hnsw_ef_search,
HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL);
DefineCustomBoolVariable("hnsw.streaming", "todo",
"todo", &hnsw_streaming,
HNSW_DEFAULT_STREAMING, PGC_USERSET, 0, NULL, NULL, NULL);
MarkGUCPrefixReserved("hnsw");
}

View File

@@ -42,6 +42,7 @@
#define HNSW_DEFAULT_EF_SEARCH 40
#define HNSW_MIN_EF_SEARCH 1
#define HNSW_MAX_EF_SEARCH 1000
#define HNSW_DEFAULT_STREAMING false
/* Tuple types */
#define HNSW_ELEMENT_TUPLE_TYPE 1
@@ -111,6 +112,7 @@
/* Variables */
extern int hnsw_ef_search;
extern bool hnsw_streaming;
extern int hnsw_lock_tranche_id;
typedef struct HnswElementData HnswElementData;
@@ -328,11 +330,22 @@ typedef struct HnswNeighborTupleData
typedef HnswNeighborTupleData * HnswNeighborTuple;
typedef union
{
struct pointerhash_hash *pointers;
struct offsethash_hash *offsets;
struct tidhash_hash *tids;
} visited_hash;
typedef struct HnswScanOpaqueData
{
const HnswTypeInfo *typeInfo;
bool first;
List *w;
visited_hash v;
List *discarded;
Datum q;
int m;
MemoryContext tmpCtx;
/* Support functions */
@@ -378,7 +391,7 @@ bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
void HnswInitPage(Buffer buf, Page page);
void HnswInit(void);
List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement);
List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, List **discarded, bool initVisited);
HnswElement HnswGetEntryPoint(Relation index);
void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint);
void *HnswAlloc(HnswAllocator * allocator, Size size);

View File

@@ -26,6 +26,9 @@ GetScanItems(IndexScanDesc scan, Datum q)
/* Get m and entry point */
HnswGetMetaPageInfo(index, &m, &entryPoint);
so->q = q;
so->m = m;
if (entryPoint == NULL)
return NIL;
@@ -33,11 +36,32 @@ GetScanItems(IndexScanDesc scan, Datum q)
for (int lc = entryPoint->level; lc >= 1; lc--)
{
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL);
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL, NULL, NULL, true);
ep = w;
}
return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL);
return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL, &so->v, &so->discarded, true);
}
/*
* Resume scan at ground level with discarded candidates
*/
static List *
ResumeScanItems(IndexScanDesc scan)
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
Relation index = scan->indexRelation;
FmgrInfo *procinfo = so->procinfo;
Oid collation = so->collation;
List *ep;
char *base = NULL;
if (list_length(so->discarded) == 0)
return NIL;
ep = so->discarded;
so->discarded = NIL;
return HnswSearchLayer(base, so->q, ep, hnsw_ef_search, 0, index, procinfo, collation, so->m, false, NULL, &so->v, &so->discarded, false);
}
/*
@@ -103,7 +127,10 @@ hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int no
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
if (!so->first)
tidhash_reset(so->v.tids);
so->first = true;
so->discarded = NIL;
MemoryContextReset(so->tmpCtx);
if (keys && scan->numberOfKeys > 0)
@@ -165,13 +192,36 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir)
#endif
}
while (list_length(so->w) > 0)
for (;;)
{
char *base = NULL;
HnswCandidate *hc = llast(so->w);
HnswElement element = HnswPtrAccess(base, hc->element);
HnswCandidate *hc;
HnswElement element;
ItemPointer heaptid;
if (list_length(so->w) == 0)
{
if (!hnsw_streaming)
break;
/* TODO figure out locking */
LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock);
so->w = ResumeScanItems(scan);
UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock);
#if defined(HNSW_MEMORY) && PG_VERSION_NUM >= 130000
elog(INFO, "memory: %zu MB", MemoryContextMemAllocated(so->tmpCtx, false) / (1024 * 1024));
#endif
if (list_length(so->w) == 0)
break;
}
hc = llast(so->w);
element = HnswPtrAccess(base, hc->element);
/* Move to next element if no valid heap TIDs */
if (element->heaptidsLength == 0)
{

View File

@@ -105,13 +105,6 @@ hash_offset(Size offset)
#define SH_DEFINE
#include "lib/simplehash.h"
typedef union
{
pointerhash_hash *pointers;
offsethash_hash *offsets;
tidhash_hash *tids;
} visited_hash;
/*
* Get the max number of connections in an upper layer for each element in the index
*/
@@ -721,18 +714,22 @@ CountElement(char *base, HnswElement skipElement, HnswCandidate * hc)
* Algorithm 2 from paper
*/
List *
HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement)
HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, List **discarded, bool initVisited)
{
List *w = NIL;
pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL);
pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL);
int wlen = 0;
visited_hash v;
visited_hash v2;
ListCell *lc2;
HnswNeighborArray *neighborhoodData = NULL;
Size neighborhoodSize = 0;
InitVisited(base, &v, index, ef, m);
if (v == NULL)
v = &v2;
if (initVisited)
InitVisited(base, v, index, ef, m);
/* Create local memory for neighborhood if needed */
if (index == NULL)
@@ -747,7 +744,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
HnswCandidate *hc = (HnswCandidate *) lfirst(lc2);
bool found;
AddToVisited(base, &v, hc, index, &found);
AddToVisited(base, v, hc, index, &found);
pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node));
pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node));
@@ -793,7 +790,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
HnswCandidate *e = &neighborhood->items[i];
bool visited;
AddToVisited(base, &v, e, index, &visited);
AddToVisited(base, v, e, index, &visited);
if (!visited)
{
@@ -806,13 +803,14 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
if (index == NULL)
eDistance = GetCandidateDistance(base, e, q, procinfo, collation);
else
HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance);
HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance);
if (eDistance < f->distance || alwaysAdd)
{
HnswCandidate *ec;
Assert(!eElement->deleted);
Assert(eElement->level >= lc);
/* Make robust to issues */
if (eElement->level < lc)
@@ -837,9 +835,25 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
/* No need to decrement wlen */
if (wlen > ef)
pairingheap_remove_first(W);
{
HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner;
if (discarded != NULL)
*discarded = lappend(*discarded, hc);
}
}
}
else if (discarded != NULL)
{
HnswCandidate *ec;
/* Copy e */
ec = palloc(sizeof(HnswCandidate));
HnswPtrStore(base, ec->element, eElement);
ec->distance = eDistance;
*discarded = lappend(*discarded, ec);
}
}
}
}
@@ -1230,7 +1244,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
/* 1st phase: greedy search to insert level */
for (int lc = entryLevel; lc >= level + 1; lc--)
{
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement);
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true);
ep = w;
}
@@ -1248,7 +1262,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
List *neighbors;
List *lw;
w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement);
w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true);
/* Elements being deleted or skipped can help with search */
/* but should be removed before selecting neighbors */

View File

@@ -0,0 +1,30 @@
use strict;
use warnings FATAL => 'all';
use PostgreSQL::Test::Cluster;
use PostgreSQL::Test::Utils;
use Test::More;
my $dim = 3;
my $array_sql = join(",", ('random()') x $dim);
# Initialize node
my $node = PostgreSQL::Test::Cluster->new('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 10000) i;"
);
$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);");
my $count = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SET hnsw.streaming = on;
SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 1000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t;
));
is($count, 10);
done_testing();