diff --git a/src/ivfflat.c b/src/ivfflat.c index 4e9b9a4..a8efb75 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -17,6 +17,7 @@ #endif int ivfflat_probes; +bool ivfflat_streaming; static relopt_kind ivfflat_relopt_kind; /* @@ -33,6 +34,10 @@ IvfflatInit(void) "Valid range is 1..lists.", &ivfflat_probes, IVFFLAT_DEFAULT_PROBES, IVFFLAT_MIN_LISTS, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL); + DefineCustomBoolVariable("ivfflat.streaming", "Use streaming mode", + NULL, &ivfflat_streaming, + IVFFLAT_DEFAULT_STREAMING, PGC_USERSET, 0, NULL, NULL, NULL); + MarkGUCPrefixReserved("ivfflat"); } diff --git a/src/ivfflat.h b/src/ivfflat.h index 10b4d91..54774aa 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -43,6 +43,7 @@ #define IVFFLAT_MIN_LISTS 1 #define IVFFLAT_MAX_LISTS 32768 #define IVFFLAT_DEFAULT_PROBES 1 +#define IVFFLAT_DEFAULT_STREAMING false /* Build phases */ /* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ @@ -80,6 +81,7 @@ /* Variables */ extern int ivfflat_probes; +extern bool ivfflat_streaming; typedef struct VectorArrayData { @@ -247,8 +249,10 @@ typedef struct IvfflatScanOpaqueData { const IvfflatTypeInfo *typeInfo; int probes; + int maxProbes; int dimensions; bool first; + Datum value; /* Sorting */ Tuplesortstate *sortstate; diff --git a/src/ivfscan.c b/src/ivfscan.c index 1e95cd6..6c06989 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -30,6 +30,21 @@ CompareLists(const pairingheap_node *a, const pairingheap_node *b, void *arg) return 0; } +/* + * Compare list distances for streaming + */ +static int +CompareListsStreaming(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (((const IvfflatScanList *) a)->distance > ((const IvfflatScanList *) b)->distance) + return -1; + + if (((const IvfflatScanList *) a)->distance < ((const IvfflatScanList *) b)->distance) + return 1; + + return 0; +} + /* * Get lists and sort by distance */ @@ -62,7 +77,7 @@ GetScanLists(IndexScanDesc scan, Datum value) /* Use procinfo from the index instead of scan key for performance */ distance = DatumGetFloat8(so->distfunc(so->procinfo, so->collation, PointerGetDatum(&list->center), value)); - if (listCount < so->probes) + if (listCount < so->maxProbes) { IvfflatScanList *scanlist; @@ -111,6 +126,7 @@ GetScanItems(IndexScanDesc scan, Datum value) TupleDesc tupdesc = RelationGetDescr(scan->indexRelation); double tuples = 0; TupleTableSlot *slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsVirtual); + int batchProbes = 0; /* * Reuse same set of shared buffers for scan @@ -119,8 +135,10 @@ GetScanItems(IndexScanDesc scan, Datum value) */ BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD); + tuplesort_reset(so->sortstate); + /* Search closest probes lists */ - while (!pairingheap_is_empty(so->listQueue)) + while (!pairingheap_is_empty(so->listQueue) && (++batchProbes) <= so->probes) { BlockNumber searchPage = ((IvfflatScanList *) pairingheap_remove_first(so->listQueue))->startPage; @@ -172,13 +190,17 @@ GetScanItems(IndexScanDesc scan, Datum value) FreeAccessStrategy(bas); - if (tuples < 100) + if (tuples < 100 && !ivfflat_streaming) ereport(DEBUG1, (errmsg("index scan found few tuples"), errdetail("Index may have been created with little data."), errhint("Recreate the index and possibly decrease lists."))); tuplesort_performsort(so->sortstate); + +#if defined(IVFFLAT_MEMORY) + elog(INFO, "memory: %zu MB", MemoryContextMemAllocated(CurrentMemoryContext, true) / (1024 * 1024)); +#endif } /* @@ -246,6 +268,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) int lists; int dimensions; int probes = ivfflat_probes; + int maxProbes; scan = RelationGetIndexScan(index, nkeys, norderbys); @@ -255,10 +278,13 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) if (probes > lists) probes = lists; - so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList)); + maxProbes = ivfflat_streaming ? lists : probes; + + so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + maxProbes * sizeof(IvfflatScanList)); so->typeInfo = IvfflatGetTypeInfo(index); so->first = true; so->probes = probes; + so->maxProbes = maxProbes; so->dimensions = dimensions; /* Set support functions */ @@ -276,7 +302,11 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) so->slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsMinimalTuple); - so->listQueue = pairingheap_allocate(CompareLists, scan); + /* Order by closest list for streaming */ + if (ivfflat_streaming) + so->listQueue = pairingheap_allocate(CompareListsStreaming, scan); + else + so->listQueue = pairingheap_allocate(CompareLists, scan); scan->opaque = so; @@ -291,9 +321,6 @@ ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; - if (!so->first) - tuplesort_reset(so->sortstate); - so->first = true; pairingheap_reset(so->listQueue); @@ -311,6 +338,7 @@ bool ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; + ItemPointer heaptid; /* * Index can be used to scan backward, but Postgres doesn't support @@ -338,27 +366,25 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) IvfflatBench("GetScanLists", GetScanLists(scan, value)); IvfflatBench("GetScanItems", GetScanItems(scan, value)); so->first = false; + so->value = value; -#if defined(IVFFLAT_MEMORY) - elog(INFO, "memory: %zu MB", MemoryContextMemAllocated(CurrentMemoryContext, true) / (1024 * 1024)); -#endif - - /* Clean up if we allocated a new value */ - if (value != scan->orderByData->sk_argument) - pfree(DatumGetPointer(value)); + /* TODO clean up if we allocated a new value */ } - if (tuplesort_gettupleslot(so->sortstate, true, false, so->slot, NULL)) + while (!tuplesort_gettupleslot(so->sortstate, true, false, so->slot, NULL)) { - ItemPointer heaptid = (ItemPointer) DatumGetPointer(slot_getattr(so->slot, 2, &so->isnull)); + if (pairingheap_is_empty(so->listQueue)) + return false; - scan->xs_heaptid = *heaptid; - scan->xs_recheck = false; - scan->xs_recheckorderby = false; - return true; + IvfflatBench("GetScanItems", GetScanItems(scan, so->value)); } - return false; + heaptid = (ItemPointer) DatumGetPointer(slot_getattr(so->slot, 2, &so->isnull)); + + scan->xs_heaptid = *heaptid; + scan->xs_recheck = false; + scan->xs_recheckorderby = false; + return true; } /* diff --git a/test/t/039_ivfflat_streaming.pl b/test/t/039_ivfflat_streaming.pl new file mode 100644 index 0000000..b570498 --- /dev/null +++ b/test/t/039_ivfflat_streaming.pl @@ -0,0 +1,31 @@ +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $dim = 3; +my $array_sql = join(",", ('random()') x $dim); + +# Initialize node +my $node = PostgreSQL::Test::Cluster->new('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" +); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);"); + +my $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = 10; + SET ivfflat.streaming = on; + SELECT COUNT(*) FROM (SELECT v FROM tst WHERE i % 10000 = 0 ORDER BY v <-> (SELECT v FROM tst LIMIT 1) LIMIT 11) t; +)); +is($count, 10); + +done_testing();