diff --git a/CHANGELOG.md b/CHANGELOG.md index 29896b1..fc1abdf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.2.6 (unreleased) - Switched to mini-batch k-means +- Improved performance of index creation for Postgres < 12 ## 0.2.5 (2022-02-11) diff --git a/src/ivfbuild.c b/src/ivfbuild.c index 15538c6..69806bb 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -36,16 +36,11 @@ #define CALLBACK_ITEM_POINTER HeapTuple hup #endif -/* - * Update build phase progress - */ -static inline void -UpdateProgress(int index, int64 val) -{ #if PG_VERSION_NUM >= 120000 - pgstat_progress_update_param(index, val); +#define UpdateProgress(index, val) pgstat_progress_update_param(index, val) +#else +#define UpdateProgress(index, val) ((void)val) #endif -} /* * Callback for table_index_build_scan @@ -91,18 +86,18 @@ BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, #ifdef IVFFLAT_KMEANS_DEBUG buildstate->inertia += minDistance; + buildstate->listSums[closestCenter] += minDistance; + buildstate->listCounts[closestCenter]++; #endif /* Create a virtual tuple */ ExecClearTuple(slot); slot->tts_values[0] = Int32GetDatum(closestCenter); slot->tts_isnull[0] = false; - slot->tts_values[1] = Int32GetDatum(ItemPointerGetBlockNumberNoCheck(tid)); + slot->tts_values[1] = PointerGetDatum(tid); slot->tts_isnull[1] = false; - slot->tts_values[2] = Int32GetDatum(ItemPointerGetOffsetNumberNoCheck(tid)); + slot->tts_values[2] = value; slot->tts_isnull[2] = false; - slot->tts_values[3] = value; - slot->tts_isnull[3] = false; ExecStoreVirtualTuple(slot); /* @@ -124,8 +119,6 @@ GetNextTuple(Tuplesortstate *sortstate, TupleDesc tupdesc, TupleTableSlot *slot, { Datum value; bool isnull; - int tupblk; - int tupoff; #if PG_VERSION_NUM >= 100000 if (tuplesort_gettupleslot(sortstate, true, false, slot, NULL)) @@ -134,13 +127,11 @@ GetNextTuple(Tuplesortstate *sortstate, TupleDesc tupdesc, TupleTableSlot *slot, #endif { *list = DatumGetInt32(slot_getattr(slot, 1, &isnull)); - tupblk = DatumGetInt32(slot_getattr(slot, 2, &isnull)); - tupoff = DatumGetInt32(slot_getattr(slot, 3, &isnull)); - value = slot_getattr(slot, 4, &isnull); + value = slot_getattr(slot, 3, &isnull); /* Form the index tuple */ *itup = index_form_tuple(tupdesc, &value, &isnull); - ItemPointerSet(&(*itup)->t_tid, tupblk, tupoff); + (*itup)->t_tid = *((ItemPointer) DatumGetPointer(slot_getattr(slot, 2, &isnull))); } else *list = -1; @@ -249,17 +240,16 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In /* Create tuple description for sorting */ #if PG_VERSION_NUM >= 120000 - buildstate->tupdesc = CreateTemplateTupleDesc(4); + buildstate->tupdesc = CreateTemplateTupleDesc(3); #else - buildstate->tupdesc = CreateTemplateTupleDesc(4, false); + buildstate->tupdesc = CreateTemplateTupleDesc(3, false); #endif TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 1, "list", INT4OID, -1, 0); - TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0); - TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0); + TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 2, "tid", TIDOID, -1, 0); #if PG_VERSION_NUM >= 110000 - TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0].atttypid, -1, 0); + TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "vector", RelationGetDescr(index)->attrs[0].atttypid, -1, 0); #else - TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0]->atttypid, -1, 0); + TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "vector", RelationGetDescr(index)->attrs[0]->atttypid, -1, 0); #endif #if PG_VERSION_NUM >= 120000 @@ -276,6 +266,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In #ifdef IVFFLAT_KMEANS_DEBUG buildstate->inertia = 0; + buildstate->listSums = palloc0(sizeof(double) * buildstate->lists); + buildstate->listCounts = palloc0(sizeof(int) * buildstate->lists); #endif } @@ -288,6 +280,11 @@ FreeBuildState(IvfflatBuildState * buildstate) pfree(buildstate->centers); pfree(buildstate->listInfo); pfree(buildstate->normvec); + +#ifdef IVFFLAT_KMEANS_DEBUG + pfree(buildstate->listSums); + pfree(buildstate->listCounts); +#endif } /* @@ -363,6 +360,51 @@ CreateListPages(Relation index, VectorArray centers, int dimensions, pfree(list); } +/* + * Print k-means metrics + */ +#ifdef IVFFLAT_KMEANS_DEBUG +static void +PrintKmeansMetrics(IvfflatBuildState * buildstate) +{ + elog(INFO, "inertia: %.3e", buildstate->inertia); + + /* Calculate Davies-Bouldin index */ + if (buildstate->lists > 1) + { + double db = 0.0; + + /* Calculate average distance */ + for (int i = 0; i < buildstate->lists; i++) + { + if (buildstate->listCounts[i] > 0) + buildstate->listSums[i] /= buildstate->listCounts[i]; + } + + for (int i = 0; i < buildstate->lists; i++) + { + double max = 0.0; + double distance; + + for (int j = 0; j < buildstate->lists; j++) + { + if (j == i) + continue; + + distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, PointerGetDatum(VectorArrayGet(buildstate->centers, i)), PointerGetDatum(VectorArrayGet(buildstate->centers, j)))); + distance = (buildstate->listSums[i] + buildstate->listSums[j]) / distance; + + if (distance > max) + max = distance; + } + db += max; + } + db /= buildstate->lists; + elog(INFO, "davies-bouldin: %.3f", db); + } +} +#endif + /* * Create entry pages */ @@ -401,7 +443,7 @@ CreateEntryPages(IvfflatBuildState * buildstate, ForkNumber forkNum) tuplesort_performsort(buildstate->sortstate); #ifdef IVFFLAT_KMEANS_DEBUG - elog(INFO, "inertia: %.3e", buildstate->inertia); + PrintKmeansMetrics(buildstate); #endif /* Insert */ diff --git a/src/ivfflat.h b/src/ivfflat.h index 7760784..be72fe5 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -61,11 +61,6 @@ #define IvfflatBench(name, code) (code) #endif -#if PG_VERSION_NUM < 100000 -#define ItemPointerGetBlockNumberNoCheck ItemPointerGetBlockNumber -#define ItemPointerGetOffsetNumberNoCheck ItemPointerGetOffsetNumber -#endif - /* Variables */ extern int ivfflat_probes; @@ -121,6 +116,8 @@ typedef struct IvfflatBuildState #ifdef IVFFLAT_KMEANS_DEBUG double inertia; + double *listSums; + int *listCounts; #endif /* Sampling */ @@ -164,6 +161,7 @@ typedef IvfflatListData * IvfflatList; typedef struct IvfflatScanList { + pairingheap_node ph_node; BlockNumber startPage; double distance; } IvfflatScanList; @@ -185,6 +183,8 @@ typedef struct IvfflatScanOpaqueData FmgrInfo *normprocinfo; Oid collation; + /* Lists */ + pairingheap *listQueue; IvfflatScanList lists[FLEXIBLE_ARRAY_MEMBER]; /* must come last */ } IvfflatScanOpaqueData; diff --git a/src/ivfscan.c b/src/ivfscan.c index dfd52af..e2171b0 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -1,5 +1,7 @@ #include "postgres.h" +#include + #include "access/relscan.h" #include "ivfflat.h" #include "miscadmin.h" @@ -17,14 +19,12 @@ * Compare list distances */ static int -CompareLists(const void *a, const void *b) +CompareLists(const pairingheap_node *a, const pairingheap_node *b, void *arg) { - double diff = (((IvfflatScanList *) a)->distance - ((IvfflatScanList *) b)->distance); - - if (diff > 0) + if (((const IvfflatScanList *) a)->distance > ((const IvfflatScanList *) b)->distance) return 1; - if (diff < 0) + if (((const IvfflatScanList *) a)->distance < ((const IvfflatScanList *) b)->distance) return -1; return 0; @@ -45,6 +45,8 @@ GetScanLists(IndexScanDesc scan, Datum value) int listCount = 0; IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; double distance; + IvfflatScanList *scanlist; + double maxDistance = DBL_MAX; /* Search all list pages */ while (BlockNumberIsValid(nextblkno)) @@ -62,22 +64,39 @@ GetScanLists(IndexScanDesc scan, Datum value) /* Use procinfo from the index instead of scan key for performance */ distance = DatumGetFloat8(FunctionCall2Coll(so->procinfo, so->collation, PointerGetDatum(&list->center), value)); - so->lists[listCount].startPage = list->startPage; - so->lists[listCount].distance = distance; - listCount++; + if (listCount < so->probes) + { + scanlist = &so->lists[listCount]; + scanlist->startPage = list->startPage; + scanlist->distance = distance; + listCount++; + + /* Add to heap */ + pairingheap_add(so->listQueue, &scanlist->ph_node); + + /* Calculate max distance */ + if (listCount == so->probes) + maxDistance = ((IvfflatScanList *) pairingheap_first(so->listQueue))->distance; + } + else if (distance < maxDistance) + { + /* Remove */ + scanlist = (IvfflatScanList *) pairingheap_remove_first(so->listQueue); + + /* Reuse */ + scanlist->startPage = list->startPage; + scanlist->distance = distance; + pairingheap_add(so->listQueue, &scanlist->ph_node); + + /* Update max distance */ + maxDistance = ((IvfflatScanList *) pairingheap_first(so->listQueue))->distance; + } } nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno; UnlockReleaseBuffer(cbuf); } - - /* Sort by distance */ - /* TODO Use heap for performance */ - qsort(so->lists, listCount, sizeof(IvfflatScanList), CompareLists); - - if (so->probes > listCount) - so->probes = listCount; } /* @@ -95,7 +114,6 @@ GetScanItems(IndexScanDesc scan, Datum value) OffsetNumber maxoffno; Datum datum; bool isnull; - int i; TupleDesc tupdesc = RelationGetDescr(scan->indexRelation); #if PG_VERSION_NUM >= 120000 @@ -112,9 +130,9 @@ GetScanItems(IndexScanDesc scan, Datum value) BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD); /* Search closest probes lists */ - for (i = 0; i < so->probes; i++) + while (!pairingheap_is_empty(so->listQueue)) { - searchPage = so->lists[i].startPage; + searchPage = ((IvfflatScanList *) pairingheap_remove_first(so->listQueue))->startPage; /* Search all entry pages for list */ while (BlockNumberIsValid(searchPage)) @@ -138,12 +156,10 @@ GetScanItems(IndexScanDesc scan, Datum value) ExecClearTuple(slot); slot->tts_values[0] = FunctionCall2Coll(so->procinfo, so->collation, datum, value); slot->tts_isnull[0] = false; - slot->tts_values[1] = Int32GetDatum((int) ItemPointerGetBlockNumberNoCheck(&itup->t_tid)); + slot->tts_values[1] = PointerGetDatum(&itup->t_tid); slot->tts_isnull[1] = false; - slot->tts_values[2] = Int32GetDatum((int) ItemPointerGetOffsetNumberNoCheck(&itup->t_tid)); + slot->tts_values[2] = Int32GetDatum((int) searchPage); slot->tts_isnull[2] = false; - slot->tts_values[3] = Int32GetDatum((int) searchPage); - slot->tts_isnull[3] = false; ExecStoreVirtualTuple(slot); tuplesort_puttupleslot(so->sortstate, slot); @@ -171,13 +187,18 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) Oid sortOperators[] = {Float8LessOperator}; Oid sortCollations[] = {InvalidOid}; bool nullsFirstFlags[] = {false}; + int probes = ivfflat_probes; scan = RelationGetIndexScan(index, nkeys, norderbys); lists = IvfflatGetLists(scan->indexRelation); - so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + lists * sizeof(IvfflatScanList)); + if (probes > lists) + probes = lists; + + so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList)); so->buf = InvalidBuffer; so->first = true; + so->probes = probes; /* Set support functions */ so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); @@ -186,14 +207,13 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) /* Create tuple description for sorting */ #if PG_VERSION_NUM >= 120000 - so->tupdesc = CreateTemplateTupleDesc(4); + so->tupdesc = CreateTemplateTupleDesc(3); #else - so->tupdesc = CreateTemplateTupleDesc(4, false); + so->tupdesc = CreateTemplateTupleDesc(3, false); #endif TupleDescInitEntry(so->tupdesc, (AttrNumber) 1, "distance", FLOAT8OID, -1, 0); - TupleDescInitEntry(so->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0); - TupleDescInitEntry(so->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0); - TupleDescInitEntry(so->tupdesc, (AttrNumber) 4, "indexblkno", INT4OID, -1, 0); + TupleDescInitEntry(so->tupdesc, (AttrNumber) 2, "tid", TIDOID, -1, 0); + TupleDescInitEntry(so->tupdesc, (AttrNumber) 3, "indexblkno", INT4OID, -1, 0); /* Prep sort */ #if PG_VERSION_NUM >= 110000 @@ -208,6 +228,8 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) so->slot = MakeSingleTupleTableSlot(so->tupdesc); #endif + so->listQueue = pairingheap_allocate(CompareLists, scan); + scan->opaque = so; return scan; @@ -227,7 +249,7 @@ ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int #endif so->first = true; - so->probes = ivfflat_probes; + pairingheap_reset(so->listQueue); if (keys && scan->numberOfKeys > 0) memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData)); @@ -286,14 +308,13 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) if (tuplesort_gettupleslot(so->sortstate, true, so->slot, NULL)) #endif { - BlockNumber blkno = DatumGetInt32(slot_getattr(so->slot, 2, &so->isnull)); - OffsetNumber offset = DatumGetInt32(slot_getattr(so->slot, 3, &so->isnull)); - BlockNumber indexblkno = DatumGetInt32(slot_getattr(so->slot, 4, &so->isnull)); + ItemPointer tid = (ItemPointer) DatumGetPointer(slot_getattr(so->slot, 2, &so->isnull)); + BlockNumber indexblkno = DatumGetInt32(slot_getattr(so->slot, 3, &so->isnull)); #if PG_VERSION_NUM >= 120000 - ItemPointerSet(&scan->xs_heaptid, blkno, offset); + scan->xs_heaptid = *tid; #else - ItemPointerSet(&scan->xs_ctup.t_self, blkno, offset); + scan->xs_ctup.t_self = *tid; #endif if (BufferIsValid(so->buf)) @@ -326,6 +347,7 @@ ivfflatendscan(IndexScanDesc scan) if (BufferIsValid(so->buf)) ReleaseBuffer(so->buf); + pairingheap_free(so->listQueue); tuplesort_end(so->sortstate); pfree(so); diff --git a/test/t/003_recall.pl b/test/t/003_recall.pl index 4992b0a..dddc4d5 100644 --- a/test/t/003_recall.pl +++ b/test/t/003_recall.pl @@ -2,15 +2,16 @@ use strict; use warnings; use PostgresNode; use TestLib; -use Test::More tests => 2; +use Test::More tests => 9; my $node; my @queries = (); -my @expected = (); +my @expected; +my $limit = 20; sub test_recall { - my ($probes, $min) = @_; + my ($probes, $min, $operator) = @_; my $correct = 0; my $total = 0; @@ -18,7 +19,7 @@ sub test_recall my $actual = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SET ivfflat.probes = $probes; - SELECT i FROM tst ORDER BY v <-> '$queries[$i]' LIMIT 10; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; )); my @actual_ids = split("\n", $actual); my %actual_set = map { $_ => 1 } @actual_ids; @@ -33,7 +34,7 @@ sub test_recall } } - cmp_ok($correct / $total, ">=", $min); + cmp_ok($correct / $total, ">=", $min, $operator); } # Initialize node @@ -56,17 +57,32 @@ for (1..20) { push(@queries, "[$r1,$r2,$r3]"); } -# Get exact results -foreach (@queries) { - my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v <-> '$_' LIMIT 10;"); - push(@expected, $res); +# Check each index type +my @operators = ("<->", "<#>", "<=>"); + +foreach (@operators) { + my $operator = $_; + + # Get exact results + @expected = (); + foreach (@queries) { + my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); + push(@expected, $res); + } + + # Add index + my $opclass; + if ($operator == "<->") { + $opclass = "vector_l2_ops"; + } elsif ($operator == "<#>") { + $opclass = "vector_ip_ops"; + } else { + $opclass = "vector_cosine_ops"; + } + $node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v $opclass);"); + + # Test approximate results + test_recall(1, 0.75, $operator); + test_recall(10, 0.95, $operator); + test_recall(100, 1.0, $operator); } - -# Add index -$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v);"); - -# Test approximate results -test_recall(1, 0.8); - -# Test probes -test_recall(100, 1.0); diff --git a/test/t/005_query_recall.pl b/test/t/005_query_recall.pl new file mode 100644 index 0000000..0e58135 --- /dev/null +++ b/test/t/005_query_recall.pl @@ -0,0 +1,45 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More tests => 60; + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4 primary key, v vector(3));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;" +); + +# Check each index type +my @operators = ("<->", "<#>", "<=>"); +foreach (@operators) { + my $operator = $_; + + # Add index + my $opclass; + if ($operator == "<->") { + $opclass = "vector_l2_ops"; + } elsif ($operator == "<#>") { + $opclass = "vector_ip_ops"; + } else { + $opclass = "vector_cosine_ops"; + } + $node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v $opclass);"); + + # Test 100% recall + for (1..20) { + my $i = int(rand() * 100000); + my $query = $node->safe_psql("postgres", "SELECT v FROM tst WHERE i = $i;"); + my $res = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT v FROM tst ORDER BY v <-> '$query' LIMIT 1; + )); + is($res, $query); + } +} diff --git a/test/t/006_lists.pl b/test/t/006_lists.pl new file mode 100644 index 0000000..eeb11aa --- /dev/null +++ b/test/t/006_lists.pl @@ -0,0 +1,31 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More tests => 3; + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (v vector(3));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;" +); + +$node->safe_psql("postgres", "CREATE INDEX lists50 ON tst USING ivfflat (v) WITH (lists = 50);"); +$node->safe_psql("postgres", "CREATE INDEX lists100 ON tst USING ivfflat (v) WITH (lists = 100);"); + +# Test prefers more lists +my $res = $node->safe_psql("postgres", "EXPLAIN SELECT v FROM tst ORDER BY v <-> '[0.5,0.5,0.5]' LIMIT 10;"); +like($res, qr/lists100/); +unlike($res, qr/lists50/); + +# Test errors with too much memory +my ($ret, $stdout, $stderr) = $node->psql("postgres", + "CREATE INDEX lists10000 ON tst USING ivfflat (v) WITH (lists = 10000);" +); +like($stderr, qr/memory required is/);