Merge branch 'master' into minibatch

This commit is contained in:
Andrew Kane
2022-02-15 19:14:41 -08:00
7 changed files with 239 additions and 82 deletions

View File

@@ -1,6 +1,7 @@
## 0.2.6 (unreleased)
- Switched to mini-batch k-means
- Improved performance of index creation for Postgres < 12
## 0.2.5 (2022-02-11)

View File

@@ -36,16 +36,11 @@
#define CALLBACK_ITEM_POINTER HeapTuple hup
#endif
/*
* Update build phase progress
*/
static inline void
UpdateProgress(int index, int64 val)
{
#if PG_VERSION_NUM >= 120000
pgstat_progress_update_param(index, val);
#define UpdateProgress(index, val) pgstat_progress_update_param(index, val)
#else
#define UpdateProgress(index, val) ((void)val)
#endif
}
/*
* Callback for table_index_build_scan
@@ -91,18 +86,18 @@ BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
#ifdef IVFFLAT_KMEANS_DEBUG
buildstate->inertia += minDistance;
buildstate->listSums[closestCenter] += minDistance;
buildstate->listCounts[closestCenter]++;
#endif
/* Create a virtual tuple */
ExecClearTuple(slot);
slot->tts_values[0] = Int32GetDatum(closestCenter);
slot->tts_isnull[0] = false;
slot->tts_values[1] = Int32GetDatum(ItemPointerGetBlockNumberNoCheck(tid));
slot->tts_values[1] = PointerGetDatum(tid);
slot->tts_isnull[1] = false;
slot->tts_values[2] = Int32GetDatum(ItemPointerGetOffsetNumberNoCheck(tid));
slot->tts_values[2] = value;
slot->tts_isnull[2] = false;
slot->tts_values[3] = value;
slot->tts_isnull[3] = false;
ExecStoreVirtualTuple(slot);
/*
@@ -124,8 +119,6 @@ GetNextTuple(Tuplesortstate *sortstate, TupleDesc tupdesc, TupleTableSlot *slot,
{
Datum value;
bool isnull;
int tupblk;
int tupoff;
#if PG_VERSION_NUM >= 100000
if (tuplesort_gettupleslot(sortstate, true, false, slot, NULL))
@@ -134,13 +127,11 @@ GetNextTuple(Tuplesortstate *sortstate, TupleDesc tupdesc, TupleTableSlot *slot,
#endif
{
*list = DatumGetInt32(slot_getattr(slot, 1, &isnull));
tupblk = DatumGetInt32(slot_getattr(slot, 2, &isnull));
tupoff = DatumGetInt32(slot_getattr(slot, 3, &isnull));
value = slot_getattr(slot, 4, &isnull);
value = slot_getattr(slot, 3, &isnull);
/* Form the index tuple */
*itup = index_form_tuple(tupdesc, &value, &isnull);
ItemPointerSet(&(*itup)->t_tid, tupblk, tupoff);
(*itup)->t_tid = *((ItemPointer) DatumGetPointer(slot_getattr(slot, 2, &isnull)));
}
else
*list = -1;
@@ -249,17 +240,16 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
/* Create tuple description for sorting */
#if PG_VERSION_NUM >= 120000
buildstate->tupdesc = CreateTemplateTupleDesc(4);
buildstate->tupdesc = CreateTemplateTupleDesc(3);
#else
buildstate->tupdesc = CreateTemplateTupleDesc(4, false);
buildstate->tupdesc = CreateTemplateTupleDesc(3, false);
#endif
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 1, "list", INT4OID, -1, 0);
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0);
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0);
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 2, "tid", TIDOID, -1, 0);
#if PG_VERSION_NUM >= 110000
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0].atttypid, -1, 0);
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "vector", RelationGetDescr(index)->attrs[0].atttypid, -1, 0);
#else
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 4, "vector", RelationGetDescr(index)->attrs[0]->atttypid, -1, 0);
TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "vector", RelationGetDescr(index)->attrs[0]->atttypid, -1, 0);
#endif
#if PG_VERSION_NUM >= 120000
@@ -276,6 +266,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
#ifdef IVFFLAT_KMEANS_DEBUG
buildstate->inertia = 0;
buildstate->listSums = palloc0(sizeof(double) * buildstate->lists);
buildstate->listCounts = palloc0(sizeof(int) * buildstate->lists);
#endif
}
@@ -288,6 +280,11 @@ FreeBuildState(IvfflatBuildState * buildstate)
pfree(buildstate->centers);
pfree(buildstate->listInfo);
pfree(buildstate->normvec);
#ifdef IVFFLAT_KMEANS_DEBUG
pfree(buildstate->listSums);
pfree(buildstate->listCounts);
#endif
}
/*
@@ -363,6 +360,51 @@ CreateListPages(Relation index, VectorArray centers, int dimensions,
pfree(list);
}
/*
* Print k-means metrics
*/
#ifdef IVFFLAT_KMEANS_DEBUG
static void
PrintKmeansMetrics(IvfflatBuildState * buildstate)
{
elog(INFO, "inertia: %.3e", buildstate->inertia);
/* Calculate Davies-Bouldin index */
if (buildstate->lists > 1)
{
double db = 0.0;
/* Calculate average distance */
for (int i = 0; i < buildstate->lists; i++)
{
if (buildstate->listCounts[i] > 0)
buildstate->listSums[i] /= buildstate->listCounts[i];
}
for (int i = 0; i < buildstate->lists; i++)
{
double max = 0.0;
double distance;
for (int j = 0; j < buildstate->lists; j++)
{
if (j == i)
continue;
distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, PointerGetDatum(VectorArrayGet(buildstate->centers, i)), PointerGetDatum(VectorArrayGet(buildstate->centers, j))));
distance = (buildstate->listSums[i] + buildstate->listSums[j]) / distance;
if (distance > max)
max = distance;
}
db += max;
}
db /= buildstate->lists;
elog(INFO, "davies-bouldin: %.3f", db);
}
}
#endif
/*
* Create entry pages
*/
@@ -401,7 +443,7 @@ CreateEntryPages(IvfflatBuildState * buildstate, ForkNumber forkNum)
tuplesort_performsort(buildstate->sortstate);
#ifdef IVFFLAT_KMEANS_DEBUG
elog(INFO, "inertia: %.3e", buildstate->inertia);
PrintKmeansMetrics(buildstate);
#endif
/* Insert */

View File

@@ -61,11 +61,6 @@
#define IvfflatBench(name, code) (code)
#endif
#if PG_VERSION_NUM < 100000
#define ItemPointerGetBlockNumberNoCheck ItemPointerGetBlockNumber
#define ItemPointerGetOffsetNumberNoCheck ItemPointerGetOffsetNumber
#endif
/* Variables */
extern int ivfflat_probes;
@@ -121,6 +116,8 @@ typedef struct IvfflatBuildState
#ifdef IVFFLAT_KMEANS_DEBUG
double inertia;
double *listSums;
int *listCounts;
#endif
/* Sampling */
@@ -164,6 +161,7 @@ typedef IvfflatListData * IvfflatList;
typedef struct IvfflatScanList
{
pairingheap_node ph_node;
BlockNumber startPage;
double distance;
} IvfflatScanList;
@@ -185,6 +183,8 @@ typedef struct IvfflatScanOpaqueData
FmgrInfo *normprocinfo;
Oid collation;
/* Lists */
pairingheap *listQueue;
IvfflatScanList lists[FLEXIBLE_ARRAY_MEMBER]; /* must come last */
} IvfflatScanOpaqueData;

View File

@@ -1,5 +1,7 @@
#include "postgres.h"
#include <float.h>
#include "access/relscan.h"
#include "ivfflat.h"
#include "miscadmin.h"
@@ -17,14 +19,12 @@
* Compare list distances
*/
static int
CompareLists(const void *a, const void *b)
CompareLists(const pairingheap_node *a, const pairingheap_node *b, void *arg)
{
double diff = (((IvfflatScanList *) a)->distance - ((IvfflatScanList *) b)->distance);
if (diff > 0)
if (((const IvfflatScanList *) a)->distance > ((const IvfflatScanList *) b)->distance)
return 1;
if (diff < 0)
if (((const IvfflatScanList *) a)->distance < ((const IvfflatScanList *) b)->distance)
return -1;
return 0;
@@ -45,6 +45,8 @@ GetScanLists(IndexScanDesc scan, Datum value)
int listCount = 0;
IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque;
double distance;
IvfflatScanList *scanlist;
double maxDistance = DBL_MAX;
/* Search all list pages */
while (BlockNumberIsValid(nextblkno))
@@ -62,22 +64,39 @@ GetScanLists(IndexScanDesc scan, Datum value)
/* Use procinfo from the index instead of scan key for performance */
distance = DatumGetFloat8(FunctionCall2Coll(so->procinfo, so->collation, PointerGetDatum(&list->center), value));
so->lists[listCount].startPage = list->startPage;
so->lists[listCount].distance = distance;
listCount++;
if (listCount < so->probes)
{
scanlist = &so->lists[listCount];
scanlist->startPage = list->startPage;
scanlist->distance = distance;
listCount++;
/* Add to heap */
pairingheap_add(so->listQueue, &scanlist->ph_node);
/* Calculate max distance */
if (listCount == so->probes)
maxDistance = ((IvfflatScanList *) pairingheap_first(so->listQueue))->distance;
}
else if (distance < maxDistance)
{
/* Remove */
scanlist = (IvfflatScanList *) pairingheap_remove_first(so->listQueue);
/* Reuse */
scanlist->startPage = list->startPage;
scanlist->distance = distance;
pairingheap_add(so->listQueue, &scanlist->ph_node);
/* Update max distance */
maxDistance = ((IvfflatScanList *) pairingheap_first(so->listQueue))->distance;
}
}
nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno;
UnlockReleaseBuffer(cbuf);
}
/* Sort by distance */
/* TODO Use heap for performance */
qsort(so->lists, listCount, sizeof(IvfflatScanList), CompareLists);
if (so->probes > listCount)
so->probes = listCount;
}
/*
@@ -95,7 +114,6 @@ GetScanItems(IndexScanDesc scan, Datum value)
OffsetNumber maxoffno;
Datum datum;
bool isnull;
int i;
TupleDesc tupdesc = RelationGetDescr(scan->indexRelation);
#if PG_VERSION_NUM >= 120000
@@ -112,9 +130,9 @@ GetScanItems(IndexScanDesc scan, Datum value)
BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD);
/* Search closest probes lists */
for (i = 0; i < so->probes; i++)
while (!pairingheap_is_empty(so->listQueue))
{
searchPage = so->lists[i].startPage;
searchPage = ((IvfflatScanList *) pairingheap_remove_first(so->listQueue))->startPage;
/* Search all entry pages for list */
while (BlockNumberIsValid(searchPage))
@@ -138,12 +156,10 @@ GetScanItems(IndexScanDesc scan, Datum value)
ExecClearTuple(slot);
slot->tts_values[0] = FunctionCall2Coll(so->procinfo, so->collation, datum, value);
slot->tts_isnull[0] = false;
slot->tts_values[1] = Int32GetDatum((int) ItemPointerGetBlockNumberNoCheck(&itup->t_tid));
slot->tts_values[1] = PointerGetDatum(&itup->t_tid);
slot->tts_isnull[1] = false;
slot->tts_values[2] = Int32GetDatum((int) ItemPointerGetOffsetNumberNoCheck(&itup->t_tid));
slot->tts_values[2] = Int32GetDatum((int) searchPage);
slot->tts_isnull[2] = false;
slot->tts_values[3] = Int32GetDatum((int) searchPage);
slot->tts_isnull[3] = false;
ExecStoreVirtualTuple(slot);
tuplesort_puttupleslot(so->sortstate, slot);
@@ -171,13 +187,18 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
Oid sortOperators[] = {Float8LessOperator};
Oid sortCollations[] = {InvalidOid};
bool nullsFirstFlags[] = {false};
int probes = ivfflat_probes;
scan = RelationGetIndexScan(index, nkeys, norderbys);
lists = IvfflatGetLists(scan->indexRelation);
so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + lists * sizeof(IvfflatScanList));
if (probes > lists)
probes = lists;
so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList));
so->buf = InvalidBuffer;
so->first = true;
so->probes = probes;
/* Set support functions */
so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC);
@@ -186,14 +207,13 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
/* Create tuple description for sorting */
#if PG_VERSION_NUM >= 120000
so->tupdesc = CreateTemplateTupleDesc(4);
so->tupdesc = CreateTemplateTupleDesc(3);
#else
so->tupdesc = CreateTemplateTupleDesc(4, false);
so->tupdesc = CreateTemplateTupleDesc(3, false);
#endif
TupleDescInitEntry(so->tupdesc, (AttrNumber) 1, "distance", FLOAT8OID, -1, 0);
TupleDescInitEntry(so->tupdesc, (AttrNumber) 2, "blkno", INT4OID, -1, 0);
TupleDescInitEntry(so->tupdesc, (AttrNumber) 3, "offset", INT4OID, -1, 0);
TupleDescInitEntry(so->tupdesc, (AttrNumber) 4, "indexblkno", INT4OID, -1, 0);
TupleDescInitEntry(so->tupdesc, (AttrNumber) 2, "tid", TIDOID, -1, 0);
TupleDescInitEntry(so->tupdesc, (AttrNumber) 3, "indexblkno", INT4OID, -1, 0);
/* Prep sort */
#if PG_VERSION_NUM >= 110000
@@ -208,6 +228,8 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys)
so->slot = MakeSingleTupleTableSlot(so->tupdesc);
#endif
so->listQueue = pairingheap_allocate(CompareLists, scan);
scan->opaque = so;
return scan;
@@ -227,7 +249,7 @@ ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int
#endif
so->first = true;
so->probes = ivfflat_probes;
pairingheap_reset(so->listQueue);
if (keys && scan->numberOfKeys > 0)
memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData));
@@ -286,14 +308,13 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir)
if (tuplesort_gettupleslot(so->sortstate, true, so->slot, NULL))
#endif
{
BlockNumber blkno = DatumGetInt32(slot_getattr(so->slot, 2, &so->isnull));
OffsetNumber offset = DatumGetInt32(slot_getattr(so->slot, 3, &so->isnull));
BlockNumber indexblkno = DatumGetInt32(slot_getattr(so->slot, 4, &so->isnull));
ItemPointer tid = (ItemPointer) DatumGetPointer(slot_getattr(so->slot, 2, &so->isnull));
BlockNumber indexblkno = DatumGetInt32(slot_getattr(so->slot, 3, &so->isnull));
#if PG_VERSION_NUM >= 120000
ItemPointerSet(&scan->xs_heaptid, blkno, offset);
scan->xs_heaptid = *tid;
#else
ItemPointerSet(&scan->xs_ctup.t_self, blkno, offset);
scan->xs_ctup.t_self = *tid;
#endif
if (BufferIsValid(so->buf))
@@ -326,6 +347,7 @@ ivfflatendscan(IndexScanDesc scan)
if (BufferIsValid(so->buf))
ReleaseBuffer(so->buf);
pairingheap_free(so->listQueue);
tuplesort_end(so->sortstate);
pfree(so);

View File

@@ -2,15 +2,16 @@ use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More tests => 2;
use Test::More tests => 9;
my $node;
my @queries = ();
my @expected = ();
my @expected;
my $limit = 20;
sub test_recall
{
my ($probes, $min) = @_;
my ($probes, $min, $operator) = @_;
my $correct = 0;
my $total = 0;
@@ -18,7 +19,7 @@ sub test_recall
my $actual = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SET ivfflat.probes = $probes;
SELECT i FROM tst ORDER BY v <-> '$queries[$i]' LIMIT 10;
SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit;
));
my @actual_ids = split("\n", $actual);
my %actual_set = map { $_ => 1 } @actual_ids;
@@ -33,7 +34,7 @@ sub test_recall
}
}
cmp_ok($correct / $total, ">=", $min);
cmp_ok($correct / $total, ">=", $min, $operator);
}
# Initialize node
@@ -56,17 +57,32 @@ for (1..20) {
push(@queries, "[$r1,$r2,$r3]");
}
# Get exact results
foreach (@queries) {
my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v <-> '$_' LIMIT 10;");
push(@expected, $res);
# Check each index type
my @operators = ("<->", "<#>", "<=>");
foreach (@operators) {
my $operator = $_;
# Get exact results
@expected = ();
foreach (@queries) {
my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;");
push(@expected, $res);
}
# Add index
my $opclass;
if ($operator == "<->") {
$opclass = "vector_l2_ops";
} elsif ($operator == "<#>") {
$opclass = "vector_ip_ops";
} else {
$opclass = "vector_cosine_ops";
}
$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v $opclass);");
# Test approximate results
test_recall(1, 0.75, $operator);
test_recall(10, 0.95, $operator);
test_recall(100, 1.0, $operator);
}
# Add index
$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v);");
# Test approximate results
test_recall(1, 0.8);
# Test probes
test_recall(100, 1.0);

View File

@@ -0,0 +1,45 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More tests => 60;
# Initialize node
my $node = get_new_node('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (i int4 primary key, v vector(3));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;"
);
# Check each index type
my @operators = ("<->", "<#>", "<=>");
foreach (@operators) {
my $operator = $_;
# Add index
my $opclass;
if ($operator == "<->") {
$opclass = "vector_l2_ops";
} elsif ($operator == "<#>") {
$opclass = "vector_ip_ops";
} else {
$opclass = "vector_cosine_ops";
}
$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v $opclass);");
# Test 100% recall
for (1..20) {
my $i = int(rand() * 100000);
my $query = $node->safe_psql("postgres", "SELECT v FROM tst WHERE i = $i;");
my $res = $node->safe_psql("postgres", qq(
SET enable_seqscan = off;
SELECT v FROM tst ORDER BY v <-> '$query' LIMIT 1;
));
is($res, $query);
}
}

31
test/t/006_lists.pl Normal file
View File

@@ -0,0 +1,31 @@
use strict;
use warnings;
use PostgresNode;
use TestLib;
use Test::More tests => 3;
# Initialize node
my $node = get_new_node('node');
$node->init;
$node->start;
# Create table
$node->safe_psql("postgres", "CREATE EXTENSION vector;");
$node->safe_psql("postgres", "CREATE TABLE tst (v vector(3));");
$node->safe_psql("postgres",
"INSERT INTO tst SELECT ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;"
);
$node->safe_psql("postgres", "CREATE INDEX lists50 ON tst USING ivfflat (v) WITH (lists = 50);");
$node->safe_psql("postgres", "CREATE INDEX lists100 ON tst USING ivfflat (v) WITH (lists = 100);");
# Test prefers more lists
my $res = $node->safe_psql("postgres", "EXPLAIN SELECT v FROM tst ORDER BY v <-> '[0.5,0.5,0.5]' LIMIT 10;");
like($res, qr/lists100/);
unlike($res, qr/lists50/);
# Test errors with too much memory
my ($ret, $stdout, $stderr) = $node->psql("postgres",
"CREATE INDEX lists10000 ON tst USING ivfflat (v) WITH (lists = 10000);"
);
like($stderr, qr/memory required is/);