Added support for on-disk parallel index builds for HNSW

This commit is contained in:
Andrew Kane
2023-11-11 19:29:45 -08:00
parent 69a2ce0d43
commit dfee5d4045
5 changed files with 479 additions and 7 deletions

View File

@@ -1,3 +1,7 @@
## 0.5.2 (unreleased)
- Added support for on-disk parallel index builds for HNSW
## 0.5.1 (2023-10-10)
- Improved performance of HNSW index builds

View File

@@ -14,6 +14,7 @@
#endif
int hnsw_ef_search;
bool hnsw_enable_parallel_build;
static relopt_kind hnsw_relopt_kind;
/*
@@ -39,6 +40,11 @@ HnswInit(void)
DefineCustomIntVariable("hnsw.ef_search", "Sets the size of the dynamic candidate list for search",
"Valid range is 1..1000.", &hnsw_ef_search,
HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL);
/* Behind a variable for now since can be slower than building in memory */
DefineCustomBoolVariable("hnsw.enable_parallel_build", "Enables or disables building indexes in parallel",
NULL, &hnsw_enable_parallel_build,
false, PGC_USERSET, 0, NULL, NULL, NULL);
}
/*

View File

@@ -4,6 +4,7 @@
#include "postgres.h"
#include "access/generic_xlog.h"
#include "access/parallel.h"
#include "access/reloptions.h"
#include "nodes/execnodes.h"
#include "port.h" /* for random() */
@@ -14,6 +15,10 @@
#error "Requires PostgreSQL 11+"
#endif
#if PG_VERSION_NUM < 120000
#include "access/relscan.h"
#endif
#define HNSW_MAX_DIM 2000
/* Support functions */
@@ -90,6 +95,7 @@
/* Variables */
extern int hnsw_ef_search;
extern bool hnsw_enable_parallel_build;
typedef struct HnswNeighborArray HnswNeighborArray;
@@ -136,6 +142,49 @@ typedef struct HnswOptions
int efConstruction; /* size of dynamic candidate list */
} HnswOptions;
typedef struct HnswSpool
{
Relation heap;
Relation index;
} HnswSpool;
typedef struct HnswShared
{
/* Immutable state */
Oid heaprelid;
Oid indexrelid;
bool isconcurrent;
int scantuplesortstates;
/* Worker progress */
ConditionVariable workersdonecv;
/* Mutex for mutable state */
slock_t mutex;
/* Mutable state */
int nparticipantsdone;
double reltuples;
double indtuples;
#if PG_VERSION_NUM < 120000
ParallelHeapScanDescData heapdesc; /* must come last */
#endif
} HnswShared;
#if PG_VERSION_NUM >= 120000
#define ParallelTableScanFromHnswShared(shared) \
(ParallelTableScanDesc) ((char *) (shared) + BUFFERALIGN(sizeof(HnswShared)))
#endif
typedef struct HnswLeader
{
ParallelContext *pcxt;
int nparticipanttuplesorts;
HnswShared *hnswshared;
Snapshot snapshot;
} HnswLeader;
typedef struct HnswBuildState
{
/* Info */
@@ -169,6 +218,10 @@ typedef struct HnswBuildState
/* Memory */
MemoryContext tmpCtx;
/* Parallel builds */
HnswLeader *hnswleader;
HnswShared *hnswshared;
} HnswBuildState;
typedef struct HnswMetaPageData
@@ -289,6 +342,7 @@ void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation i
void HnswSetElementTuple(HnswElementTuple etup, HnswElement element);
void HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation);
void HnswLoadNeighbors(HnswElement element, Relation index, int m);
PGDLLEXPORT void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc);
/* Index access methods */
IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo);

View File

@@ -2,12 +2,15 @@
#include <math.h>
#include "access/parallel.h"
#include "access/xact.h"
#include "catalog/index.h"
#include "hnsw.h"
#include "miscadmin.h"
#include "lib/pairingheap.h"
#include "nodes/pg_list.h"
#include "storage/bufmgr.h"
#include "tcop/tcopprot.h"
#include "utils/datum.h"
#include "utils/memutils.h"
@@ -36,6 +39,23 @@
#define UpdateProgress(index, val) ((void)val)
#endif
#if PG_VERSION_NUM >= 140000
#include "utils/backend_status.h"
#include "utils/wait_event.h"
#endif
#if PG_VERSION_NUM >= 120000
#include "access/table.h"
#include "optimizer/optimizer.h"
#else
#include "access/heapam.h"
#include "optimizer/planner.h"
#include "pgstat.h"
#endif
#define PARALLEL_KEY_HNSW_SHARED UINT64CONST(0xA000000000000001)
#define PARALLEL_KEY_QUERY_TEXT UINT64CONST(0xA000000000000002)
/*
* Create the metapage
*/
@@ -376,7 +396,18 @@ BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values,
oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx);
if (HnswInsertTuple(buildstate->index, values, isnull, tid, buildstate->heap))
UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples);
{
if (buildstate->hnswshared)
{
HnswShared *hnswshared = buildstate->hnswshared;
SpinLockAcquire(&hnswshared->mutex);
UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++hnswshared->indtuples);
SpinLockRelease(&hnswshared->mutex);
}
else
UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples);
}
/* Reset memory context */
MemoryContextSwitchTo(oldCtx);
@@ -461,6 +492,9 @@ InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, Index
buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext,
"Hnsw build temporary context",
ALLOCSET_DEFAULT_SIZES);
buildstate->hnswleader = NULL;
buildstate->hnswshared = NULL;
}
/*
@@ -473,21 +507,374 @@ FreeBuildState(HnswBuildState * buildstate)
MemoryContextDelete(buildstate->tmpCtx);
}
/*
* Within leader, wait for end of heap scan
*/
static double
ParallelHeapScan(HnswBuildState * buildstate)
{
HnswShared *hnswshared = buildstate->hnswleader->hnswshared;
int nparticipanttuplesorts;
double reltuples;
nparticipanttuplesorts = buildstate->hnswleader->nparticipanttuplesorts;
for (;;)
{
SpinLockAcquire(&hnswshared->mutex);
if (hnswshared->nparticipantsdone == nparticipanttuplesorts)
{
buildstate->indtuples = hnswshared->indtuples;
reltuples = hnswshared->reltuples;
SpinLockRelease(&hnswshared->mutex);
break;
}
SpinLockRelease(&hnswshared->mutex);
ConditionVariableSleep(&hnswshared->workersdonecv,
WAIT_EVENT_PARALLEL_CREATE_INDEX_SCAN);
}
ConditionVariableCancelSleep();
return reltuples;
}
/*
* Perform a worker's portion of a parallel insert
*/
static void
HnswParallelScanAndInsert(HnswSpool * hnswspool, HnswShared * hnswshared, bool progress)
{
HnswBuildState buildstate;
#if PG_VERSION_NUM >= 120000
TableScanDesc scan;
#else
HeapScanDesc scan;
#endif
double reltuples;
IndexInfo *indexInfo;
/* Join parallel scan */
indexInfo = BuildIndexInfo(hnswspool->index);
indexInfo->ii_Concurrent = hnswshared->isconcurrent;
InitBuildState(&buildstate, hnswspool->heap, hnswspool->index, indexInfo, MAIN_FORKNUM);
/* TODO Support in-memory builds */
buildstate.memoryLeft = 0;
buildstate.flushed = true;
buildstate.hnswshared = hnswshared;
#if PG_VERSION_NUM >= 120000
scan = table_beginscan_parallel(hnswspool->heap,
ParallelTableScanFromHnswShared(hnswshared));
reltuples = table_index_build_scan(hnswspool->heap, hnswspool->index, indexInfo,
true, progress, BuildCallback,
(void *) &buildstate, scan);
#else
scan = heap_beginscan_parallel(hnswspool->heap, &hnswshared->heapdesc);
reltuples = IndexBuildHeapScan(hnswspool->heap, hnswspool->index, indexInfo,
true, BuildCallback,
(void *) &buildstate, scan);
#endif
/* Record statistics */
SpinLockAcquire(&hnswshared->mutex);
hnswshared->nparticipantsdone++;
hnswshared->reltuples += reltuples;
SpinLockRelease(&hnswshared->mutex);
/* Log statistics */
if (progress)
ereport(DEBUG1, (errmsg("leader processed " INT64_FORMAT " tuples", (int64) reltuples)));
else
ereport(DEBUG1, (errmsg("worker processed " INT64_FORMAT " tuples", (int64) reltuples)));
/* Notify leader */
ConditionVariableSignal(&hnswshared->workersdonecv);
FreeBuildState(&buildstate);
}
/*
* Perform work within a launched parallel process
*/
void
HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc)
{
char *sharedquery;
HnswSpool *hnswspool;
HnswShared *hnswshared;
Relation heapRel;
Relation indexRel;
LOCKMODE heapLockmode;
LOCKMODE indexLockmode;
/* Set debug_query_string for individual workers first */
sharedquery = shm_toc_lookup(toc, PARALLEL_KEY_QUERY_TEXT, true);
debug_query_string = sharedquery;
/* Report the query string from leader */
pgstat_report_activity(STATE_RUNNING, debug_query_string);
/* Look up shared state */
hnswshared = shm_toc_lookup(toc, PARALLEL_KEY_HNSW_SHARED, false);
/* Open relations using lock modes known to be obtained by index.c */
if (!hnswshared->isconcurrent)
{
heapLockmode = ShareLock;
indexLockmode = AccessExclusiveLock;
}
else
{
heapLockmode = ShareUpdateExclusiveLock;
indexLockmode = RowExclusiveLock;
}
/* Open relations within worker */
#if PG_VERSION_NUM >= 120000
heapRel = table_open(hnswshared->heaprelid, heapLockmode);
#else
heapRel = heap_open(hnswshared->heaprelid, heapLockmode);
#endif
indexRel = index_open(hnswshared->indexrelid, indexLockmode);
/* Initialize worker's own spool */
hnswspool = (HnswSpool *) palloc0(sizeof(HnswSpool));
hnswspool->heap = heapRel;
hnswspool->index = indexRel;
/* Perform inserts */
HnswParallelScanAndInsert(hnswspool, hnswshared, false);
/* Close relations within worker */
index_close(indexRel, indexLockmode);
#if PG_VERSION_NUM >= 120000
table_close(heapRel, heapLockmode);
#else
heap_close(heapRel, heapLockmode);
#endif
}
/*
* End parallel build
*/
static void
HnswEndParallel(HnswLeader * hnswleader)
{
/* Shutdown worker processes */
WaitForParallelWorkersToFinish(hnswleader->pcxt);
/* Free last reference to MVCC snapshot, if one was used */
if (IsMVCCSnapshot(hnswleader->snapshot))
UnregisterSnapshot(hnswleader->snapshot);
DestroyParallelContext(hnswleader->pcxt);
ExitParallelMode();
}
/*
* Return size of shared memory required for parallel index build
*/
static Size
ParallelEstimateShared(Relation heap, Snapshot snapshot)
{
#if PG_VERSION_NUM >= 120000
return add_size(BUFFERALIGN(sizeof(HnswShared)), table_parallelscan_estimate(heap, snapshot));
#else
if (!IsMVCCSnapshot(snapshot))
{
Assert(snapshot == SnapshotAny);
return sizeof(HnswShared);
}
return add_size(offsetof(HnswShared, heapdesc) +
offsetof(ParallelHeapScanDescData, phs_snapshot_data),
EstimateSnapshotSpace(snapshot));
#endif
}
/*
* Within leader, participate as a parallel worker
*/
static void
HnswLeaderParticipateAsWorker(HnswBuildState * buildstate)
{
HnswLeader *hnswleader = buildstate->hnswleader;
HnswSpool *leaderworker;
/* Allocate memory and initialize private spool */
leaderworker = (HnswSpool *) palloc0(sizeof(HnswSpool));
leaderworker->heap = buildstate->heap;
leaderworker->index = buildstate->index;
/* Perform work common to all participants */
HnswParallelScanAndInsert(leaderworker, hnswleader->hnswshared, true);
}
/*
* Begin parallel build
*/
static void
HnswBeginParallel(HnswBuildState * buildstate, bool isconcurrent, int request)
{
ParallelContext *pcxt;
int scantuplesortstates;
Snapshot snapshot;
Size esthnswshared;
HnswShared *hnswshared;
HnswLeader *hnswleader = (HnswLeader *) palloc0(sizeof(HnswLeader));
bool leaderparticipates = true;
int querylen;
#ifdef DISABLE_LEADER_PARTICIPATION
leaderparticipates = false;
#endif
/* Enter parallel mode and create context */
EnterParallelMode();
Assert(request > 0);
#if PG_VERSION_NUM >= 120000
pcxt = CreateParallelContext("vector", "HnswParallelBuildMain", request);
#else
pcxt = CreateParallelContext("vector", "HnswParallelBuildMain", request, true);
#endif
scantuplesortstates = leaderparticipates ? request + 1 : request;
/* Get snapshot for table scan */
if (!isconcurrent)
snapshot = SnapshotAny;
else
snapshot = RegisterSnapshot(GetTransactionSnapshot());
/* Estimate size of workspaces */
esthnswshared = ParallelEstimateShared(buildstate->heap, snapshot);
shm_toc_estimate_chunk(&pcxt->estimator, esthnswshared);
shm_toc_estimate_keys(&pcxt->estimator, 1);
/* Finally, estimate PARALLEL_KEY_QUERY_TEXT space */
if (debug_query_string)
{
querylen = strlen(debug_query_string);
shm_toc_estimate_chunk(&pcxt->estimator, querylen + 1);
shm_toc_estimate_keys(&pcxt->estimator, 1);
}
else
querylen = 0; /* keep compiler quiet */
/* Everyone's had a chance to ask for space, so now create the DSM */
InitializeParallelDSM(pcxt);
/* If no DSM segment was available, back out (do serial build) */
if (pcxt->seg == NULL)
{
if (IsMVCCSnapshot(snapshot))
UnregisterSnapshot(snapshot);
DestroyParallelContext(pcxt);
ExitParallelMode();
return;
}
/* Store shared build state, for which we reserved space */
hnswshared = (HnswShared *) shm_toc_allocate(pcxt->toc, esthnswshared);
/* Initialize immutable state */
hnswshared->heaprelid = RelationGetRelid(buildstate->heap);
hnswshared->indexrelid = RelationGetRelid(buildstate->index);
hnswshared->isconcurrent = isconcurrent;
hnswshared->scantuplesortstates = scantuplesortstates;
ConditionVariableInit(&hnswshared->workersdonecv);
SpinLockInit(&hnswshared->mutex);
/* Initialize mutable state */
hnswshared->nparticipantsdone = 0;
hnswshared->reltuples = 0;
hnswshared->indtuples = 0;
#if PG_VERSION_NUM >= 120000
table_parallelscan_initialize(buildstate->heap,
ParallelTableScanFromHnswShared(hnswshared),
snapshot);
#else
heap_parallelscan_initialize(&hnswshared->heapdesc, buildstate->heap, snapshot);
#endif
shm_toc_insert(pcxt->toc, PARALLEL_KEY_HNSW_SHARED, hnswshared);
/* Store query string for workers */
if (debug_query_string)
{
char *sharedquery;
sharedquery = (char *) shm_toc_allocate(pcxt->toc, querylen + 1);
memcpy(sharedquery, debug_query_string, querylen + 1);
shm_toc_insert(pcxt->toc, PARALLEL_KEY_QUERY_TEXT, sharedquery);
}
/* Launch workers, saving status for leader/caller */
LaunchParallelWorkers(pcxt);
hnswleader->pcxt = pcxt;
hnswleader->nparticipanttuplesorts = pcxt->nworkers_launched;
if (leaderparticipates)
hnswleader->nparticipanttuplesorts++;
hnswleader->hnswshared = hnswshared;
hnswleader->snapshot = snapshot;
/* If no workers were successfully launched, back out (do serial build) */
if (pcxt->nworkers_launched == 0)
{
HnswEndParallel(hnswleader);
return;
}
/* Log participants */
ereport(DEBUG1, (errmsg("using %d parallel workers", pcxt->nworkers_launched)));
/* Save leader state now that it's clear build will be parallel */
buildstate->hnswleader = hnswleader;
/* Join heap scan ourselves */
if (leaderparticipates)
HnswLeaderParticipateAsWorker(buildstate);
/* Wait for all launched workers */
WaitForParallelWorkersToAttach(pcxt);
}
/*
* Build graph
*/
static void
BuildGraph(HnswBuildState * buildstate, ForkNumber forkNum)
{
int parallel_workers = 0;
UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_HNSW_PHASE_LOAD);
/* Calculate parallel workers */
if (hnsw_enable_parallel_build)
parallel_workers = plan_create_index_workers(RelationGetRelid(buildstate->heap), RelationGetRelid(buildstate->index));
/* Attempt to launch parallel worker scan when required */
if (parallel_workers > 0)
{
/* TODO Support in-memory builds */
FlushPages(buildstate);
HnswBeginParallel(buildstate, buildstate->indexInfo->ii_Concurrent, parallel_workers);
}
/* Add tuples to sort */
if (buildstate->hnswleader)
buildstate->reltuples = ParallelHeapScan(buildstate);
else
{
#if PG_VERSION_NUM >= 120000
buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, BuildCallback, (void *) buildstate, NULL);
buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, true, BuildCallback, (void *) buildstate, NULL);
#else
buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, BuildCallback, (void *) buildstate, NULL);
buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo,
true, BuildCallback, (void *) buildstate, NULL);
#endif
}
/* End parallel build */
if (buildstate->hnswleader)
HnswEndParallel(buildstate->hnswleader);
}
/*

View File

@@ -83,11 +83,32 @@ for my $i (0 .. $#operators)
push(@expected, $res);
}
# Add index
$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v $opclass);");
# Build index serially
$node->safe_psql("postgres", qq(
SET max_parallel_maintenance_workers = 0;
CREATE INDEX idx ON tst USING hnsw (v $opclass);
));
# Test approximate results
my $min = $operator eq "<#>" ? 0.80 : 0.99;
test_recall($min, $operator);
$node->safe_psql("postgres", "DROP INDEX idx;");
# Build index in parallel
my ($ret, $stdout, $stderr) = $node->psql("postgres", qq(
SET client_min_messages = DEBUG;
SET min_parallel_table_scan_size = 1;
SET hnsw.enable_parallel_build = on;
CREATE INDEX idx ON tst USING hnsw (v $opclass);
));
is($ret, 0, $stderr);
like($stderr, qr/using \d+ parallel workers/);
# Test approximate results
test_recall($min, $operator);
$node->safe_psql("postgres", "DROP INDEX idx;");
}
done_testing();