diff --git a/src/hnsw.h b/src/hnsw.h index 3e8bdc2..af39852 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -4,6 +4,7 @@ #include "postgres.h" #include "access/generic_xlog.h" +#include "access/parallel.h" #include "access/reloptions.h" #include "nodes/execnodes.h" #include "port.h" /* for random() */ @@ -132,6 +133,49 @@ typedef struct HnswOptions int efConstruction; /* size of dynamic candidate list */ } HnswOptions; +typedef struct HnswSpool +{ + Relation heap; + Relation index; +} HnswSpool; + +typedef struct HnswShared +{ + /* Immutable state */ + Oid heaprelid; + Oid indexrelid; + bool isconcurrent; + int scantuplesortstates; + + /* Worker progress */ + ConditionVariable workersdonecv; + + /* Mutex for mutable state */ + slock_t mutex; + + /* Mutable state */ + int nparticipantsdone; + double reltuples; + double indtuples; + +#if PG_VERSION_NUM < 120000 + ParallelHeapScanDescData heapdesc; /* must come last */ +#endif +} HnswShared; + +#if PG_VERSION_NUM >= 120000 +#define ParallelTableScanFromHnswShared(shared) \ + (ParallelTableScanDesc) ((char *) (shared) + BUFFERALIGN(sizeof(HnswShared))) +#endif + +typedef struct HnswLeader +{ + ParallelContext *pcxt; + int nparticipanttuplesorts; + HnswShared *hnswshared; + Snapshot snapshot; +} HnswLeader; + typedef struct HnswBuildState { /* Info */ @@ -165,6 +209,9 @@ typedef struct HnswBuildState /* Memory */ MemoryContext tmpCtx; + + /* Parallel builds */ + HnswLeader *hnswleader; } HnswBuildState; typedef struct HnswMetaPageData @@ -285,6 +332,7 @@ void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation i void HnswSetElementTuple(HnswElementTuple etup, HnswElement element); void HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation); void HnswLoadNeighbors(HnswElement element, Relation index); +PGDLLEXPORT void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc); /* Index access methods */ IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo); diff --git a/src/hnswbuild.c b/src/hnswbuild.c index 8cf5b75..f298611 100644 --- a/src/hnswbuild.c +++ b/src/hnswbuild.c @@ -2,12 +2,14 @@ #include +#include "access/parallel.h" #include "catalog/index.h" #include "hnsw.h" #include "miscadmin.h" #include "lib/pairingheap.h" #include "nodes/pg_list.h" #include "storage/bufmgr.h" +#include "tcop/tcopprot.h" #include "utils/memutils.h" #if PG_VERSION_NUM >= 140000 @@ -35,6 +37,23 @@ #define UpdateProgress(index, val) ((void)val) #endif +#if PG_VERSION_NUM >= 140000 +#include "utils/backend_status.h" +#include "utils/wait_event.h" +#endif + +#if PG_VERSION_NUM >= 120000 +#include "access/table.h" +#include "optimizer/optimizer.h" +#else +#include "access/heapam.h" +#include "optimizer/planner.h" +#include "pgstat.h" +#endif + +#define PARALLEL_KEY_HNSW_SHARED UINT64CONST(0xA000000000000001) +#define PARALLEL_KEY_QUERY_TEXT UINT64CONST(0xA000000000000002) + /* * Create the metapage */ @@ -351,6 +370,7 @@ BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); + /* TODO Fix progress for parallel builds */ if (HnswInsertTuple(buildstate->index, values, isnull, tid, buildstate->heap)) UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples); @@ -460,21 +480,373 @@ FreeBuildState(HnswBuildState * buildstate) MemoryContextDelete(buildstate->tmpCtx); } +/* + * Within leader, wait for end of heap scan + */ +static double +ParallelHeapScan(HnswBuildState * buildstate) +{ + HnswShared *hnswshared = buildstate->hnswleader->hnswshared; + int nparticipanttuplesorts; + double reltuples; + + nparticipanttuplesorts = buildstate->hnswleader->nparticipanttuplesorts; + for (;;) + { + SpinLockAcquire(&hnswshared->mutex); + if (hnswshared->nparticipantsdone == nparticipanttuplesorts) + { + buildstate->indtuples = hnswshared->indtuples; + reltuples = hnswshared->reltuples; + SpinLockRelease(&hnswshared->mutex); + break; + } + SpinLockRelease(&hnswshared->mutex); + + ConditionVariableSleep(&hnswshared->workersdonecv, + WAIT_EVENT_PARALLEL_CREATE_INDEX_SCAN); + } + + ConditionVariableCancelSleep(); + + return reltuples; +} + +/* + * Perform a worker's portion of a parallel insert + */ +static void +HnswParallelScanAndInsert(HnswSpool * hnswspool, HnswShared * hnswshared, bool progress) +{ + HnswBuildState buildstate; +#if PG_VERSION_NUM >= 120000 + TableScanDesc scan; +#else + HeapScanDesc scan; +#endif + double reltuples; + IndexInfo *indexInfo; + + /* Join parallel scan */ + indexInfo = BuildIndexInfo(hnswspool->index); + indexInfo->ii_Concurrent = hnswshared->isconcurrent; + InitBuildState(&buildstate, hnswspool->heap, hnswspool->index, indexInfo, MAIN_FORKNUM); + /* TODO Support in-memory builds */ + buildstate.maxInMemoryElements = 0; + buildstate.flushed = true; +#if PG_VERSION_NUM >= 120000 + scan = table_beginscan_parallel(hnswspool->heap, + ParallelTableScanFromHnswShared(hnswshared)); + reltuples = table_index_build_scan(hnswspool->heap, hnswspool->index, indexInfo, + true, progress, BuildCallback, + (void *) &buildstate, scan); +#else + scan = heap_beginscan_parallel(hnswspool->heap, &hnswshared->heapdesc); + reltuples = IndexBuildHeapScan(hnswspool->heap, hnswspool->index, indexInfo, + true, BuildCallback, + (void *) &buildstate, scan); +#endif + + /* Record statistics */ + SpinLockAcquire(&hnswshared->mutex); + hnswshared->nparticipantsdone++; + hnswshared->reltuples += reltuples; + hnswshared->indtuples += buildstate.indtuples; + SpinLockRelease(&hnswshared->mutex); + + /* Log statistics */ + if (progress) + ereport(DEBUG1, (errmsg("leader processed " INT64_FORMAT " tuples", (int64) reltuples))); + else + ereport(DEBUG1, (errmsg("worker processed " INT64_FORMAT " tuples", (int64) reltuples))); + + /* Notify leader */ + ConditionVariableSignal(&hnswshared->workersdonecv); + + FreeBuildState(&buildstate); +} + +/* + * Perform work within a launched parallel process + */ +void +HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc) +{ + char *sharedquery; + HnswSpool *hnswspool; + HnswShared *hnswshared; + Relation heapRel; + Relation indexRel; + LOCKMODE heapLockmode; + LOCKMODE indexLockmode; + + /* Set debug_query_string for individual workers first */ + sharedquery = shm_toc_lookup(toc, PARALLEL_KEY_QUERY_TEXT, true); + debug_query_string = sharedquery; + + /* Report the query string from leader */ + pgstat_report_activity(STATE_RUNNING, debug_query_string); + + /* Look up shared state */ + hnswshared = shm_toc_lookup(toc, PARALLEL_KEY_HNSW_SHARED, false); + + /* Open relations using lock modes known to be obtained by index.c */ + if (!hnswshared->isconcurrent) + { + heapLockmode = ShareLock; + indexLockmode = AccessExclusiveLock; + } + else + { + heapLockmode = ShareUpdateExclusiveLock; + indexLockmode = RowExclusiveLock; + } + + /* Open relations within worker */ +#if PG_VERSION_NUM >= 120000 + heapRel = table_open(hnswshared->heaprelid, heapLockmode); +#else + heapRel = heap_open(hnswshared->heaprelid, heapLockmode); +#endif + indexRel = index_open(hnswshared->indexrelid, indexLockmode); + + /* Initialize worker's own spool */ + hnswspool = (HnswSpool *) palloc0(sizeof(HnswSpool)); + hnswspool->heap = heapRel; + hnswspool->index = indexRel; + + /* Perform inserts */ + HnswParallelScanAndInsert(hnswspool, hnswshared, false); + + /* Close relations within worker */ + index_close(indexRel, indexLockmode); +#if PG_VERSION_NUM >= 120000 + table_close(heapRel, heapLockmode); +#else + heap_close(heapRel, heapLockmode); +#endif +} + +/* + * End parallel build + */ +static void +HnswEndParallel(HnswLeader * hnswleader) +{ + /* Shutdown worker processes */ + WaitForParallelWorkersToFinish(hnswleader->pcxt); + + /* Free last reference to MVCC snapshot, if one was used */ + if (IsMVCCSnapshot(hnswleader->snapshot)) + UnregisterSnapshot(hnswleader->snapshot); + DestroyParallelContext(hnswleader->pcxt); + ExitParallelMode(); +} + +/* + * Return size of shared memory required for parallel index build + */ +static Size +ParallelEstimateShared(Relation heap, Snapshot snapshot) +{ +#if PG_VERSION_NUM >= 120000 + return add_size(BUFFERALIGN(sizeof(HnswShared)), table_parallelscan_estimate(heap, snapshot)); +#else + if (!IsMVCCSnapshot(snapshot)) + { + Assert(snapshot == SnapshotAny); + return sizeof(HnswShared); + } + + return add_size(offsetof(HnswShared, heapdesc) + + offsetof(ParallelHeapScanDescData, phs_snapshot_data), + EstimateSnapshotSpace(snapshot)); +#endif +} + +/* + * Within leader, participate as a parallel worker + */ +static void +HnswLeaderParticipateAsWorker(HnswBuildState * buildstate) +{ + HnswLeader *hnswleader = buildstate->hnswleader; + HnswSpool *leaderworker; + + /* Allocate memory and initialize private spool */ + leaderworker = (HnswSpool *) palloc0(sizeof(HnswSpool)); + leaderworker->heap = buildstate->heap; + leaderworker->index = buildstate->index; + + /* Perform work common to all participants */ + HnswParallelScanAndInsert(leaderworker, hnswleader->hnswshared, true); +} + +/* + * Begin parallel build + */ +static void +HnswBeginParallel(HnswBuildState * buildstate, bool isconcurrent, int request) +{ + ParallelContext *pcxt; + int scantuplesortstates; + Snapshot snapshot; + Size esthnswshared; + HnswShared *hnswshared; + HnswLeader *hnswleader = (HnswLeader *) palloc0(sizeof(HnswLeader)); + bool leaderparticipates = true; + int querylen; + +#ifdef DISABLE_LEADER_PARTICIPATION + leaderparticipates = false; +#endif + + /* Enter parallel mode and create context */ + EnterParallelMode(); + Assert(request > 0); +#if PG_VERSION_NUM >= 120000 + pcxt = CreateParallelContext("vector", "HnswParallelBuildMain", request); +#else + pcxt = CreateParallelContext("vector", "HnswParallelBuildMain", request, true); +#endif + + scantuplesortstates = leaderparticipates ? request + 1 : request; + + /* Get snapshot for table scan */ + if (!isconcurrent) + snapshot = SnapshotAny; + else + snapshot = RegisterSnapshot(GetTransactionSnapshot()); + + /* Estimate size of workspaces */ + esthnswshared = ParallelEstimateShared(buildstate->heap, snapshot); + shm_toc_estimate_chunk(&pcxt->estimator, esthnswshared); + shm_toc_estimate_keys(&pcxt->estimator, 1); + + /* Finally, estimate PARALLEL_KEY_QUERY_TEXT space */ + if (debug_query_string) + { + querylen = strlen(debug_query_string); + shm_toc_estimate_chunk(&pcxt->estimator, querylen + 1); + shm_toc_estimate_keys(&pcxt->estimator, 1); + } + else + querylen = 0; /* keep compiler quiet */ + + /* Everyone's had a chance to ask for space, so now create the DSM */ + InitializeParallelDSM(pcxt); + + /* If no DSM segment was available, back out (do serial build) */ + if (pcxt->seg == NULL) + { + if (IsMVCCSnapshot(snapshot)) + UnregisterSnapshot(snapshot); + DestroyParallelContext(pcxt); + ExitParallelMode(); + return; + } + + /* Store shared build state, for which we reserved space */ + hnswshared = (HnswShared *) shm_toc_allocate(pcxt->toc, esthnswshared); + /* Initialize immutable state */ + hnswshared->heaprelid = RelationGetRelid(buildstate->heap); + hnswshared->indexrelid = RelationGetRelid(buildstate->index); + hnswshared->isconcurrent = isconcurrent; + hnswshared->scantuplesortstates = scantuplesortstates; + ConditionVariableInit(&hnswshared->workersdonecv); + SpinLockInit(&hnswshared->mutex); + /* Initialize mutable state */ + hnswshared->nparticipantsdone = 0; + hnswshared->reltuples = 0; + hnswshared->indtuples = 0; +#if PG_VERSION_NUM >= 120000 + table_parallelscan_initialize(buildstate->heap, + ParallelTableScanFromHnswShared(hnswshared), + snapshot); +#else + heap_parallelscan_initialize(&hnswshared->heapdesc, buildstate->heap, snapshot); +#endif + + shm_toc_insert(pcxt->toc, PARALLEL_KEY_HNSW_SHARED, hnswshared); + + /* Store query string for workers */ + if (debug_query_string) + { + char *sharedquery; + + sharedquery = (char *) shm_toc_allocate(pcxt->toc, querylen + 1); + memcpy(sharedquery, debug_query_string, querylen + 1); + shm_toc_insert(pcxt->toc, PARALLEL_KEY_QUERY_TEXT, sharedquery); + } + + /* Launch workers, saving status for leader/caller */ + LaunchParallelWorkers(pcxt); + hnswleader->pcxt = pcxt; + hnswleader->nparticipanttuplesorts = pcxt->nworkers_launched; + if (leaderparticipates) + hnswleader->nparticipanttuplesorts++; + hnswleader->hnswshared = hnswshared; + hnswleader->snapshot = snapshot; + + /* If no workers were successfully launched, back out (do serial build) */ + if (pcxt->nworkers_launched == 0) + { + HnswEndParallel(hnswleader); + return; + } + + /* Log participants */ + ereport(DEBUG1, (errmsg("using %d parallel workers", pcxt->nworkers_launched))); + + /* Save leader state now that it's clear build will be parallel */ + buildstate->hnswleader = hnswleader; + + /* Join heap scan ourselves */ + if (leaderparticipates) + HnswLeaderParticipateAsWorker(buildstate); + + /* Wait for all launched workers */ + WaitForParallelWorkersToAttach(pcxt); +} + /* * Build graph */ static void BuildGraph(HnswBuildState * buildstate, ForkNumber forkNum) { + int parallel_workers = 0; + UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_HNSW_PHASE_LOAD); + /* Calculate parallel workers */ + parallel_workers = plan_create_index_workers(RelationGetRelid(buildstate->heap), RelationGetRelid(buildstate->index)); + + /* Attempt to launch parallel worker scan when required */ + if (parallel_workers > 0) + { + /* TODO Support in-memory builds */ + FlushPages(buildstate); + HnswBeginParallel(buildstate, buildstate->indexInfo->ii_Concurrent, parallel_workers); + } + + /* Add tuples to sort */ + if (buildstate->hnswleader) + buildstate->reltuples = ParallelHeapScan(buildstate); + else + { #if PG_VERSION_NUM >= 120000 - buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, - true, true, BuildCallback, (void *) buildstate, NULL); + buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, true, BuildCallback, (void *) buildstate, NULL); #else - buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo, - true, BuildCallback, (void *) buildstate, NULL); + buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, BuildCallback, (void *) buildstate, NULL); #endif + } + + /* End parallel build */ + if (buildstate->hnswleader) + HnswEndParallel(buildstate->hnswleader); } /* diff --git a/test/t/012_hnsw_build_recall.pl b/test/t/012_hnsw_build_recall.pl index e9074c6..e2a8653 100644 --- a/test/t/012_hnsw_build_recall.pl +++ b/test/t/012_hnsw_build_recall.pl @@ -83,11 +83,31 @@ for my $i (0 .. $#operators) push(@expected, $res); } - # Add index - $node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v $opclass);"); + # Build index serially + $node->safe_psql("postgres", qq( + SET max_parallel_maintenance_workers = 0; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + # Test approximate results my $min = $operator eq "<#>" ? 0.80 : 0.99; test_recall($min, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel + my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET client_min_messages = DEBUG; + SET min_parallel_table_scan_size = 1; + CREATE INDEX idx ON tst USING hnsw (v $opclass); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + + # Test approximate results + test_recall($min, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); } done_testing();