diff --git a/src/ivfflat.c b/src/ivfflat.c index 395040d..0b24875 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -17,8 +17,16 @@ #endif int ivfflat_probes; +int ivfflat_iterative_search; +int ivfflat_iterative_search_max_probes; static relopt_kind ivfflat_relopt_kind; +static const struct config_enum_entry ivfflat_iterative_search_options[] = { + {"off", IVFFLAT_ITERATIVE_SEARCH_OFF, false}, + {"on", IVFFLAT_ITERATIVE_SEARCH_RELAXED, false}, + {NULL, 0, false} +}; + /* * Initialize index options and variables */ @@ -33,6 +41,14 @@ IvfflatInit(void) "Valid range is 1..lists.", &ivfflat_probes, IVFFLAT_DEFAULT_PROBES, IVFFLAT_MIN_LISTS, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL); + DefineCustomEnumVariable("ivfflat.iterative_search", "Sets whether to use iterative search", + NULL, &ivfflat_iterative_search, + IVFFLAT_ITERATIVE_SEARCH_OFF, ivfflat_iterative_search_options, PGC_USERSET, 0, NULL, NULL, NULL); + + DefineCustomIntVariable("ivfflat.iterative_search_max_probes", "Sets the max number of probes for iterative search", + "Zero sets to the number of lists", &ivfflat_iterative_search_max_probes, + 0, 0, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL); + MarkGUCPrefixReserved("ivfflat"); } diff --git a/src/ivfflat.h b/src/ivfflat.h index abf71fe..c2b4121 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -80,6 +80,14 @@ /* Variables */ extern int ivfflat_probes; +extern int ivfflat_iterative_search; +extern int ivfflat_iterative_search_max_probes; + +typedef enum IvfflatIterativeSearchType +{ + IVFFLAT_ITERATIVE_SEARCH_OFF, + IVFFLAT_ITERATIVE_SEARCH_RELAXED +} IvfflatIterativeSearchType; typedef struct VectorArrayData { @@ -248,8 +256,10 @@ typedef struct IvfflatScanOpaqueData { const IvfflatTypeInfo *typeInfo; int probes; + int maxProbes; int dimensions; bool first; + Datum value; /* Sorting */ Tuplesortstate *sortstate; @@ -266,6 +276,8 @@ typedef struct IvfflatScanOpaqueData /* Lists */ pairingheap *listQueue; + BlockNumber *listPages; + int listIndex; IvfflatScanList lists[FLEXIBLE_ARRAY_MEMBER]; /* must come last */ } IvfflatScanOpaqueData; diff --git a/src/ivfscan.c b/src/ivfscan.c index 74e3675..578f6aa 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -65,7 +65,7 @@ GetScanLists(IndexScanDesc scan, Datum value) /* Use procinfo from the index instead of scan key for performance */ distance = DatumGetFloat8(so->distfunc(so->procinfo, so->collation, PointerGetDatum(&list->center), value)); - if (listCount < so->probes) + if (listCount < so->maxProbes) { IvfflatScanList *scanlist; @@ -78,7 +78,7 @@ GetScanLists(IndexScanDesc scan, Datum value) pairingheap_add(so->listQueue, &scanlist->ph_node); /* Calculate max distance */ - if (listCount == so->probes) + if (listCount == so->maxProbes) maxDistance = GetScanList(pairingheap_first(so->listQueue))->distance; } else if (distance < maxDistance) @@ -102,6 +102,11 @@ GetScanLists(IndexScanDesc scan, Datum value) UnlockReleaseBuffer(cbuf); } + + for (int i = listCount - 1; i >= 0; i--) + so->listPages[i] = GetScanList(pairingheap_remove_first(so->listQueue))->startPage; + + Assert(pairingheap_is_empty(so->listQueue)); } /* @@ -114,11 +119,14 @@ GetScanItems(IndexScanDesc scan, Datum value) TupleDesc tupdesc = RelationGetDescr(scan->indexRelation); double tuples = 0; TupleTableSlot *slot = so->vslot; + int batchProbes = 0; + + tuplesort_reset(so->sortstate); /* Search closest probes lists */ - while (!pairingheap_is_empty(so->listQueue)) + while (so->listIndex < so->maxProbes && (++batchProbes) <= so->probes) { - BlockNumber searchPage = GetScanList(pairingheap_remove_first(so->listQueue))->startPage; + BlockNumber searchPage = so->listPages[so->listIndex++]; /* Search all entry pages for list */ while (BlockNumberIsValid(searchPage)) @@ -166,13 +174,17 @@ GetScanItems(IndexScanDesc scan, Datum value) } } - if (tuples < 100) + if (tuples < 100 && ivfflat_iterative_search == IVFFLAT_ITERATIVE_SEARCH_OFF) ereport(DEBUG1, (errmsg("index scan found few tuples"), errdetail("Index may have been created with little data."), errhint("Recreate the index and possibly decrease lists."))); tuplesort_performsort(so->sortstate); + +#if defined(IVFFLAT_MEMORY) + elog(INFO, "memory: %zu MB", MemoryContextMemAllocated(CurrentMemoryContext, true) / (1024 * 1024)); +#endif } /* @@ -240,6 +252,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) int lists; int dimensions; int probes = ivfflat_probes; + int maxProbes; scan = RelationGetIndexScan(index, nkeys, norderbys); @@ -249,10 +262,21 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) if (probes > lists) probes = lists; - so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList)); + if (ivfflat_iterative_search != IVFFLAT_ITERATIVE_SEARCH_OFF) + { + if (ivfflat_iterative_search_max_probes == 0) + maxProbes = lists; + else + maxProbes = Min(ivfflat_iterative_search_max_probes, lists); + } + else + maxProbes = probes; + + so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + maxProbes * sizeof(IvfflatScanList)); so->typeInfo = IvfflatGetTypeInfo(index); so->first = true; so->probes = probes; + so->maxProbes = maxProbes; so->dimensions = dimensions; /* Set support functions */ @@ -280,6 +304,8 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) so->bas = GetAccessStrategy(BAS_BULKREAD); so->listQueue = pairingheap_allocate(CompareLists, scan); + so->listPages = palloc(maxProbes * sizeof(BlockNumber)); + so->listIndex = 0; scan->opaque = so; @@ -294,11 +320,9 @@ ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; - if (!so->first) - tuplesort_reset(so->sortstate); - so->first = true; pairingheap_reset(so->listQueue); + so->listIndex = 0; if (keys && scan->numberOfKeys > 0) memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData)); @@ -314,6 +338,8 @@ bool ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; + ItemPointer heaptid; + bool isnull; /* * Index can be used to scan backward, but Postgres doesn't support @@ -341,28 +367,25 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) IvfflatBench("GetScanLists", GetScanLists(scan, value)); IvfflatBench("GetScanItems", GetScanItems(scan, value)); so->first = false; + so->value = value; -#if defined(IVFFLAT_MEMORY) - elog(INFO, "memory: %zu MB", MemoryContextMemAllocated(CurrentMemoryContext, true) / (1024 * 1024)); -#endif - - /* Clean up if we allocated a new value */ - if (value != scan->orderByData->sk_argument) - pfree(DatumGetPointer(value)); + /* TODO clean up if we allocated a new value */ } - if (tuplesort_gettupleslot(so->sortstate, true, false, so->mslot, NULL)) + while (!tuplesort_gettupleslot(so->sortstate, true, false, so->mslot, NULL)) { - bool isnull; - ItemPointer heaptid = (ItemPointer) DatumGetPointer(slot_getattr(so->mslot, 2, &isnull)); + if (so->listIndex == so->maxProbes) + return false; - scan->xs_heaptid = *heaptid; - scan->xs_recheck = false; - scan->xs_recheckorderby = false; - return true; + IvfflatBench("GetScanItems", GetScanItems(scan, so->value)); } - return false; + heaptid = (ItemPointer) DatumGetPointer(slot_getattr(so->mslot, 2, &isnull)); + + scan->xs_heaptid = *heaptid; + scan->xs_recheck = false; + scan->xs_recheckorderby = false; + return true; } /* @@ -374,6 +397,7 @@ ivfflatendscan(IndexScanDesc scan) IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; pairingheap_free(so->listQueue); + pfree(so->listPages); tuplesort_end(so->sortstate); FreeAccessStrategy(so->bas); FreeTupleDesc(so->tupdesc); diff --git a/src/ivfvacuum.c b/src/ivfvacuum.c index 57815af..1272da8 100644 --- a/src/ivfvacuum.c +++ b/src/ivfvacuum.c @@ -26,7 +26,7 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, Page cpage; OffsetNumber coffno; OffsetNumber cmaxoffno; - BlockNumber startPages[MaxOffsetNumber]; + BlockNumber listPages[MaxOffsetNumber]; ListInfo listInfo; cbuf = ReadBuffer(index, blkno); @@ -40,7 +40,7 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, { IvfflatList list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, coffno)); - startPages[coffno - FirstOffsetNumber] = list->startPage; + listPages[coffno - FirstOffsetNumber] = list->startPage; } listInfo.blkno = blkno; @@ -50,7 +50,7 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno)) { - BlockNumber searchPage = startPages[coffno - FirstOffsetNumber]; + BlockNumber searchPage = listPages[coffno - FirstOffsetNumber]; BlockNumber insertPage = InvalidBlockNumber; /* Iterate over entry pages */ diff --git a/test/t/041_ivfflat_iterative_search.pl b/test/t/041_ivfflat_iterative_search.pl new file mode 100644 index 0000000..231c49e --- /dev/null +++ b/test/t/041_ivfflat_iterative_search.pl @@ -0,0 +1,54 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); + +# Initialize node +my $node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4 PRIMARY KEY, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" +); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);"); + +my $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = 10; + SET ivfflat.iterative_search = on; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; +)); +is($count, 10); + +foreach ((30, 50, 70)) +{ + my $max_probes = $_; + my $expected = $max_probes / 10; + my $sum = 0; + + for my $i (1 .. 20) + { + $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = 10; + SET ivfflat.iterative_search = on; + SET ivfflat.iterative_search_max_probes = $max_probes; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst WHERE i = $i) LIMIT 11) t; + )); + $sum += $count; + } + + my $avg = $sum / 20; + cmp_ok($avg, '>', $expected - 2); + cmp_ok($avg, '<', $expected + 2); +} + +done_testing(); diff --git a/test/t/042_ivfflat_iterative_search_recall.pl b/test/t/042_ivfflat_iterative_search_recall.pl new file mode 100644 index 0000000..6bdddd0 --- /dev/null +++ b/test/t/042_ivfflat_iterative_search_recall.pl @@ -0,0 +1,125 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; +my @cs = (100, 1000); + +sub test_recall +{ + my ($c, $probes, $min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + SET ivfflat.iterative_search = on; + EXPLAIN ANALYZE SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx on tst/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + SET ivfflat.iterative_search = on; + SELECT i FROM tst WHERE i % $c = 0 ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + + my @expected_ids = split("\n", $expected[$i]); + my %expected_set = map { $_ => 1 } @expected_ids; + + foreach (@actual_ids) + { + if (exists($expected_set{$_})) + { + $correct++; + } + } + + $total += $limit; + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "[$r1,$r2,$r3]"); +} + +# Check each index type +my @operators = ("<->", "<=>"); +my @opclasses = ("vector_l2_ops", "vector_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING ivfflat (v $opclass);"); + + foreach (@cs) + { + my $c = $_; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + WITH top AS ( + SELECT v $operator '$_' AS distance FROM tst WHERE i % $c = 0 ORDER BY distance LIMIT $limit + ) + SELECT i FROM tst WHERE (v $operator '$_') <= (SELECT MAX(distance) FROM top) + )); + push(@expected, $res); + } + + if ($c == 100) + { + test_recall($c, 1, 0.58, $operator); + test_recall($c, 10, 0.98, $operator); + } + else + { + if ($operator eq "<->") + { + test_recall($c, 1, 0.80, $operator); + } + else + { + test_recall($c, 1, 0.88, $operator); + } + } + } + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing();