Files
pgvector/src/hnsw.c
2024-09-28 15:24:21 -07:00

324 lines
9.8 KiB
C

#include "postgres.h"
#include <float.h>
#include <math.h>
#include "access/amapi.h"
#include "access/reloptions.h"
#include "commands/progress.h"
#include "commands/vacuum.h"
#include "hnsw.h"
#include "miscadmin.h"
#include "utils/float.h"
#include "utils/guc.h"
#include "utils/selfuncs.h"
#include "utils/spccache.h"
#if PG_VERSION_NUM < 150000
#define MarkGUCPrefixReserved(x) EmitWarningsOnPlaceholders(x)
#endif
int hnsw_ef_search;
int hnsw_ef_stream;
bool hnsw_streaming;
int hnsw_lock_tranche_id;
static relopt_kind hnsw_relopt_kind;
/*
* Assign a tranche ID for our LWLocks. This only needs to be done by one
* backend, as the tranche ID is remembered in shared memory.
*
* This shared memory area is very small, so we just allocate it from the
* "slop" that PostgreSQL reserves for small allocations like this. If
* this grows bigger, we should use a shmem_request_hook and
* RequestAddinShmemSpace() to pre-reserve space for this.
*/
void
HnswInitLockTranche(void)
{
int *tranche_ids;
bool found;
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
tranche_ids = ShmemInitStruct("hnsw LWLock ids",
sizeof(int) * 1,
&found);
if (!found)
tranche_ids[0] = LWLockNewTrancheId();
hnsw_lock_tranche_id = tranche_ids[0];
LWLockRelease(AddinShmemInitLock);
/* Per-backend registration of the tranche ID */
LWLockRegisterTranche(hnsw_lock_tranche_id, "HnswBuild");
}
/*
* Initialize index options and variables
*/
void
HnswInit(void)
{
if (!process_shared_preload_libraries_in_progress)
HnswInitLockTranche();
hnsw_relopt_kind = add_reloption_kind();
add_int_reloption(hnsw_relopt_kind, "m", "Max number of connections",
HNSW_DEFAULT_M, HNSW_MIN_M, HNSW_MAX_M, AccessExclusiveLock);
add_int_reloption(hnsw_relopt_kind, "ef_construction", "Size of the dynamic candidate list for construction",
HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION, AccessExclusiveLock);
DefineCustomIntVariable("hnsw.ef_search", "Sets the size of the dynamic candidate list for search",
"Valid range is 1..1000.", &hnsw_ef_search,
HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL);
/* TODO Figure out name */
DefineCustomBoolVariable("hnsw.streaming", "Use streaming mode",
NULL, &hnsw_streaming,
HNSW_DEFAULT_STREAMING, PGC_USERSET, 0, NULL, NULL, NULL);
/* TODO Figure out name */
/* TODO Use same value as ivfflat.max_probes for "all" */
DefineCustomIntVariable("hnsw.ef_stream", "Sets the max number of additional candidates to visit for streaming search",
"-1 means all", &hnsw_ef_stream,
HNSW_DEFAULT_EF_STREAM, HNSW_MIN_EF_STREAM, HNSW_MAX_EF_STREAM, PGC_USERSET, 0, NULL, NULL, NULL);
MarkGUCPrefixReserved("hnsw");
}
/*
* Get the name of index build phase
*/
static char *
hnswbuildphasename(int64 phasenum)
{
switch (phasenum)
{
case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE:
return "initializing";
case PROGRESS_HNSW_PHASE_LOAD:
return "loading tuples";
default:
return NULL;
}
}
/*
* Estimate ef needed for iterative scans
*/
static int
EstimateEf(PlannerInfo *root, IndexPath *path)
{
double selectivity = 1;
ListCell *lc;
/* Cannot estimate without limit */
/* limit_tuples includes offset */
if (root->limit_tuples < 0)
return 0;
/* Get the selectivity of non-index conditions */
foreach(lc, path->indexinfo->indrestrictinfo)
{
RestrictInfo *rinfo = lfirst(lc);
/* Skip DEFAULT_INEQ_SEL since it may be a distance filter */
if (rinfo->norm_selec >= 0 && rinfo->norm_selec <= 1 && rinfo->norm_selec != (Selectivity) DEFAULT_INEQ_SEL)
selectivity *= rinfo->norm_selec;
}
return root->limit_tuples / Max(selectivity, 0.00001);
}
/*
* Estimate the cost of an index scan
*/
static void
hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count,
Cost *indexStartupCost, Cost *indexTotalCost,
Selectivity *indexSelectivity, double *indexCorrelation,
double *indexPages)
{
GenericCosts costs;
int m;
int ef;
int entryLevel;
int layer0TuplesMax;
double layer0Selectivity;
double scalingFactor = 0.55;
double spc_seq_page_cost;
Relation index;
/* Never use index without order */
if (path->indexorderbys == NULL)
{
*indexStartupCost = get_float8_infinity();
*indexTotalCost = get_float8_infinity();
*indexSelectivity = 0;
*indexCorrelation = 0;
*indexPages = 0;
return;
}
MemSet(&costs, 0, sizeof(costs));
index = index_open(path->indexinfo->indexoid, NoLock);
HnswGetMetaPageInfo(index, &m, NULL);
index_close(index, NoLock);
/* TODO Separate startup and total cost */
ef = hnsw_streaming ? Max(hnsw_ef_search, EstimateEf(root, path)) : hnsw_ef_search;
/*
* HNSW cost estimation follows a formula that accounts for the total
* number of tuples indexed combined with the parameters that most
* influence the duration of the index scan, namely: m - the number of
* tuples that are scanned in each step of the HNSW graph traversal
* ef_search - which influences the total number of steps taken at layer 0
*
* The source of the vector data can impact how many steps it takes to
* converge on the set of vectors to return to the executor. Currently, we
* use a hardcoded scaling factor (HNSWScanScalingFactor) to help
* influence that, but this could later become a configurable parameter
* based on the cost estimations.
*
* The tuple estimator formula is below:
*
* numIndexTuples = entryLevel * m + layer0TuplesMax * layer0Selectivity
*
* "entryLevel * m" represents the floor of tuples we need to scan to get
* to layer 0 (L0).
*
* "layer0TuplesMax" is the estimated total number of tuples we'd scan at
* L0 if we weren't discarding already visited tuples as part of the scan.
*
* "layer0Selectivity" estimates the percentage of tuples that are scanned
* at L0, accounting for previously visited tuples, multiplied by the
* "scalingFactor" (currently hardcoded).
*/
entryLevel = (int) (log(path->indexinfo->tuples + 1) * HnswGetMl(m));
layer0TuplesMax = HnswGetLayerM(m, 0) * ef;
layer0Selectivity = (scalingFactor * log(path->indexinfo->tuples + 1)) /
(log(m) * (1 + log(ef)));
costs.numIndexTuples = (entryLevel * m) +
(layer0TuplesMax * layer0Selectivity);
genericcostestimate(root, path, loop_count, &costs);
get_tablespace_page_costs(path->indexinfo->reltablespace, NULL, &spc_seq_page_cost);
/* Adjust cost if needed since TOAST not included in seq scan cost */
if (costs.numIndexPages > path->indexinfo->rel->pages && costs.numIndexTuples / (path->indexinfo->tuples + 1) < 0.5)
{
/* Change all page cost from random to sequential */
costs.indexTotalCost -= costs.numIndexPages * (costs.spc_random_page_cost - spc_seq_page_cost);
/* Remove cost of extra pages */
costs.indexTotalCost -= (costs.numIndexPages - path->indexinfo->rel->pages) * spc_seq_page_cost;
}
/* Use total cost since most work happens before first tuple is returned */
*indexStartupCost = costs.indexTotalCost;
*indexTotalCost = costs.indexTotalCost;
*indexSelectivity = costs.indexSelectivity;
*indexCorrelation = costs.indexCorrelation;
*indexPages = costs.numIndexPages;
}
/*
* Parse and validate the reloptions
*/
static bytea *
hnswoptions(Datum reloptions, bool validate)
{
static const relopt_parse_elt tab[] = {
{"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)},
{"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)},
};
return (bytea *) build_reloptions(reloptions, validate,
hnsw_relopt_kind,
sizeof(HnswOptions),
tab, lengthof(tab));
}
/*
* Validate catalog entries for the specified operator class
*/
static bool
hnswvalidate(Oid opclassoid)
{
return true;
}
/*
* Define index handler
*
* See https://www.postgresql.org/docs/current/index-api.html
*/
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(hnswhandler);
Datum
hnswhandler(PG_FUNCTION_ARGS)
{
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);
amroutine->amstrategies = 0;
amroutine->amsupport = 3;
amroutine->amoptsprocnum = 0;
amroutine->amcanorder = false;
amroutine->amcanorderbyop = true;
amroutine->amcanbackward = false; /* can change direction mid-scan */
amroutine->amcanunique = false;
amroutine->amcanmulticol = false;
amroutine->amoptionalkey = true;
amroutine->amsearcharray = false;
amroutine->amsearchnulls = false;
amroutine->amstorage = false;
amroutine->amclusterable = false;
amroutine->ampredlocks = false;
amroutine->amcanparallel = false;
#if PG_VERSION_NUM >= 170000
amroutine->amcanbuildparallel = true;
#endif
amroutine->amcaninclude = false;
amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */
#if PG_VERSION_NUM >= 160000
amroutine->amsummarizing = false;
#endif
amroutine->amparallelvacuumoptions = VACUUM_OPTION_PARALLEL_BULKDEL;
amroutine->amkeytype = InvalidOid;
/* Interface functions */
amroutine->ambuild = hnswbuild;
amroutine->ambuildempty = hnswbuildempty;
amroutine->aminsert = hnswinsert;
#if PG_VERSION_NUM >= 170000
amroutine->aminsertcleanup = NULL;
#endif
amroutine->ambulkdelete = hnswbulkdelete;
amroutine->amvacuumcleanup = hnswvacuumcleanup;
amroutine->amcanreturn = NULL;
amroutine->amcostestimate = hnswcostestimate;
amroutine->amoptions = hnswoptions;
amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */
amroutine->ambuildphasename = hnswbuildphasename;
amroutine->amvalidate = hnswvalidate;
#if PG_VERSION_NUM >= 140000
amroutine->amadjustmembers = NULL;
#endif
amroutine->ambeginscan = hnswbeginscan;
amroutine->amrescan = hnswrescan;
amroutine->amgettuple = hnswgettuple;
amroutine->amgetbitmap = NULL;
amroutine->amendscan = hnswendscan;
amroutine->ammarkpos = NULL;
amroutine->amrestrpos = NULL;
/* Interface functions to support parallel index scans */
amroutine->amestimateparallelscan = NULL;
amroutine->aminitparallelscan = NULL;
amroutine->amparallelrescan = NULL;
PG_RETURN_POINTER(amroutine);
}