mirror of
https://github.com/pgvector/pgvector.git
synced 2026-07-01 10:11:20 +08:00
Added streaming option for HNSW [skip ci]
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
#endif
|
||||
|
||||
int hnsw_ef_search;
|
||||
bool hnsw_streaming;
|
||||
int hnsw_lock_tranche_id;
|
||||
static relopt_kind hnsw_relopt_kind;
|
||||
|
||||
@@ -75,6 +76,10 @@ HnswInit(void)
|
||||
"Valid range is 1..1000.", &hnsw_ef_search,
|
||||
HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL);
|
||||
|
||||
DefineCustomBoolVariable("hnsw.streaming", "todo",
|
||||
"todo", &hnsw_streaming,
|
||||
HNSW_DEFAULT_STREAMING, PGC_USERSET, 0, NULL, NULL, NULL);
|
||||
|
||||
MarkGUCPrefixReserved("hnsw");
|
||||
}
|
||||
|
||||
|
||||
15
src/hnsw.h
15
src/hnsw.h
@@ -42,6 +42,7 @@
|
||||
#define HNSW_DEFAULT_EF_SEARCH 40
|
||||
#define HNSW_MIN_EF_SEARCH 1
|
||||
#define HNSW_MAX_EF_SEARCH 1000
|
||||
#define HNSW_DEFAULT_STREAMING false
|
||||
|
||||
/* Tuple types */
|
||||
#define HNSW_ELEMENT_TUPLE_TYPE 1
|
||||
@@ -111,6 +112,7 @@
|
||||
|
||||
/* Variables */
|
||||
extern int hnsw_ef_search;
|
||||
extern bool hnsw_streaming;
|
||||
extern int hnsw_lock_tranche_id;
|
||||
|
||||
typedef struct HnswElementData HnswElementData;
|
||||
@@ -328,11 +330,22 @@ typedef struct HnswNeighborTupleData
|
||||
|
||||
typedef HnswNeighborTupleData * HnswNeighborTuple;
|
||||
|
||||
typedef union
|
||||
{
|
||||
struct pointerhash_hash *pointers;
|
||||
struct offsethash_hash *offsets;
|
||||
struct tidhash_hash *tids;
|
||||
} visited_hash;
|
||||
|
||||
typedef struct HnswScanOpaqueData
|
||||
{
|
||||
const HnswTypeInfo *typeInfo;
|
||||
bool first;
|
||||
List *w;
|
||||
visited_hash v;
|
||||
List *discarded;
|
||||
Datum q;
|
||||
int m;
|
||||
MemoryContext tmpCtx;
|
||||
|
||||
/* Support functions */
|
||||
@@ -378,7 +391,7 @@ bool HnswCheckNorm(FmgrInfo *procinfo, Oid collation, Datum value);
|
||||
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
|
||||
void HnswInitPage(Buffer buf, Page page);
|
||||
void HnswInit(void);
|
||||
List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement);
|
||||
List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, List **discarded, bool initVisited);
|
||||
HnswElement HnswGetEntryPoint(Relation index);
|
||||
void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint);
|
||||
void *HnswAlloc(HnswAllocator * allocator, Size size);
|
||||
|
||||
@@ -26,6 +26,9 @@ GetScanItems(IndexScanDesc scan, Datum q)
|
||||
/* Get m and entry point */
|
||||
HnswGetMetaPageInfo(index, &m, &entryPoint);
|
||||
|
||||
so->q = q;
|
||||
so->m = m;
|
||||
|
||||
if (entryPoint == NULL)
|
||||
return NIL;
|
||||
|
||||
@@ -33,11 +36,32 @@ GetScanItems(IndexScanDesc scan, Datum q)
|
||||
|
||||
for (int lc = entryPoint->level; lc >= 1; lc--)
|
||||
{
|
||||
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL);
|
||||
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL, NULL, NULL, true);
|
||||
ep = w;
|
||||
}
|
||||
|
||||
return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL);
|
||||
return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL, &so->v, &so->discarded, true);
|
||||
}
|
||||
|
||||
/*
|
||||
* Resume scan at ground level with discarded candidates
|
||||
*/
|
||||
static List *
|
||||
ResumeScanItems(IndexScanDesc scan)
|
||||
{
|
||||
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
|
||||
Relation index = scan->indexRelation;
|
||||
FmgrInfo *procinfo = so->procinfo;
|
||||
Oid collation = so->collation;
|
||||
List *ep;
|
||||
char *base = NULL;
|
||||
|
||||
if (list_length(so->discarded) == 0)
|
||||
return NIL;
|
||||
|
||||
ep = so->discarded;
|
||||
so->discarded = NIL;
|
||||
return HnswSearchLayer(base, so->q, ep, hnsw_ef_search, 0, index, procinfo, collation, so->m, false, NULL, &so->v, &so->discarded, false);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -103,7 +127,10 @@ hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int no
|
||||
{
|
||||
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
|
||||
|
||||
if (!so->first)
|
||||
tidhash_reset(so->v.tids);
|
||||
so->first = true;
|
||||
so->discarded = NIL;
|
||||
MemoryContextReset(so->tmpCtx);
|
||||
|
||||
if (keys && scan->numberOfKeys > 0)
|
||||
@@ -165,13 +192,36 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir)
|
||||
#endif
|
||||
}
|
||||
|
||||
while (list_length(so->w) > 0)
|
||||
for (;;)
|
||||
{
|
||||
char *base = NULL;
|
||||
HnswCandidate *hc = llast(so->w);
|
||||
HnswElement element = HnswPtrAccess(base, hc->element);
|
||||
HnswCandidate *hc;
|
||||
HnswElement element;
|
||||
ItemPointer heaptid;
|
||||
|
||||
if (list_length(so->w) == 0)
|
||||
{
|
||||
if (!hnsw_streaming)
|
||||
break;
|
||||
|
||||
/* TODO figure out locking */
|
||||
LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock);
|
||||
|
||||
so->w = ResumeScanItems(scan);
|
||||
|
||||
UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock);
|
||||
|
||||
#if defined(HNSW_MEMORY) && PG_VERSION_NUM >= 130000
|
||||
elog(INFO, "memory: %zu MB", MemoryContextMemAllocated(so->tmpCtx, false) / (1024 * 1024));
|
||||
#endif
|
||||
|
||||
if (list_length(so->w) == 0)
|
||||
break;
|
||||
}
|
||||
|
||||
hc = llast(so->w);
|
||||
element = HnswPtrAccess(base, hc->element);
|
||||
|
||||
/* Move to next element if no valid heap TIDs */
|
||||
if (element->heaptidsLength == 0)
|
||||
{
|
||||
|
||||
@@ -105,13 +105,6 @@ hash_offset(Size offset)
|
||||
#define SH_DEFINE
|
||||
#include "lib/simplehash.h"
|
||||
|
||||
typedef union
|
||||
{
|
||||
pointerhash_hash *pointers;
|
||||
offsethash_hash *offsets;
|
||||
tidhash_hash *tids;
|
||||
} visited_hash;
|
||||
|
||||
/*
|
||||
* Get the max number of connections in an upper layer for each element in the index
|
||||
*/
|
||||
@@ -721,18 +714,22 @@ CountElement(char *base, HnswElement skipElement, HnswCandidate * hc)
|
||||
* Algorithm 2 from paper
|
||||
*/
|
||||
List *
|
||||
HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement)
|
||||
HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement, visited_hash * v, List **discarded, bool initVisited)
|
||||
{
|
||||
List *w = NIL;
|
||||
pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL);
|
||||
pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL);
|
||||
int wlen = 0;
|
||||
visited_hash v;
|
||||
visited_hash v2;
|
||||
ListCell *lc2;
|
||||
HnswNeighborArray *neighborhoodData = NULL;
|
||||
Size neighborhoodSize = 0;
|
||||
|
||||
InitVisited(base, &v, index, ef, m);
|
||||
if (v == NULL)
|
||||
v = &v2;
|
||||
|
||||
if (initVisited)
|
||||
InitVisited(base, v, index, ef, m);
|
||||
|
||||
/* Create local memory for neighborhood if needed */
|
||||
if (index == NULL)
|
||||
@@ -747,7 +744,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
|
||||
HnswCandidate *hc = (HnswCandidate *) lfirst(lc2);
|
||||
bool found;
|
||||
|
||||
AddToVisited(base, &v, hc, index, &found);
|
||||
AddToVisited(base, v, hc, index, &found);
|
||||
|
||||
pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node));
|
||||
pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node));
|
||||
@@ -793,7 +790,7 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
|
||||
HnswCandidate *e = &neighborhood->items[i];
|
||||
bool visited;
|
||||
|
||||
AddToVisited(base, &v, e, index, &visited);
|
||||
AddToVisited(base, v, e, index, &visited);
|
||||
|
||||
if (!visited)
|
||||
{
|
||||
@@ -806,13 +803,14 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
|
||||
if (index == NULL)
|
||||
eDistance = GetCandidateDistance(base, e, q, procinfo, collation);
|
||||
else
|
||||
HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd ? NULL : &f->distance);
|
||||
HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance);
|
||||
|
||||
if (eDistance < f->distance || alwaysAdd)
|
||||
{
|
||||
HnswCandidate *ec;
|
||||
|
||||
Assert(!eElement->deleted);
|
||||
Assert(eElement->level >= lc);
|
||||
|
||||
/* Make robust to issues */
|
||||
if (eElement->level < lc)
|
||||
@@ -837,9 +835,25 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
|
||||
|
||||
/* No need to decrement wlen */
|
||||
if (wlen > ef)
|
||||
pairingheap_remove_first(W);
|
||||
{
|
||||
HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner;
|
||||
|
||||
if (discarded != NULL)
|
||||
*discarded = lappend(*discarded, hc);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (discarded != NULL)
|
||||
{
|
||||
HnswCandidate *ec;
|
||||
|
||||
/* Copy e */
|
||||
ec = palloc(sizeof(HnswCandidate));
|
||||
HnswPtrStore(base, ec->element, eElement);
|
||||
ec->distance = eDistance;
|
||||
|
||||
*discarded = lappend(*discarded, ec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1230,7 +1244,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
|
||||
/* 1st phase: greedy search to insert level */
|
||||
for (int lc = entryLevel; lc >= level + 1; lc--)
|
||||
{
|
||||
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement);
|
||||
w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true);
|
||||
ep = w;
|
||||
}
|
||||
|
||||
@@ -1248,7 +1262,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
|
||||
List *neighbors;
|
||||
List *lw;
|
||||
|
||||
w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement);
|
||||
w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement, NULL, NULL, true);
|
||||
|
||||
/* Elements being deleted or skipped can help with search */
|
||||
/* but should be removed before selecting neighbors */
|
||||
|
||||
30
test/t/039_hnsw_streaming.pl
Normal file
30
test/t/039_hnsw_streaming.pl
Normal file
@@ -0,0 +1,30 @@
|
||||
use strict;
|
||||
use warnings FATAL => 'all';
|
||||
use PostgreSQL::Test::Cluster;
|
||||
use PostgreSQL::Test::Utils;
|
||||
use Test::More;
|
||||
|
||||
my $dim = 3;
|
||||
my $array_sql = join(",", ('random()') x $dim);
|
||||
|
||||
# Initialize node
|
||||
my $node = PostgreSQL::Test::Cluster->new('node');
|
||||
$node->init;
|
||||
$node->start;
|
||||
|
||||
# Create table
|
||||
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
|
||||
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));");
|
||||
$node->safe_psql("postgres",
|
||||
"INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 10000) i;"
|
||||
);
|
||||
$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);");
|
||||
|
||||
my $count = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SET hnsw.streaming = on;
|
||||
SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 1000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t;
|
||||
));
|
||||
is($count, 10);
|
||||
|
||||
done_testing();
|
||||
Reference in New Issue
Block a user