mirror of
https://github.com/pgvector/pgvector.git
synced 2026-06-06 05:51:21 +08:00
Added iterative search for IVFFlat [skip ci]
This commit is contained in:
@@ -17,8 +17,16 @@
|
||||
#endif
|
||||
|
||||
int ivfflat_probes;
|
||||
int ivfflat_iterative_search;
|
||||
int ivfflat_iterative_search_max_probes;
|
||||
static relopt_kind ivfflat_relopt_kind;
|
||||
|
||||
static const struct config_enum_entry ivfflat_iterative_search_options[] = {
|
||||
{"off", IVFFLAT_ITERATIVE_SEARCH_OFF, false},
|
||||
{"on", IVFFLAT_ITERATIVE_SEARCH_RELAXED, false},
|
||||
{NULL, 0, false}
|
||||
};
|
||||
|
||||
/*
|
||||
* Initialize index options and variables
|
||||
*/
|
||||
@@ -33,6 +41,14 @@ IvfflatInit(void)
|
||||
"Valid range is 1..lists.", &ivfflat_probes,
|
||||
IVFFLAT_DEFAULT_PROBES, IVFFLAT_MIN_LISTS, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL);
|
||||
|
||||
DefineCustomEnumVariable("ivfflat.iterative_search", "Sets whether to use iterative search",
|
||||
NULL, &ivfflat_iterative_search,
|
||||
IVFFLAT_ITERATIVE_SEARCH_OFF, ivfflat_iterative_search_options, PGC_USERSET, 0, NULL, NULL, NULL);
|
||||
|
||||
DefineCustomIntVariable("ivfflat.iterative_search_max_probes", "Sets the max number of probes for iterative search",
|
||||
"Zero sets to the number of lists", &ivfflat_iterative_search_max_probes,
|
||||
0, 0, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL);
|
||||
|
||||
MarkGUCPrefixReserved("ivfflat");
|
||||
}
|
||||
|
||||
|
||||
@@ -80,6 +80,14 @@
|
||||
|
||||
/* Variables */
|
||||
extern int ivfflat_probes;
|
||||
extern int ivfflat_iterative_search;
|
||||
extern int ivfflat_iterative_search_max_probes;
|
||||
|
||||
typedef enum IvfflatIterativeSearchType
|
||||
{
|
||||
IVFFLAT_ITERATIVE_SEARCH_OFF,
|
||||
IVFFLAT_ITERATIVE_SEARCH_RELAXED
|
||||
} IvfflatIterativeSearchType;
|
||||
|
||||
typedef struct VectorArrayData
|
||||
{
|
||||
@@ -248,8 +256,10 @@ typedef struct IvfflatScanOpaqueData
|
||||
{
|
||||
const IvfflatTypeInfo *typeInfo;
|
||||
int probes;
|
||||
int maxProbes;
|
||||
int dimensions;
|
||||
bool first;
|
||||
Datum value;
|
||||
|
||||
/* Sorting */
|
||||
Tuplesortstate *sortstate;
|
||||
@@ -266,6 +276,8 @@ typedef struct IvfflatScanOpaqueData
|
||||
|
||||
/* Lists */
|
||||
pairingheap *listQueue;
|
||||
BlockNumber *listPages;
|
||||
int listIndex;
|
||||
IvfflatScanList lists[FLEXIBLE_ARRAY_MEMBER]; /* must come last */
|
||||
} IvfflatScanOpaqueData;
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ GetScanLists(IndexScanDesc scan, Datum value)
|
||||
/* Use procinfo from the index instead of scan key for performance */
|
||||
distance = DatumGetFloat8(so->distfunc(so->procinfo, so->collation, PointerGetDatum(&list->center), value));
|
||||
|
||||
if (listCount < so->probes)
|
||||
if (listCount < so->maxProbes)
|
||||
{
|
||||
IvfflatScanList *scanlist;
|
||||
|
||||
@@ -78,7 +78,7 @@ GetScanLists(IndexScanDesc scan, Datum value)
|
||||
pairingheap_add(so->listQueue, &scanlist->ph_node);
|
||||
|
||||
/* Calculate max distance */
|
||||
if (listCount == so->probes)
|
||||
if (listCount == so->maxProbes)
|
||||
maxDistance = GetScanList(pairingheap_first(so->listQueue))->distance;
|
||||
}
|
||||
else if (distance < maxDistance)
|
||||
@@ -102,6 +102,11 @@ GetScanLists(IndexScanDesc scan, Datum value)
|
||||
|
||||
UnlockReleaseBuffer(cbuf);
|
||||
}
|
||||
|
||||
for (int i = listCount - 1; i >= 0; i--)
|
||||
so->listPages[i] = GetScanList(pairingheap_remove_first(so->listQueue))->startPage;
|
||||
|
||||
Assert(pairingheap_is_empty(so->listQueue));
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -114,11 +119,14 @@ GetScanItems(IndexScanDesc scan, Datum value)
|
||||
TupleDesc tupdesc = RelationGetDescr(scan->indexRelation);
|
||||
double tuples = 0;
|
||||
TupleTableSlot *slot = so->vslot;
|
||||
int batchProbes = 0;
|
||||
|
||||
tuplesort_reset(so->sortstate);
|
||||
|
||||
/* Search closest probes lists */
|
||||
while (!pairingheap_is_empty(so->listQueue))
|
||||
while (so->listIndex < so->maxProbes && (++batchProbes) <= so->probes)
|
||||
{
|
||||
BlockNumber searchPage = GetScanList(pairingheap_remove_first(so->listQueue))->startPage;
|
||||
BlockNumber searchPage = so->listPages[so->listIndex++];
|
||||
|
||||
/* Search all entry pages for list */
|
||||
while (BlockNumberIsValid(searchPage))
|
||||
@@ -166,13 +174,17 @@ GetScanItems(IndexScanDesc scan, Datum value)
|
||||
}
|
||||
}
|
||||
|
||||
if (tuples < 100)
|
||||
if (tuples < 100 && ivfflat_iterative_search == IVFFLAT_ITERATIVE_SEARCH_OFF)
|
||||
ereport(DEBUG1,
|
||||
(errmsg("index scan found few tuples"),
|
||||
errdetail("Index may have been created with little data."),
|
||||
errhint("Recreate the index and possibly decrease lists.")));
|
||||
|
||||
tuplesort_performsort(so->sortstate);
|
||||
|
||||
#if defined(IVFFLAT_MEMORY)
|
||||
elog(INFO, "memory: %zu MB", MemoryContextMemAllocated(CurrentMemoryContext, true) / (1024 * 1024));
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -240,6 +252,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
|
||||
int lists;
|
||||
int dimensions;
|
||||
int probes = ivfflat_probes;
|
||||
int maxProbes;
|
||||
|
||||
scan = RelationGetIndexScan(index, nkeys, norderbys);
|
||||
|
||||
@@ -249,10 +262,21 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
|
||||
if (probes > lists)
|
||||
probes = lists;
|
||||
|
||||
so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList));
|
||||
if (ivfflat_iterative_search != IVFFLAT_ITERATIVE_SEARCH_OFF)
|
||||
{
|
||||
if (ivfflat_iterative_search_max_probes == 0)
|
||||
maxProbes = lists;
|
||||
else
|
||||
maxProbes = Min(ivfflat_iterative_search_max_probes, lists);
|
||||
}
|
||||
else
|
||||
maxProbes = probes;
|
||||
|
||||
so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + maxProbes * sizeof(IvfflatScanList));
|
||||
so->typeInfo = IvfflatGetTypeInfo(index);
|
||||
so->first = true;
|
||||
so->probes = probes;
|
||||
so->maxProbes = maxProbes;
|
||||
so->dimensions = dimensions;
|
||||
|
||||
/* Set support functions */
|
||||
@@ -280,6 +304,8 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
|
||||
so->bas = GetAccessStrategy(BAS_BULKREAD);
|
||||
|
||||
so->listQueue = pairingheap_allocate(CompareLists, scan);
|
||||
so->listPages = palloc(maxProbes * sizeof(BlockNumber));
|
||||
so->listIndex = 0;
|
||||
|
||||
scan->opaque = so;
|
||||
|
||||
@@ -294,11 +320,9 @@ ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int
|
||||
{
|
||||
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
|
||||
|
||||
if (!so->first)
|
||||
tuplesort_reset(so->sortstate);
|
||||
|
||||
so->first = true;
|
||||
pairingheap_reset(so->listQueue);
|
||||
so->listIndex = 0;
|
||||
|
||||
if (keys && scan->numberOfKeys > 0)
|
||||
memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData));
|
||||
@@ -314,6 +338,8 @@ bool
|
||||
ivfflatgettuple(IndexScanDesc scan, ScanDirection dir)
|
||||
{
|
||||
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
|
||||
ItemPointer heaptid;
|
||||
bool isnull;
|
||||
|
||||
/*
|
||||
* Index can be used to scan backward, but Postgres doesn't support
|
||||
@@ -341,28 +367,25 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir)
|
||||
IvfflatBench("GetScanLists", GetScanLists(scan, value));
|
||||
IvfflatBench("GetScanItems", GetScanItems(scan, value));
|
||||
so->first = false;
|
||||
so->value = value;
|
||||
|
||||
#if defined(IVFFLAT_MEMORY)
|
||||
elog(INFO, "memory: %zu MB", MemoryContextMemAllocated(CurrentMemoryContext, true) / (1024 * 1024));
|
||||
#endif
|
||||
|
||||
/* Clean up if we allocated a new value */
|
||||
if (value != scan->orderByData->sk_argument)
|
||||
pfree(DatumGetPointer(value));
|
||||
/* TODO clean up if we allocated a new value */
|
||||
}
|
||||
|
||||
if (tuplesort_gettupleslot(so->sortstate, true, false, so->mslot, NULL))
|
||||
while (!tuplesort_gettupleslot(so->sortstate, true, false, so->mslot, NULL))
|
||||
{
|
||||
bool isnull;
|
||||
ItemPointer heaptid = (ItemPointer) DatumGetPointer(slot_getattr(so->mslot, 2, &isnull));
|
||||
if (so->listIndex == so->maxProbes)
|
||||
return false;
|
||||
|
||||
scan->xs_heaptid = *heaptid;
|
||||
scan->xs_recheck = false;
|
||||
scan->xs_recheckorderby = false;
|
||||
return true;
|
||||
IvfflatBench("GetScanItems", GetScanItems(scan, so->value));
|
||||
}
|
||||
|
||||
return false;
|
||||
heaptid = (ItemPointer) DatumGetPointer(slot_getattr(so->mslot, 2, &isnull));
|
||||
|
||||
scan->xs_heaptid = *heaptid;
|
||||
scan->xs_recheck = false;
|
||||
scan->xs_recheckorderby = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -374,6 +397,7 @@ ivfflatendscan(IndexScanDesc scan)
|
||||
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
|
||||
|
||||
pairingheap_free(so->listQueue);
|
||||
pfree(so->listPages);
|
||||
tuplesort_end(so->sortstate);
|
||||
FreeAccessStrategy(so->bas);
|
||||
FreeTupleDesc(so->tupdesc);
|
||||
|
||||
@@ -26,7 +26,7 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats,
|
||||
Page cpage;
|
||||
OffsetNumber coffno;
|
||||
OffsetNumber cmaxoffno;
|
||||
BlockNumber startPages[MaxOffsetNumber];
|
||||
BlockNumber listPages[MaxOffsetNumber];
|
||||
ListInfo listInfo;
|
||||
|
||||
cbuf = ReadBuffer(index, blkno);
|
||||
@@ -40,7 +40,7 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats,
|
||||
{
|
||||
IvfflatList list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, coffno));
|
||||
|
||||
startPages[coffno - FirstOffsetNumber] = list->startPage;
|
||||
listPages[coffno - FirstOffsetNumber] = list->startPage;
|
||||
}
|
||||
|
||||
listInfo.blkno = blkno;
|
||||
@@ -50,7 +50,7 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats,
|
||||
|
||||
for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno))
|
||||
{
|
||||
BlockNumber searchPage = startPages[coffno - FirstOffsetNumber];
|
||||
BlockNumber searchPage = listPages[coffno - FirstOffsetNumber];
|
||||
BlockNumber insertPage = InvalidBlockNumber;
|
||||
|
||||
/* Iterate over entry pages */
|
||||
|
||||
54
test/t/041_ivfflat_iterative_search.pl
Normal file
54
test/t/041_ivfflat_iterative_search.pl
Normal file
@@ -0,0 +1,54 @@
|
||||
use strict;
|
||||
use warnings FATAL => 'all';
|
||||
use PostgreSQL::Test::Cluster;
|
||||
use PostgreSQL::Test::Utils;
|
||||
use Test::More;
|
||||
|
||||
my $dim = 3;
|
||||
my $array_sql = join(",", ('random()') x $dim);
|
||||
|
||||
# Initialize node
|
||||
my $node = PostgreSQL::Test::Cluster->new('node');
|
||||
$node->init;
|
||||
$node->start;
|
||||
|
||||
# Create table
|
||||
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
|
||||
$node->safe_psql("postgres", "CREATE TABLE tst (i int4 PRIMARY KEY, v vector($dim));");
|
||||
$node->safe_psql("postgres",
|
||||
"INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;"
|
||||
);
|
||||
$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);");
|
||||
|
||||
my $count = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SET ivfflat.probes = 10;
|
||||
SET ivfflat.iterative_search = on;
|
||||
SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t;
|
||||
));
|
||||
is($count, 10);
|
||||
|
||||
foreach ((30, 50, 70))
|
||||
{
|
||||
my $max_probes = $_;
|
||||
my $expected = $max_probes / 10;
|
||||
my $sum = 0;
|
||||
|
||||
for my $i (1 .. 20)
|
||||
{
|
||||
$count = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SET ivfflat.probes = 10;
|
||||
SET ivfflat.iterative_search = on;
|
||||
SET ivfflat.iterative_search_max_probes = $max_probes;
|
||||
SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst WHERE i = $i) LIMIT 11) t;
|
||||
));
|
||||
$sum += $count;
|
||||
}
|
||||
|
||||
my $avg = $sum / 20;
|
||||
cmp_ok($avg, '>', $expected - 2);
|
||||
cmp_ok($avg, '<', $expected + 2);
|
||||
}
|
||||
|
||||
done_testing();
|
||||
125
test/t/042_ivfflat_iterative_search_recall.pl
Normal file
125
test/t/042_ivfflat_iterative_search_recall.pl
Normal file
@@ -0,0 +1,125 @@
|
||||
use strict;
|
||||
use warnings FATAL => 'all';
|
||||
use PostgreSQL::Test::Cluster;
|
||||
use PostgreSQL::Test::Utils;
|
||||
use Test::More;
|
||||
|
||||
my $node;
|
||||
my @queries = ();
|
||||
my @expected;
|
||||
my $limit = 20;
|
||||
my @cs = (100, 1000);
|
||||
|
||||
sub test_recall
|
||||
{
|
||||
my ($c, $probes, $min, $operator) = @_;
|
||||
my $correct = 0;
|
||||
my $total = 0;
|
||||
|
||||
my $explain = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SET ivfflat.probes = $probes;
|
||||
SET ivfflat.iterative_search = on;
|
||||
EXPLAIN ANALYZE SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[0]' LIMIT $limit;
|
||||
));
|
||||
like($explain, qr/Index Scan using idx on tst/);
|
||||
|
||||
for my $i (0 .. $#queries)
|
||||
{
|
||||
my $actual = $node->safe_psql("postgres", qq(
|
||||
SET enable_seqscan = off;
|
||||
SET ivfflat.probes = $probes;
|
||||
SET ivfflat.iterative_search = on;
|
||||
SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[$i]' LIMIT $limit;
|
||||
));
|
||||
my @actual_ids = split("\n", $actual);
|
||||
|
||||
my @expected_ids = split("\n", $expected[$i]);
|
||||
my %expected_set = map { $_ => 1 } @expected_ids;
|
||||
|
||||
foreach (@actual_ids)
|
||||
{
|
||||
if (exists($expected_set{$_}))
|
||||
{
|
||||
$correct++;
|
||||
}
|
||||
}
|
||||
|
||||
$total += $limit;
|
||||
}
|
||||
|
||||
cmp_ok($correct / $total, ">=", $min, $operator);
|
||||
}
|
||||
|
||||
# Initialize node
|
||||
$node = PostgreSQL::Test::Cluster->new('node');
|
||||
$node->init;
|
||||
$node->start;
|
||||
|
||||
# Create table
|
||||
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
|
||||
$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));");
|
||||
$node->safe_psql("postgres",
|
||||
"INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;"
|
||||
);
|
||||
|
||||
# Generate queries
|
||||
for (1 .. 20)
|
||||
{
|
||||
my $r1 = rand();
|
||||
my $r2 = rand();
|
||||
my $r3 = rand();
|
||||
push(@queries, "[$r1,$r2,$r3]");
|
||||
}
|
||||
|
||||
# Check each index type
|
||||
my @operators = ("<->", "<=>");
|
||||
my @opclasses = ("vector_l2_ops", "vector_cosine_ops");
|
||||
|
||||
for my $i (0 .. $#operators)
|
||||
{
|
||||
my $operator = $operators[$i];
|
||||
my $opclass = $opclasses[$i];
|
||||
|
||||
$node->safe_psql("postgres", "CREATE INDEX idx ON tst USING ivfflat (v $opclass);");
|
||||
|
||||
foreach (@cs)
|
||||
{
|
||||
my $c = $_;
|
||||
|
||||
# Get exact results
|
||||
@expected = ();
|
||||
foreach (@queries)
|
||||
{
|
||||
my $res = $node->safe_psql("postgres", qq(
|
||||
SET enable_indexscan = off;
|
||||
WITH top AS (
|
||||
SELECT v $operator '$_' AS distance FROM tst WHERE i % $c = 0 ORDER BY distance LIMIT $limit
|
||||
)
|
||||
SELECT i FROM tst WHERE (v $operator '$_') <= (SELECT MAX(distance) FROM top)
|
||||
));
|
||||
push(@expected, $res);
|
||||
}
|
||||
|
||||
if ($c == 100)
|
||||
{
|
||||
test_recall($c, 1, 0.58, $operator);
|
||||
test_recall($c, 10, 0.98, $operator);
|
||||
}
|
||||
else
|
||||
{
|
||||
if ($operator eq "<->")
|
||||
{
|
||||
test_recall($c, 1, 0.80, $operator);
|
||||
}
|
||||
else
|
||||
{
|
||||
test_recall($c, 1, 0.88, $operator);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$node->safe_psql("postgres", "DROP INDEX idx;");
|
||||
}
|
||||
|
||||
done_testing();
|
||||
Reference in New Issue
Block a user