Added iterative search for HNSW [skip ci]

This commit is contained in:
Andrew Kane
2024-10-10 18:14:39 -07:00
parent c91ed7b2c3
commit 961cb17d80
9 changed files with 457 additions and 30 deletions

View File

@@ -1,5 +1,6 @@
## 0.8.0 (unreleased)
- Added support for iterative index scans
- Added casts for arrays to `sparsevec`
- Improved cost estimation
- Improved performance of HNSW inserts and on-disk index builds

View File

@@ -18,7 +18,16 @@
#define MarkGUCPrefixReserved(x) EmitWarningsOnPlaceholders(x)
#endif
static const struct config_enum_entry hnsw_iterative_search_options[] = {
{"off", HNSW_ITERATIVE_SEARCH_OFF, false},
{"on", HNSW_ITERATIVE_SEARCH_RELAXED, false},
{"strict", HNSW_ITERATIVE_SEARCH_STRICT, false},
{NULL, 0, false}
};
int hnsw_ef_search;
int hnsw_iterative_search_max_tuples;
int hnsw_iterative_search;
int hnsw_lock_tranche_id;
static relopt_kind hnsw_relopt_kind;
@@ -69,6 +78,15 @@ HnswInit(void)
"Valid range is 1..1000.", &hnsw_ef_search,
HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL);
DefineCustomEnumVariable("hnsw.iterative_search", "Sets iterative search",
NULL, &hnsw_iterative_search,
HNSW_ITERATIVE_SEARCH_OFF, hnsw_iterative_search_options, PGC_USERSET, 0, NULL, NULL, NULL);
/* TODO Ensure ivfflat.max_probes uses same value for "all" */
DefineCustomIntVariable("hnsw.iterative_search_max_tuples", "Sets the max number of candidates to visit for iterative search",
"-1 means all", &hnsw_iterative_search_max_tuples,
-1, -1, INT_MAX, PGC_USERSET, 0, NULL, NULL, NULL);
MarkGUCPrefixReserved("hnsw");
}

View File

@@ -109,8 +109,17 @@
/* Variables */
extern int hnsw_ef_search;
extern int hnsw_iterative_search;
extern int hnsw_iterative_search_max_tuples;
extern int hnsw_lock_tranche_id;
typedef enum HnswIterativeSearchType
{
HNSW_ITERATIVE_SEARCH_OFF,
HNSW_ITERATIVE_SEARCH_RELAXED,
HNSW_ITERATIVE_SEARCH_STRICT
} HnswIterativeSearchType;
typedef struct HnswElementData HnswElementData;
typedef struct HnswNeighborArray HnswNeighborArray;
@@ -132,6 +141,7 @@ struct HnswElementData
uint8 heaptidsLength;
uint8 level;
uint8 deleted;
uint8 version;
uint32 hash;
HnswNeighborsPtr neighbors;
BlockNumber blkno;
@@ -319,10 +329,10 @@ typedef struct HnswElementTupleData
uint8 type;
uint8 level;
uint8 deleted;
uint8 unused;
uint8 version;
ItemPointerData heaptids[HNSW_HEAPTIDS];
ItemPointerData neighbortid;
uint16 unused2;
uint16 unused;
Vector data;
} HnswElementTupleData;
@@ -331,7 +341,7 @@ typedef HnswElementTupleData * HnswElementTuple;
typedef struct HnswNeighborTupleData
{
uint8 type;
uint8 unused;
uint8 version;
uint16 count;
ItemPointerData indextids[FLEXIBLE_ARRAY_MEMBER];
} HnswNeighborTupleData;
@@ -356,6 +366,12 @@ typedef struct HnswScanOpaqueData
const HnswTypeInfo *typeInfo;
bool first;
List *w;
visited_hash v;
pairingheap *discarded;
HnswQuery q;
int m;
int64 tuples;
double previousDistance;
MemoryContext tmpCtx;
/* Support functions */
@@ -399,7 +415,7 @@ bool HnswCheckNorm(HnswSupport * support, Datum value);
Buffer HnswNewBuffer(Relation index, ForkNumber forkNum);
void HnswInitPage(Buffer buf, Page page);
void HnswInit(void);
List *HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement);
List *HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples);
HnswElement HnswGetEntryPoint(Relation index);
void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint);
void *HnswAlloc(HnswAllocator * allocator, Size size);

View File

@@ -36,7 +36,7 @@ GetInsertPage(Relation index)
* Check for a free offset
*/
static bool
HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size etupSize, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage)
HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size etupSize, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage, uint8 *tupleVersion)
{
OffsetNumber offno;
OffsetNumber maxoffno = PageGetMaxOffsetNumber(page);
@@ -98,6 +98,7 @@ HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size
{
*freeOffno = offno;
*freeNeighborOffno = neighborOffno;
*tupleVersion = etup->version;
return true;
}
else if (*nbuf != buf)
@@ -153,6 +154,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B
OffsetNumber freeOffno = InvalidOffsetNumber;
OffsetNumber freeNeighborOffno = InvalidOffsetNumber;
BlockNumber newInsertPage = InvalidBlockNumber;
uint8 tupleVersion;
char *base = NULL;
/* Calculate sizes */
@@ -202,7 +204,7 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B
}
/* Next, try space from a deleted element */
if (HnswFreeOffset(index, buf, page, e, etupSize, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage))
if (HnswFreeOffset(index, buf, page, e, etupSize, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage, &tupleVersion))
{
if (nbuf != buf)
{
@@ -212,6 +214,10 @@ AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, B
npage = GenericXLogRegisterBuffer(state, nbuf, 0);
}
/* Set tuple version */
etup->version = tupleVersion;
ntup->version = tupleVersion;
break;
}

View File

@@ -1,5 +1,7 @@
#include "postgres.h"
#include <math.h>
#include "access/relscan.h"
#include "hnsw.h"
#include "pgstat.h"
@@ -21,25 +23,57 @@ GetScanItems(IndexScanDesc scan, Datum value)
int m;
HnswElement entryPoint;
char *base = NULL;
HnswQuery q;
q.value = value;
HnswQuery *q = &so->q;
/* Get m and entry point */
HnswGetMetaPageInfo(index, &m, &entryPoint);
q->value = value;
so->m = m;
if (entryPoint == NULL)
return NIL;
ep = list_make1(HnswEntryCandidate(base, entryPoint, &q, index, support, false));
ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, support, false));
for (int lc = entryPoint->level; lc >= 1; lc--)
{
w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, false, NULL);
w = HnswSearchLayer(base, q, ep, 1, lc, index, support, m, false, NULL, NULL, NULL, true, NULL);
ep = w;
}
return HnswSearchLayer(base, &q, ep, hnsw_ef_search, 0, index, support, m, false, NULL);
return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, support, m, false, NULL, &so->v, hnsw_iterative_search != HNSW_ITERATIVE_SEARCH_OFF ? &so->discarded : NULL, true, &so->tuples);
}
/*
* Resume scan at ground level with discarded candidates
*/
static List *
ResumeScanItems(IndexScanDesc scan)
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
Relation index = scan->indexRelation;
List *ep = NIL;
char *base = NULL;
int batch_size = hnsw_ef_search;
if (pairingheap_is_empty(so->discarded))
return NIL;
/* Get next batch of candidates */
for (int i = 0; i < batch_size; i++)
{
HnswSearchCandidate *sc;
if (pairingheap_is_empty(so->discarded))
break;
sc = HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded));
ep = lappend(ep, sc);
}
return HnswSearchLayer(base, &so->q, ep, batch_size, 0, index, &so->support, so->m, false, NULL, &so->v, &so->discarded, false, &so->tuples);
}
/*
@@ -83,6 +117,8 @@ hnswbeginscan(Relation index, int nkeys, int norderbys)
so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData));
so->typeInfo = HnswGetTypeInfo(index);
so->first = true;
so->v.tids = NULL;
so->discarded = NULL;
so->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
"Hnsw scan temporary context",
ALLOCSET_DEFAULT_SIZES);
@@ -103,7 +139,15 @@ hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int no
{
HnswScanOpaque so = (HnswScanOpaque) scan->opaque;
if (so->v.tids != NULL)
tidhash_reset(so->v.tids);
if (so->discarded != NULL)
pairingheap_reset(so->discarded);
so->first = true;
so->tuples = 0;
so->previousDistance = -INFINITY;
MemoryContextReset(so->tmpCtx);
if (keys && scan->numberOfKeys > 0)
@@ -165,22 +209,100 @@ hnswgettuple(IndexScanDesc scan, ScanDirection dir)
#endif
}
while (list_length(so->w) > 0)
for (;;)
{
char *base = NULL;
HnswSearchCandidate *sc = llast(so->w);
HnswElement element = HnswPtrAccess(base, sc->element);
HnswSearchCandidate *sc;
HnswElement element;
ItemPointer heaptid;
if (list_length(so->w) == 0)
{
if (hnsw_iterative_search == HNSW_ITERATIVE_SEARCH_OFF)
break;
/* Empty index */
if (so->discarded == NULL)
break;
/* Reached max number of additional tuples */
if (hnsw_iterative_search_max_tuples != -1 && so->tuples >= hnsw_iterative_search_max_tuples)
{
if (pairingheap_is_empty(so->discarded))
break;
/* Return remaining tuples */
so->w = lappend(so->w, HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded)));
}
/* Prevent scans from consuming too much memory */
else if (MemoryContextMemAllocated(so->tmpCtx, false) > (Size) work_mem * 1024L)
{
if (pairingheap_is_empty(so->discarded))
{
ereport(DEBUG1,
(errmsg("hnsw index scan exceeded work_mem after " INT64_FORMAT " tuples", so->tuples),
errhint("Increase work_mem to scan more tuples.")));
break;
}
/* Return remaining tuples */
so->w = lappend(so->w, HnswGetSearchCandidate(w_node, pairingheap_remove_first(so->discarded)));
}
else
{
/*
* Locking ensures when neighbors are read, the elements they
* reference will not be deleted (and replaced) during the
* iteration.
*
* Elements loaded into memory on previous iterations may have
* been deleted (and replaced), so when reading neighbors, the
* element version must be checked.
*/
LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock);
so->w = ResumeScanItems(scan);
UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock);
#if defined(HNSW_MEMORY)
elog(INFO, "memory: %zu KB", MemoryContextMemAllocated(so->tmpCtx, false) / 1024);
#endif
}
if (list_length(so->w) == 0)
break;
}
sc = llast(so->w);
element = HnswPtrAccess(base, sc->element);
/* Move to next element if no valid heap TIDs */
if (element->heaptidsLength == 0)
{
so->w = list_delete_last(so->w);
/* Mark memory as free for next iteration */
if (hnsw_iterative_search != HNSW_ITERATIVE_SEARCH_OFF)
{
pfree(element);
pfree(sc);
}
continue;
}
heaptid = &element->heaptids[--element->heaptidsLength];
if (hnsw_iterative_search == HNSW_ITERATIVE_SEARCH_STRICT)
{
if (sc->distance < so->previousDistance)
continue;
so->previousDistance = sc->distance;
}
MemoryContextSwitchTo(oldCtx);
scan->xs_heaptid = *heaptid;

View File

@@ -251,6 +251,8 @@ HnswInitElement(char *base, ItemPointer heaptid, int m, double ml, int maxLevel,
element->level = level;
element->deleted = 0;
/* Start at one to make it easier to find issues */
element->version = 1;
HnswInitNeighbors(base, element, m, allocator);
@@ -430,6 +432,7 @@ HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element)
etup->type = HNSW_ELEMENT_TUPLE_TYPE;
etup->level = element->level;
etup->deleted = 0;
etup->version = element->version;
for (int i = 0; i < HNSW_HEAPTIDS; i++)
{
if (i < element->heaptidsLength)
@@ -472,6 +475,7 @@ HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m)
}
ntup->count = idx;
ntup->version = e->version;
}
/*
@@ -482,6 +486,7 @@ HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHe
{
element->level = etup->level;
element->deleted = etup->deleted;
element->version = etup->version;
element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid);
element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid);
element->heaptidsLength = 0;
@@ -608,6 +613,21 @@ CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, v
return 0;
}
/*
* Compare discarded candidate distances
*/
static int
CompareNearestDiscardedCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg)
{
if (HnswGetSearchCandidateConst(w_node, a)->distance < HnswGetSearchCandidateConst(w_node, b)->distance)
return 1;
if (HnswGetSearchCandidateConst(w_node, a)->distance > HnswGetSearchCandidateConst(w_node, b)->distance)
return -1;
return 0;
}
/*
* Compare candidate distances
*/
@@ -728,8 +748,11 @@ HnswLoadNeighborTids(HnswElement element, ItemPointerData *indextids, Relation i
ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno));
/* Ensure expected neighbors */
if (ntup->count != (element->level + 2) * m)
/*
* Ensure the neighbor tuple has not been deleted or replaced between
* index scan iterations
*/
if (ntup->version != element->version || ntup->count != (element->level + 2) * m)
{
UnlockReleaseBuffer(buf);
return false;
@@ -775,13 +798,13 @@ HnswLoadUnvisitedFromDisk(HnswElement element, HnswUnvisited * unvisited, int *u
* Algorithm 2 from paper
*/
List *
HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement)
HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation index, HnswSupport * support, int m, bool inserting, HnswElement skipElement, visited_hash * v, pairingheap **discarded, bool initVisited, int64 *tuples)
{
List *w = NIL;
pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL);
pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL);
int wlen = 0;
visited_hash v;
visited_hash vh;
ListCell *lc2;
HnswNeighborArray *localNeighborhood = NULL;
Size neighborhoodSize = 0;
@@ -790,7 +813,19 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in
int unvisitedLength;
bool inMemory = index == NULL;
InitVisited(base, &v, inMemory, ef, m);
if (v == NULL)
{
v = &vh;
initVisited = true;
}
if (initVisited)
{
InitVisited(base, v, inMemory, ef, m);
if (discarded != NULL)
*discarded = pairingheap_allocate(CompareNearestDiscardedCandidates, NULL);
}
/* Create local memory for neighborhood if needed */
if (inMemory)
@@ -805,7 +840,13 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in
HnswSearchCandidate *sc = (HnswSearchCandidate *) lfirst(lc2);
bool found;
AddToVisited(base, &v, sc->element, inMemory, &found);
if (initVisited)
{
AddToVisited(base, v, sc->element, inMemory, &found);
if (tuples != NULL)
(*tuples)++;
}
pairingheap_add(C, &sc->c_node);
pairingheap_add(W, &sc->w_node);
@@ -831,9 +872,12 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in
cElement = HnswPtrAccess(base, c->element);
if (inMemory)
HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, &v, lc, localNeighborhood, neighborhoodSize);
HnswLoadUnvisitedFromMemory(base, cElement, unvisited, &unvisitedLength, v, lc, localNeighborhood, neighborhoodSize);
else
HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, &v, index, m, lm, lc);
HnswLoadUnvisitedFromDisk(cElement, unvisited, &unvisitedLength, v, index, m, lm, lc);
if (tuples != NULL)
(*tuples) += unvisitedLength;
for (int i = 0; i < unvisitedLength; i++)
{
@@ -857,16 +901,25 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in
/* Avoid any allocations if not adding */
eElement = NULL;
HnswLoadElementImpl(blkno, offno, &eDistance, q, index, support, inserting, alwaysAdd ? NULL : &f->distance, &eElement);
HnswLoadElementImpl(blkno, offno, &eDistance, q, index, support, inserting, alwaysAdd || discarded != NULL ? NULL : &f->distance, &eElement);
if (eElement == NULL)
continue;
}
if (!(eDistance < f->distance || alwaysAdd))
continue;
if (eElement == NULL || !(eDistance < f->distance || alwaysAdd))
{
if (discarded != NULL)
{
/* Create a new candidate */
e = palloc(sizeof(HnswSearchCandidate));
HnswPtrStore(base, e->element, eElement);
e->distance = eDistance;
pairingheap_add(*discarded, &e->w_node);
}
Assert(!eElement->deleted);
continue;
}
/* Make robust to issues */
if (eElement->level < lc)
@@ -890,7 +943,12 @@ HnswSearchLayer(char *base, HnswQuery * q, List *ep, int ef, int lc, Relation in
/* No need to decrement wlen */
if (wlen > ef)
pairingheap_remove_first(W);
{
HnswSearchCandidate *d = HnswGetSearchCandidate(w_node, pairingheap_remove_first(W));
if (discarded != NULL)
pairingheap_add(*discarded, &d->w_node);
}
}
}
}
@@ -1225,7 +1283,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
/* 1st phase: greedy search to insert level */
for (int lc = entryLevel; lc >= level + 1; lc--)
{
w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, true, skipElement);
w = HnswSearchLayer(base, &q, ep, 1, lc, index, support, m, true, skipElement, NULL, NULL, true, NULL);
ep = w;
}
@@ -1244,7 +1302,7 @@ HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint
List *lw = NIL;
ListCell *lc2;
w = HnswSearchLayer(base, &q, ep, efConstruction, lc, index, support, m, true, skipElement);
w = HnswSearchLayer(base, &q, ep, efConstruction, lc, index, support, m, true, skipElement, NULL, NULL, true, NULL);
/* Convert search candidates to candidates */
foreach(lc2, w)

View File

@@ -527,6 +527,14 @@ MarkDeleted(HnswVacuumState * vacuumstate)
for (int i = 0; i < ntup->count; i++)
ItemPointerSetInvalid(&ntup->indextids[i]);
/* Increment version */
/* This is used to avoid incorrect reads for iterative scans */
/* Reserve some bits for future use */
etup->version++;
if (etup->version > 15)
etup->version = 1;
ntup->version = etup->version;
/*
* We modified the tuples in place, no need to call
* PageIndexTupleOverwrite

View File

@@ -0,0 +1,67 @@
use strict;
use warnings FATAL => 'all';
use PostgreSQL::Test::Cluster;
use PostgreSQL::Test::Utils;
use Test::More;
my $dim = 3;
my $array_sql = join(",", ('random()') x $dim);
# Initialize node
my $node = PostgreSQL::Test::Cluster->new('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4 PRIMARY KEY, v vector($dim));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;"
);
$node->safe_psql("postgres", qq(
SET maintenance_work_mem = '128MB';
SET max_parallel_maintenance_workers = 2;
CREATE INDEX ON tst USING hnsw (v vector_l2_ops)
));
my $count = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SET hnsw.iterative_search = on;
SET work_mem = '8MB';
SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t;
));
is($count, 10);
foreach ((30000, 50000, 70000))
{
my $max_tuples = $_;
my $expected = $max_tuples / 10000;
my $sum = 0;
for my $i (1 .. 20)
{
$count = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SET hnsw.iterative_search = on;
SET hnsw.iterative_search_max_tuples = $max_tuples;
SET work_mem = '8MB';
SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst WHERE i = $i) LIMIT 11) t;
));
$sum += $count;
}
my $avg = $sum / 20;
cmp_ok($avg, '>', $expected - 2);
cmp_ok($avg, '<', $expected + 2);
}
my ($ret, $stdout, $stderr) = $node->psql("postgres", qq(
SET enable_seqscan = off;
SET hnsw.iterative_search = on;
SET client_min_messages = debug1;
SET work_mem = '2MB';
SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t;
));
like($stderr, qr/hnsw index scan exceeded work_mem after \d+ tuples/);
done_testing();

View File

@@ -0,0 +1,131 @@
use strict;
use warnings FATAL => 'all';
use PostgreSQL::Test::Cluster;
use PostgreSQL::Test::Utils;
use Test::More;
my $node;
my @queries = ();
my @expected;
my $limit = 20;
my $dim = 3;
my $array_sql = join(",", ('random()') x $dim);
my @cs = (100, 1000);
sub test_recall
{
my ($c, $ef_search, $min, $operator) = @_;
my $correct = 0;
my $total = 0;
my $explain = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SET hnsw.ef_search = $ef_search;
SET hnsw.iterative_search = on;
EXPLAIN ANALYZE SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[0]' LIMIT $limit;
));
like($explain, qr/Index Scan using idx on tst/);
for my $i (0 .. $#queries)
{
my $actual = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SET hnsw.ef_search = $ef_search;
SET hnsw.iterative_search = on;
SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[$i]' LIMIT $limit;
));
my @actual_ids = split("\n", $actual);
my @expected_ids = split("\n", $expected[$i]);
my %expected_set = map { $_ => 1 } @expected_ids;
foreach (@actual_ids)
{
if (exists($expected_set{$_}))
{
$correct++;
}
}
$total += $limit;
}
cmp_ok($correct / $total, ">=", $min, $operator);
}
# Initialize node
$node = PostgreSQL::Test::Cluster->new('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;"
);
# Generate queries
for (1 .. 20)
{
my @r = ();
for (1 .. $dim)
{
push(@r, rand());
}
push(@queries, "[" . join(",", @r) . "]");
}
# Check each index type
my @operators = ("<->", "<=>");
my @opclasses = ("vector_l2_ops", "vector_cosine_ops");
for my $i (0 .. $#operators)
{
my $operator = $operators[$i];
my $opclass = $opclasses[$i];
$node->safe_psql("postgres", qq(
SET maintenance_work_mem = '128MB';
CREATE INDEX idx ON tst USING hnsw (v $opclass);
));
foreach (@cs)
{
my $c = $_;
# Get exact results
@expected = ();
foreach (@queries)
{
my $res = $node->safe_psql("postgres", qq(
SET enable_indexscan = off;
WITH top AS (
SELECT v $operator '$_' AS distance FROM tst WHERE i % $c = 0 ORDER BY distance LIMIT $limit
)
SELECT i FROM tst WHERE (v $operator '$_') <= (SELECT MAX(distance) FROM top)
));
push(@expected, $res);
}
if ($c == 100)
{
test_recall($c, 40, 0.99, $operator);
}
else
{
if ($operator eq "<->")
{
test_recall($c, 40, 0.99, $operator);
}
else
{
test_recall($c, 40, 0.99, $operator);
}
}
}
$node->safe_psql("postgres", "DROP INDEX idx;");
}
done_testing();