diff --git a/Makefile b/Makefile index ff26f56..09908b3 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ EXTVERSION = 0.4.4 MODULE_big = vector DATA = $(wildcard sql/*--*.sql) -OBJS = src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o +OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o TESTS = $(wildcard test/sql/*.sql) REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS)) diff --git a/Makefile.win b/Makefile.win index 8ceb572..e69810b 100644 --- a/Makefile.win +++ b/Makefile.win @@ -1,7 +1,7 @@ EXTENSION = vector EXTVERSION = 0.4.4 -OBJS = src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj +OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged REGRESS_OPTS = --inputdir=test --load-extension=vector diff --git a/sql/vector--0.4.4--0.5.0.sql b/sql/vector--0.4.4--0.5.0.sql index 3fe3365..48572bf 100644 --- a/sql/vector--0.4.4--0.5.0.sql +++ b/sql/vector--0.4.4--0.5.0.sql @@ -18,3 +18,26 @@ CREATE AGGREGATE sum(vector) ( COMBINEFUNC = vector_add, PARALLEL = SAFE ); + +CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler; + +COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method'; + +CREATE OPERATOR CLASS vector_l2_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_l2_squared_distance(vector, vector); + +CREATE OPERATOR CLASS vector_ip_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector); + +CREATE OPERATOR CLASS vector_cosine_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector), + FUNCTION 2 vector_norm(vector); diff --git a/sql/vector.sql b/sql/vector.sql index 91f594c..137931f 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -227,7 +227,7 @@ CREATE OPERATOR > ( RESTRICT = scalargtsel, JOIN = scalargtjoinsel ); --- access method +-- access methods CREATE FUNCTION ivfflathandler(internal) RETURNS index_am_handler AS 'MODULE_PATHNAME' LANGUAGE C; @@ -236,6 +236,13 @@ CREATE ACCESS METHOD ivfflat TYPE INDEX HANDLER ivfflathandler; COMMENT ON ACCESS METHOD ivfflat IS 'ivfflat index access method'; +CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler; + +COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method'; + -- opclasses CREATE OPERATOR CLASS vector_ops @@ -267,3 +274,19 @@ CREATE OPERATOR CLASS vector_cosine_ops FUNCTION 2 vector_norm(vector), FUNCTION 3 vector_spherical_distance(vector, vector), FUNCTION 4 vector_norm(vector); + +CREATE OPERATOR CLASS vector_l2_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_l2_squared_distance(vector, vector); + +CREATE OPERATOR CLASS vector_ip_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector); + +CREATE OPERATOR CLASS vector_cosine_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector), + FUNCTION 2 vector_norm(vector); diff --git a/src/hnsw.c b/src/hnsw.c new file mode 100644 index 0000000..0ec5592 --- /dev/null +++ b/src/hnsw.c @@ -0,0 +1,210 @@ +#include "postgres.h" + +#include + +#include "access/amapi.h" +#include "commands/vacuum.h" +#include "hnsw.h" +#include "utils/guc.h" +#include "utils/selfuncs.h" + +#if PG_VERSION_NUM >= 120000 +#include "commands/progress.h" +#endif + +int hnsw_ef_search; +static relopt_kind hnsw_relopt_kind; + +/* + * Initialize index options and variables + */ +void +HnswInit(void) +{ + hnsw_relopt_kind = add_reloption_kind(); + add_int_reloption(hnsw_relopt_kind, "m", "Number of connections", + HNSW_DEFAULT_M, HNSW_MIN_M, HNSW_MAX_M +#if PG_VERSION_NUM >= 130000 + ,AccessExclusiveLock +#endif + ); + add_int_reloption(hnsw_relopt_kind, "ef_construction", "Size of dynamic candidate list", + HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION +#if PG_VERSION_NUM >= 130000 + ,AccessExclusiveLock +#endif + ); + + DefineCustomIntVariable("hnsw.ef_search", "Sets the size of dynamic candidate list", + "Valid range is 10..1000.", &hnsw_ef_search, + HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL); +} + +/* + * Get the name of index build phase + */ +#if PG_VERSION_NUM >= 120000 +static char * +hnswbuildphasename(int64 phasenum) +{ + switch (phasenum) + { + case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE: + return "initializing"; + case PROGRESS_HNSW_PHASE_LOAD: + return "loading tuples"; + default: + return NULL; + } +} +#endif + +/* + * Estimate the cost of an index scan + */ +static void +hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, + Cost *indexStartupCost, Cost *indexTotalCost, + Selectivity *indexSelectivity, double *indexCorrelation, + double *indexPages) +{ + GenericCosts costs; +#if PG_VERSION_NUM < 120000 + List *qinfos; +#endif + + /* Never use index without order */ + if (path->indexorderbys == NULL) + { + *indexStartupCost = DBL_MAX; + *indexTotalCost = DBL_MAX; + *indexSelectivity = 0; + *indexCorrelation = 0; + *indexPages = 0; + return; + } + + MemSet(&costs, 0, sizeof(costs)); + +#if PG_VERSION_NUM >= 120000 + genericcostestimate(root, path, loop_count, &costs); +#else + qinfos = deconstruct_indexquals(path); + genericcostestimate(root, path, loop_count, qinfos, &costs); +#endif + + /* TODO Improve cost estimate */ + + *indexStartupCost = costs.indexStartupCost; + *indexTotalCost = costs.indexTotalCost; + *indexSelectivity = costs.indexSelectivity; + *indexCorrelation = costs.indexCorrelation; + *indexPages = costs.numIndexPages; +} + +/* + * Parse and validate the reloptions + */ +static bytea * +hnswoptions(Datum reloptions, bool validate) +{ + static const relopt_parse_elt tab[] = { + {"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)}, + {"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)}, + }; + +#if PG_VERSION_NUM >= 130000 + return (bytea *) build_reloptions(reloptions, validate, + hnsw_relopt_kind, + sizeof(HnswOptions), + tab, lengthof(tab)); +#else + relopt_value *options; + int numoptions; + HnswOptions *rdopts; + + options = parseRelOptions(reloptions, validate, hnsw_relopt_kind, &numoptions); + rdopts = allocateReloptStruct(sizeof(HnswOptions), options, numoptions); + fillRelOptions((void *) rdopts, sizeof(HnswOptions), options, numoptions, + validate, tab, lengthof(tab)); + + return (bytea *) rdopts; +#endif +} + +/* + * Validate catalog entries for the specified operator class + */ +static bool +hnswvalidate(Oid opclassoid) +{ + return true; +} + +/* + * Define index handler + * + * See https://www.postgresql.org/docs/current/index-api.html + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(hnswhandler); +Datum +hnswhandler(PG_FUNCTION_ARGS) +{ + IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); + + amroutine->amstrategies = 0; + amroutine->amsupport = 2; +#if PG_VERSION_NUM >= 130000 + amroutine->amoptsprocnum = 0; +#endif + amroutine->amcanorder = false; + amroutine->amcanorderbyop = true; + amroutine->amcanbackward = false; /* can change direction mid-scan */ + amroutine->amcanunique = false; + amroutine->amcanmulticol = false; + amroutine->amoptionalkey = true; + amroutine->amsearcharray = false; + amroutine->amsearchnulls = false; + amroutine->amstorage = false; + amroutine->amclusterable = false; + amroutine->ampredlocks = false; + amroutine->amcanparallel = false; + amroutine->amcaninclude = false; +#if PG_VERSION_NUM >= 130000 + amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */ + amroutine->amparallelvacuumoptions = VACUUM_OPTION_PARALLEL_BULKDEL; +#endif + amroutine->amkeytype = InvalidOid; + + /* Interface functions */ + amroutine->ambuild = hnswbuild; + amroutine->ambuildempty = hnswbuildempty; + amroutine->aminsert = hnswinsert; + amroutine->ambulkdelete = hnswbulkdelete; + amroutine->amvacuumcleanup = hnswvacuumcleanup; + amroutine->amcanreturn = NULL; /* tuple not included in heapsort */ + amroutine->amcostestimate = hnswcostestimate; + amroutine->amoptions = hnswoptions; + amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */ +#if PG_VERSION_NUM >= 120000 + amroutine->ambuildphasename = hnswbuildphasename; +#endif + amroutine->amvalidate = hnswvalidate; +#if PG_VERSION_NUM >= 140000 + amroutine->amadjustmembers = NULL; +#endif + amroutine->ambeginscan = hnswbeginscan; + amroutine->amrescan = hnswrescan; + amroutine->amgettuple = hnswgettuple; + amroutine->amgetbitmap = NULL; + amroutine->amendscan = hnswendscan; + amroutine->ammarkpos = NULL; + amroutine->amrestrpos = NULL; + + /* Interface functions to support parallel index scans */ + amroutine->amestimateparallelscan = NULL; + amroutine->aminitparallelscan = NULL; + amroutine->amparallelrescan = NULL; + + PG_RETURN_POINTER(amroutine); +} diff --git a/src/hnsw.h b/src/hnsw.h new file mode 100644 index 0000000..28b9f38 --- /dev/null +++ b/src/hnsw.h @@ -0,0 +1,271 @@ +#ifndef HNSW_H +#define HNSW_H + +#include "postgres.h" + +#include "access/generic_xlog.h" +#include "access/reloptions.h" +#include "nodes/execnodes.h" +#include "port.h" /* for random() */ +#include "utils/sampling.h" +#include "vector.h" + +#if PG_VERSION_NUM < 110000 +#error "Requires PostgreSQL 11+" +#endif + +#define HNSW_MAX_DIM 2000 + +/* Support functions */ +#define HNSW_DISTANCE_PROC 1 +#define HNSW_NORM_PROC 2 + +#define HNSW_VERSION 1 +#define HNSW_MAGIC_NUMBER 0xA953A953 +#define HNSW_PAGE_ID 0xFF85 + +/* Preserved page numbers */ +#define HNSW_METAPAGE_BLKNO 0 +#define HNSW_HEAD_BLKNO 1 /* first element page */ + +#define HNSW_DEFAULT_M 16 +#define HNSW_MIN_M 4 +#define HNSW_MAX_M 100 +#define HNSW_DEFAULT_EF_CONSTRUCTION 40 +#define HNSW_MIN_EF_CONSTRUCTION 10 +#define HNSW_MAX_EF_CONSTRUCTION 1000 +#define HNSW_DEFAULT_EF_SEARCH 40 +#define HNSW_MIN_EF_SEARCH 10 +#define HNSW_MAX_EF_SEARCH 1000 + +#define HNSW_HEAPTIDS 10 + +/* Build phases */ +/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ +#define PROGRESS_HNSW_PHASE_LOAD 2 + +#define HNSW_ELEMENT_TUPLE_SIZE(_dim) (offsetof(HnswElementTupleData, vec) + VECTOR_SIZE(_dim)) + +#define HnswPageGetOpaque(page) ((HnswPageOpaque) PageGetSpecialPointer(page)) +#define HnswPageGetMeta(page) ((HnswMetaPageData *) PageGetContents(page)) + +#if PG_VERSION_NUM >= 150000 +#define RandomDouble() pg_prng_double(&pg_global_prng_state) +#else +#define RandomDouble() (((double) random()) / MAX_RANDOM_VALUE) +#endif + +#if PG_VERSION_NUM < 130000 +#define list_delete_last(list) list_truncate(list, list_length(list) - 1) +#define list_sort(list, cmp) list_qsort(list, cmp) +#endif + +#define GetLayerM(m, layer) (layer == 0 ? m * 2 : m) +#define HnswGetMl(m) (1 / log(m)) + +/* Variables */ +extern int hnsw_ef_search; + +typedef struct HnswNeighborArray HnswNeighborArray; + +typedef struct HnswElementData +{ + List *heaptids; + uint8 level; + uint8 deleted; + HnswNeighborArray *neighbors; + BlockNumber blkno; + OffsetNumber offno; + BlockNumber neighborPage; + Vector *vec; +} HnswElementData; + +typedef HnswElementData * HnswElement; + +typedef struct HnswCandidate +{ + HnswElement element; + float distance; +} HnswCandidate; + +typedef struct HnswNeighborArray +{ + int length; + HnswCandidate *items; +} HnswNeighborArray; + +typedef struct HnswUpdate +{ + HnswCandidate hc; + int level; + int index; +} HnswUpdate; + +typedef struct HnswPairingHeapNode +{ + pairingheap_node ph_node; + HnswCandidate *inner; +} HnswPairingHeapNode; + +/* HNSW index options */ +typedef struct HnswOptions +{ + int32 vl_len_; /* varlena header (do not touch directly!) */ + int m; /* number of connections */ + int efConstruction; /* size of dynamic candidate list */ +} HnswOptions; + +typedef struct HnswBuildState +{ + /* Info */ + Relation heap; + Relation index; + IndexInfo *indexInfo; + ForkNumber forkNum; + + /* Settings */ + int dimensions; + int m; + int efConstruction; + + /* Statistics */ + double indtuples; + double reltuples; + + /* Support functions */ + FmgrInfo *procinfo; + FmgrInfo *normprocinfo; + Oid collation; + + /* Variables */ + List *elements; + HnswElement entryPoint; + double ml; + int maxLevel; + double maxInMemoryElements; + bool flushed; + Vector *normvec; + + /* Memory */ + MemoryContext tmpCtx; +} HnswBuildState; + +typedef struct HnswMetaPageData +{ + uint32 magicNumber; + uint32 version; + uint32 dimensions; + uint32 m; + uint32 efConstruction; + BlockNumber entryBlkno; + OffsetNumber entryOffno; + BlockNumber insertPage; +} HnswMetaPageData; + +typedef HnswMetaPageData * HnswMetaPage; + +typedef struct HnswPageOpaqueData +{ + BlockNumber nextblkno; + uint16 unused; + uint16 page_id; /* for identification of HNSW indexes */ +} HnswPageOpaqueData; + +typedef HnswPageOpaqueData * HnswPageOpaque; + +typedef struct HnswElementTupleData +{ + ItemPointerData heaptids[HNSW_HEAPTIDS]; + uint8 level; + uint8 deleted; + uint16 unused; + BlockNumber neighborPage; + Vector vec; +} HnswElementTupleData; + +typedef HnswElementTupleData * HnswElementTuple; + +typedef struct HnswNeighborTupleData +{ + ItemPointerData indextid; + uint16 unused; + float distance; +} HnswNeighborTupleData; + +typedef HnswNeighborTupleData * HnswNeighborTuple; + +typedef struct HnswScanOpaqueData +{ + bool first; + Buffer buf; + List *w; + + /* Support functions */ + FmgrInfo *procinfo; + FmgrInfo *normprocinfo; + Oid collation; +} HnswScanOpaqueData; + +typedef HnswScanOpaqueData * HnswScanOpaque; + +typedef struct HnswVacuumState +{ + Relation index; + IndexBulkDeleteResult *stats; + IndexBulkDeleteCallback callback; + void *callback_state; + int m; + int efConstruction; + HTAB *deleted; + BufferAccessStrategy bas; + FmgrInfo *procinfo; + Oid collation; + HnswNeighborTuple ntup; + Size nsize; + HnswElementData highestPoint; + MemoryContext tmpCtx; +} HnswVacuumState; + +/* Methods */ +int HnswGetM(Relation index); +int HnswGetEfConstruction(Relation index); +FmgrInfo *HnswOptionalProcInfo(Relation rel, uint16 procnum); +bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); +void HnswCommitBuffer(Buffer buf, GenericXLogState *state); +Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); +void HnswInitPage(Buffer buf, Page page); +void HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state); +void HnswInit(void); +List *SearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool inserting, BlockNumber *skipPage); +HnswElement GetEntryPoint(Relation index); +HnswElement HnswInitElement(ItemPointer tid, int m, double ml, int maxLevel); +void HnswFreeElement(HnswElement element); +HnswElement HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, List **updates, bool vacuuming); +HnswCandidate *EntryCandidate(HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadvec); +void UpdateMetaPage(Relation index, bool updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum); +void AddNeighborsToPage(Relation index, Page page, HnswElement e, HnswNeighborTuple neighbor, Size neighborsz, int m); +void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); +void HnswInitNeighbors(HnswElement element, int m); +bool HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel); +void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec); + +/* Index access methods */ +IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo); +void hnswbuildempty(Relation index); +bool hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique +#if PG_VERSION_NUM >= 140000 + ,bool indexUnchanged +#endif + ,IndexInfo *indexInfo +); +IndexBulkDeleteResult *hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state); +IndexBulkDeleteResult *hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats); +IndexScanDesc hnswbeginscan(Relation index, int nkeys, int norderbys); +void hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys); +bool hnswgettuple(IndexScanDesc scan, ScanDirection dir); +void hnswendscan(IndexScanDesc scan); + +/* Ensure fits in uint8 */ +#define HnswGetMaxLevel(m) Min(((BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData))) / (MAXALIGN(sizeof(HnswNeighborTupleData)) + sizeof(ItemIdData)) / m) - 2, 255) + +#endif diff --git a/src/hnswbuild.c b/src/hnswbuild.c new file mode 100644 index 0000000..90ff9c3 --- /dev/null +++ b/src/hnswbuild.c @@ -0,0 +1,484 @@ +#include "postgres.h" + +#include + +#include "catalog/index.h" +#include "hnsw.h" +#include "miscadmin.h" +#include "lib/pairingheap.h" +#include "nodes/pg_list.h" +#include "storage/bufmgr.h" +#include "utils/memutils.h" + +#if PG_VERSION_NUM >= 140000 +#include "utils/backend_progress.h" +#elif PG_VERSION_NUM >= 120000 +#include "pgstat.h" +#endif + +#if PG_VERSION_NUM >= 120000 +#include "access/tableam.h" +#include "commands/progress.h" +#else +#define PROGRESS_CREATEIDX_TUPLES_DONE 0 +#endif + +#if PG_VERSION_NUM >= 130000 +#define CALLBACK_ITEM_POINTER ItemPointer tid +#else +#define CALLBACK_ITEM_POINTER HeapTuple hup +#endif + +#if PG_VERSION_NUM >= 120000 +#define UpdateProgress(index, val) pgstat_progress_update_param(index, val) +#else +#define UpdateProgress(index, val) ((void)val) +#endif + +/* + * Create the metapage + */ +static void +CreateMetaPage(HnswBuildState * buildstate) +{ + Relation index = buildstate->index; + ForkNumber forkNum = buildstate->forkNum; + Buffer buf; + Page page; + GenericXLogState *state; + HnswMetaPage metap; + + buf = HnswNewBuffer(index, forkNum); + HnswInitRegisterPage(index, &buf, &page, &state); + + /* Set metapage data */ + metap = HnswPageGetMeta(page); + metap->magicNumber = HNSW_MAGIC_NUMBER; + metap->version = HNSW_VERSION; + metap->dimensions = buildstate->dimensions; + metap->m = buildstate->m; + metap->efConstruction = buildstate->efConstruction; + metap->entryBlkno = InvalidBlockNumber; + metap->entryOffno = InvalidOffsetNumber; + metap->insertPage = InvalidBlockNumber; + ((PageHeader) page)->pd_lower = + ((char *) metap + sizeof(HnswMetaPageData)) - (char *) page; + + HnswCommitBuffer(buf, state); +} + +/* + * Create element pages + */ +static void +CreateElementPages(HnswBuildState * buildstate) +{ + Relation index = buildstate->index; + ForkNumber forkNum = buildstate->forkNum; + int dimensions = buildstate->dimensions; + Size elementsz; + HnswElementTuple element; + int elementsPerPage; + BlockNumber neighborPage; + BlockNumber insertPage; + Buffer buf; + Page page; + GenericXLogState *state; + ListCell *lc; + + /* Allocate once */ + elementsz = MAXALIGN(HNSW_ELEMENT_TUPLE_SIZE(dimensions)); + element = palloc0(elementsz); + + /* Calculate starting neighbor page */ + elementsPerPage = (BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData))) / (elementsz + sizeof(ItemIdData)); + neighborPage = HNSW_HEAD_BLKNO + (int) ceil(list_length(buildstate->elements) / (double) elementsPerPage); + + /* Prepare first page */ + buf = HnswNewBuffer(index, forkNum); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(buf, page); + + foreach(lc, buildstate->elements) + { + HnswElement e = lfirst(lc); + + /* Calculate neighbor page */ + /* Will be rechecked later */ + e->neighborPage = neighborPage++; + + /* Set item data */ + for (int i = 0; i < HNSW_HEAPTIDS; i++) + { + if (i < list_length(e->heaptids)) + element->heaptids[i] = *((ItemPointer) list_nth(e->heaptids, i)); + else + ItemPointerSetInvalid(&element->heaptids[i]); + } + element->level = e->level; + element->deleted = 0; + element->neighborPage = e->neighborPage; + memcpy(&element->vec, e->vec, VECTOR_SIZE(dimensions)); + + /* Ensure free space */ + if (PageGetFreeSpace(page) < elementsz) + { + /* Add a new page */ + Buffer newbuf = HnswNewBuffer(index, forkNum); + + /* Update previous page */ + HnswPageGetOpaque(page)->nextblkno = BufferGetBlockNumber(newbuf); + + /* Commit */ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + + /* Can take a while, so ensure we can interrupt */ + /* Needs to be called when no buffer locks are held */ + CHECK_FOR_INTERRUPTS(); + + /* Prepare new page */ + buf = newbuf; + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(buf, page); + } + + /* Add the item */ + e->blkno = BufferGetBlockNumber(buf); + e->offno = PageAddItem(page, (Item) element, elementsz, InvalidOffsetNumber, false, false); + if (e->offno == InvalidOffsetNumber) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + if (e == buildstate->entryPoint) + UpdateMetaPage(index, true, e, InvalidBlockNumber, forkNum); + } + + insertPage = BufferGetBlockNumber(buf); + + /* Commit */ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + + UpdateMetaPage(index, false, NULL, insertPage, forkNum); +} + +/* + * Create neighbor pages + */ +static void +CreateNeighborPages(HnswBuildState * buildstate) +{ + Relation index = buildstate->index; + ForkNumber forkNum = buildstate->forkNum; + Size neighborsz; + HnswNeighborTuple neighbor; + ListCell *lc; + + /* Allocate once */ + neighborsz = MAXALIGN(sizeof(HnswNeighborTupleData)); + neighbor = palloc0(neighborsz); + + foreach(lc, buildstate->elements) + { + HnswElement e = lfirst(lc); + Buffer buf; + Page page; + GenericXLogState *state; + + /* Can take a while, so ensure we can interrupt */ + /* Needs to be called when no buffer locks are held */ + CHECK_FOR_INTERRUPTS(); + + buf = HnswNewBuffer(index, forkNum); + + /* Check block number */ + if (BufferGetBlockNumber(buf) != e->neighborPage) + elog(ERROR, "expected neighbor page %d, got %d", e->neighborPage, BufferGetBlockNumber(buf)); + + /* Prepare page */ + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(buf, page); + + AddNeighborsToPage(index, page, e, neighbor, neighborsz, buildstate->m); + + /* Commit */ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + } +} + +/* + * Free elements + */ +static void +FreeElements(HnswBuildState * buildstate) +{ + ListCell *lc; + + foreach(lc, buildstate->elements) + HnswFreeElement(lfirst(lc)); + + list_free(buildstate->elements); +} + +/* + * Flush pages + */ +static void +FlushPages(HnswBuildState * buildstate) +{ + CreateMetaPage(buildstate); + CreateElementPages(buildstate); + CreateNeighborPages(buildstate); + + buildstate->flushed = true; + FreeElements(buildstate); +} + +/* + * Insert tuple + */ +static bool +InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState * buildstate, HnswElement * dup) +{ + FmgrInfo *procinfo = buildstate->procinfo; + Oid collation = buildstate->collation; + HnswElement entryPoint = buildstate->entryPoint; + int efConstruction = buildstate->efConstruction; + int m = buildstate->m; + + /* Detoast once for all calls */ + Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + + /* Normalize if needed */ + if (buildstate->normprocinfo != NULL) + { + if (!HnswNormValue(buildstate->normprocinfo, collation, &value, buildstate->normvec)) + return false; + } + + /* Copy value to element so accessible outside of memory context */ + memcpy(element->vec, DatumGetVector(value), VECTOR_SIZE(buildstate->dimensions)); + + /* Insert element in graph */ + *dup = HnswInsertElement(element, entryPoint, NULL, procinfo, collation, m, efConstruction, NULL, false); + + /* Update entry point if needed */ + if (*dup == NULL && (entryPoint == NULL || element->level > entryPoint->level)) + buildstate->entryPoint = element; + + UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples); + + return *dup == NULL; +} + +/* + * Callback for table_index_build_scan + */ +static void +BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, + bool *isnull, bool tupleIsAlive, void *state) +{ + HnswBuildState *buildstate = (HnswBuildState *) state; + MemoryContext oldCtx; + HnswElement element; + HnswElement dup = NULL; + bool inserted; + +#if PG_VERSION_NUM < 130000 + ItemPointer tid = &hup->t_self; +#endif + + /* Skip nulls */ + if (isnull[0]) + return; + + if (buildstate->indtuples >= buildstate->maxInMemoryElements) + { + if (!buildstate->flushed) + { + /* TODO Improve message */ + ereport(NOTICE, + (errmsg("hnsw graph no longer fits into maintenance_work_mem"), + errdetail("Building will take significantly more time."), + errhint("Increase maintenance_work_mem to speed up future builds."))); + + FlushPages(buildstate); + } + + oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); + + if (HnswInsertTuple(buildstate->index, values, isnull, tid, buildstate->heap)) + UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples); + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(buildstate->tmpCtx); + + return; + } + + /* Allocate necessary memory outside of memory context */ + element = HnswInitElement(tid, buildstate->m, buildstate->ml, buildstate->maxLevel); + element->vec = palloc(VECTOR_SIZE(buildstate->dimensions)); + + /* Use memory context since detoast can allocate */ + oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); + + /* Insert tuple */ + inserted = InsertTuple(index, values, element, buildstate, &dup); + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(buildstate->tmpCtx); + + /* Add outside memory context */ + if (dup != NULL) + HnswAddHeapTid(dup, tid); + + /* Add to buildstate or free */ + if (inserted) + buildstate->elements = lappend(buildstate->elements, element); + else + HnswFreeElement(element); +} + +/* + * Get the max number of elements that fit into maintenance_work_mem + */ +static double +HnswGetMaxInMemoryElements(int m, double ml, int dimensions) +{ + Size elementSize = sizeof(HnswElementData); + double avgLevel = -log(0.5) * ml; + + elementSize += sizeof(HnswCandidate) * (m * (avgLevel + 2)); + elementSize += sizeof(ItemPointerData); + elementSize += VECTOR_SIZE(dimensions); + return (maintenance_work_mem * 1024L) / elementSize; +} + +/* + * Initialize the build state + */ +static void +InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum) +{ + buildstate->heap = heap; + buildstate->index = index; + buildstate->indexInfo = indexInfo; + buildstate->forkNum = forkNum; + + buildstate->m = HnswGetM(index); + buildstate->efConstruction = HnswGetEfConstruction(index); + buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + + /* Require column to have dimensions to be indexed */ + if (buildstate->dimensions < 0) + elog(ERROR, "column does not have dimensions"); + + if (buildstate->dimensions > HNSW_MAX_DIM) + elog(ERROR, "column cannot have more than %d dimensions for hnsw index", HNSW_MAX_DIM); + + buildstate->reltuples = 0; + buildstate->indtuples = 0; + + /* Get support functions */ + buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + buildstate->collation = index->rd_indcollation[0]; + + buildstate->elements = NIL; + buildstate->entryPoint = NULL; + buildstate->ml = HnswGetMl(buildstate->m); + buildstate->maxLevel = HnswGetMaxLevel(buildstate->m); + buildstate->maxInMemoryElements = HnswGetMaxInMemoryElements(buildstate->m, buildstate->ml, buildstate->dimensions); + buildstate->flushed = false; + + /* Reuse for each tuple */ + buildstate->normvec = InitVector(buildstate->dimensions); + + buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, + "Hnsw build temporary context", + ALLOCSET_DEFAULT_SIZES); +} + +/* + * Free resources + */ +static void +FreeBuildState(HnswBuildState * buildstate) +{ + MemoryContextDelete(buildstate->tmpCtx); +} + +/* + * Build graph + */ +static void +BuildGraph(HnswBuildState * buildstate, ForkNumber forkNum) +{ + UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_HNSW_PHASE_LOAD); + +#if PG_VERSION_NUM >= 120000 + buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, true, BuildCallback, (void *) buildstate, NULL); +#else + buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, BuildCallback, (void *) buildstate, NULL); +#endif +} + +/* + * Build the index + */ +static void +BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo, + HnswBuildState * buildstate, ForkNumber forkNum) +{ + InitBuildState(buildstate, heap, index, indexInfo, forkNum); + + if (buildstate->heap != NULL) + BuildGraph(buildstate, forkNum); + + if (!buildstate->flushed) + FlushPages(buildstate); + + FreeBuildState(buildstate); +} + +/* + * Build the index for a logged table + */ +IndexBuildResult * +hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo) +{ + IndexBuildResult *result; + HnswBuildState buildstate; + + BuildIndex(heap, index, indexInfo, &buildstate, MAIN_FORKNUM); + + result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult)); + result->heap_tuples = buildstate.reltuples; + result->index_tuples = buildstate.indtuples; + + return result; +} + +/* + * Build the index for an unlogged table + */ +void +hnswbuildempty(Relation index) +{ + IndexInfo *indexInfo = BuildIndexInfo(index); + HnswBuildState buildstate; + + BuildIndex(NULL, index, indexInfo, &buildstate, INIT_FORKNUM); +} diff --git a/src/hnswinsert.c b/src/hnswinsert.c new file mode 100644 index 0000000..8fa9664 --- /dev/null +++ b/src/hnswinsert.c @@ -0,0 +1,416 @@ +#include "postgres.h" + +#include + +#include "hnsw.h" +#include "storage/bufmgr.h" +#include "storage/lmgr.h" +#include "utils/memutils.h" + +/* + * Get the insert page + */ +static BlockNumber +GetInsertPage(Relation index) +{ + Buffer buf; + Page page; + HnswMetaPage metap; + BlockNumber insertPage; + + buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = HnswPageGetMeta(page); + + insertPage = metap->insertPage; + + UnlockReleaseBuffer(buf); + + return insertPage; +} + +/* + * Check for a free offset + */ +static bool +HnswFreeOffset(Page page, OffsetNumber *freeOffno, BlockNumber *neighborPage) +{ + OffsetNumber offno; + OffsetNumber maxoffno = PageGetMaxOffsetNumber(page); + + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple item = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + + if (item->deleted) + { + *freeOffno = offno; + *neighborPage = item->neighborPage; + return true; + } + } + + return false; +} + +/* + * Add to element and neighbor pages + */ +static void +WriteNewElementPages(Relation index, HnswElement e, int m) +{ + Buffer buf; + Page page; + GenericXLogState *state; + Size esize; + HnswElementTuple etup; + BlockNumber insertPage = GetInsertPage(index); + BlockNumber originalInsertPage = insertPage; + int dimensions = e->vec->dim; + Size nsize = MAXALIGN(sizeof(HnswNeighborTupleData)); + HnswNeighborTuple ntup = palloc0(nsize); + Buffer nbuf; + Page npage; + OffsetNumber freeOffno = InvalidOffsetNumber; + BlockNumber neighborPage = InvalidBlockNumber; + + /* Get tuple size */ + esize = MAXALIGN(HNSW_ELEMENT_TUPLE_SIZE(dimensions)); + + /* Prepare tuple */ + etup = palloc0(esize); + etup->heaptids[0] = *((ItemPointer) linitial(e->heaptids)); + for (int i = 1; i < HNSW_HEAPTIDS; i++) + ItemPointerSetInvalid(&etup->heaptids[i]); + etup->level = e->level; + etup->deleted = 0; + memcpy(&etup->vec, e->vec, VECTOR_SIZE(dimensions)); + + /* Find a page to insert the item */ + for (;;) + { + buf = ReadBuffer(index, insertPage); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + if (HnswFreeOffset(page, &freeOffno, &neighborPage) || PageGetFreeSpace(page) >= esize) + break; + + insertPage = HnswPageGetOpaque(page)->nextblkno; + + if (BlockNumberIsValid(insertPage)) + { + /* Move to next page */ + GenericXLogAbort(state); + UnlockReleaseBuffer(buf); + } + else + { + Buffer newbuf; + Page newpage; + + /* + * From ReadBufferExtended: Caller is responsible for ensuring + * that only one backend tries to extend a relation at the same + * time! + */ + LockRelationForExtension(index, ExclusiveLock); + + /* Add a new page */ + newbuf = HnswNewBuffer(index, MAIN_FORKNUM); + + /* Unlock extend relation lock as early as possible */ + UnlockRelationForExtension(index, ExclusiveLock); + + /* Init new page */ + newpage = GenericXLogRegisterBuffer(state, newbuf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(newbuf, newpage); + + /* Update insert page */ + insertPage = BufferGetBlockNumber(newbuf); + + /* Update previous buffer */ + HnswPageGetOpaque(page)->nextblkno = insertPage; + + /* Commit */ + MarkBufferDirty(newbuf); + MarkBufferDirty(buf); + GenericXLogFinish(state); + + /* Unlock previous buffer */ + UnlockReleaseBuffer(buf); + + /* Prepare new buffer */ + state = GenericXLogStart(index); + buf = newbuf; + page = GenericXLogRegisterBuffer(state, buf, 0); + break; + } + } + + if (OffsetNumberIsValid(freeOffno)) + { + /* Reuse existing page */ + nbuf = ReadBuffer(index, neighborPage); + LockBuffer(nbuf, BUFFER_LOCK_EXCLUSIVE); + } + else + { + /* Add new page */ + LockRelationForExtension(index, ExclusiveLock); + nbuf = HnswNewBuffer(index, MAIN_FORKNUM); + UnlockRelationForExtension(index, ExclusiveLock); + } + + npage = GenericXLogRegisterBuffer(state, nbuf, GENERIC_XLOG_FULL_IMAGE); + + /* Overwrites existing page via InitPage */ + HnswInitPage(nbuf, npage); + + /* Update neighbors */ + AddNeighborsToPage(index, npage, e, ntup, nsize, m); + + e->blkno = BufferGetBlockNumber(buf); + e->neighborPage = BufferGetBlockNumber(nbuf); + + /* Set neighbor page for element */ + etup->neighborPage = e->neighborPage; + + /* Add to next offset */ + if (OffsetNumberIsValid(freeOffno)) + { + e->offno = freeOffno; + if (!PageIndexTupleOverwrite(page, freeOffno, (Item) etup, esize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + } + else + { + e->offno = PageAddItem(page, (Item) etup, esize, InvalidOffsetNumber, false, false); + if (e->offno == InvalidOffsetNumber) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + } + + /* Commit */ + MarkBufferDirty(buf); + MarkBufferDirty(nbuf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + UnlockReleaseBuffer(nbuf); + + /* Update the insert page */ + if (insertPage != originalInsertPage) + UpdateMetaPage(index, false, NULL, insertPage, MAIN_FORKNUM); +} + +/* + * Calculate offset number for update + */ +static OffsetNumber +HnswGetOffsetNumber(HnswUpdate * update, int m) +{ + return FirstOffsetNumber + (update->hc.element->level - update->level) * m + update->index; +} + +/* + * Update neighbors + */ +static void +UpdateNeighborPages(Relation index, HnswElement e, int m, List *updates) +{ + Buffer buf; + Page page; + GenericXLogState *state; + ListCell *lc; + OffsetNumber offno; + Size neighborsz = MAXALIGN(sizeof(HnswNeighborTupleData)); + HnswNeighborTuple neighbor = palloc0(neighborsz); + + /* Could update multiple at once for same element */ + /* but should only happen a low percent of time, so keep simple for now */ + foreach(lc, updates) + { + HnswUpdate *update = lfirst(lc); + + /* Register page */ + buf = ReadBuffer(index, update->hc.element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + offno = HnswGetOffsetNumber(update, m); + + /* Make robust against issues */ + if (offno <= PageGetMaxOffsetNumber(page)) + { + /* Set item data */ + ItemPointerSet(&neighbor->indextid, e->blkno, e->offno); + neighbor->distance = update->hc.distance; + + /* Update connections */ + if (!PageIndexTupleOverwrite(page, offno, (Item) neighbor, neighborsz)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + MarkBufferDirty(buf); + GenericXLogFinish(state); + } + else + GenericXLogAbort(state); + + UnlockReleaseBuffer(buf); + } +} + +/* + * Add a heap tid to an existing element + */ +static bool +HnswAddDuplicate(Relation index, HnswElement element, HnswElement dup) +{ + Buffer buf; + Page page; + GenericXLogState *state; + Size esize = MAXALIGN(HNSW_ELEMENT_TUPLE_SIZE(dup->vec->dim)); + HnswElementTuple etup; + int i; + + /* Read page */ + buf = ReadBuffer(index, dup->blkno); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + /* Find space */ + etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, dup->offno)); + for (i = 0; i < HNSW_HEAPTIDS; i++) + { + if (!ItemPointerIsValid(&etup->heaptids[i])) + break; + } + + /* Either being deleted or we lost our chance to another backend */ + if (i == 0 || i == HNSW_HEAPTIDS) + { + GenericXLogAbort(state); + UnlockReleaseBuffer(buf); + return false; + } + + /* Add heap tid */ + etup->heaptids[i] = *((ItemPointer) linitial(element->heaptids)); + + /* Update index tuple */ + if (!PageIndexTupleOverwrite(page, dup->offno, (Item) etup, esize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + + return true; +} + +/* + * Write changes to disk + */ +static void +WriteElement(Relation index, HnswElement element, int m, List *updates, HnswElement dup, HnswElement entryPoint) +{ + /* Try to add to existing page */ + if (dup != NULL) + { + if (HnswAddDuplicate(index, element, dup)) + return; + } + + /* If fails, take this path */ + WriteNewElementPages(index, element, m); + UpdateNeighborPages(index, element, m, updates); + + /* Update metapage if needed */ + if (entryPoint == NULL || element->level > entryPoint->level) + UpdateMetaPage(index, true, element, InvalidBlockNumber, MAIN_FORKNUM); +} + +/* + * Insert a tuple into the index + */ +bool +HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel) +{ + Datum value; + FmgrInfo *normprocinfo; + HnswElement entryPoint; + HnswElement element; + int m = HnswGetM(index); + int efConstruction = HnswGetEfConstruction(index); + double ml = HnswGetMl(m); + FmgrInfo *procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + Oid collation = index->rd_indcollation[0]; + List *updates = NIL; + HnswElement dup; + + /* Detoast once for all calls */ + value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + + /* Normalize if needed */ + normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + if (normprocinfo != NULL) + { + if (!HnswNormValue(normprocinfo, collation, &value, NULL)) + return false; + } + + /* Create an element */ + element = HnswInitElement(heap_tid, m, ml, HnswGetMaxLevel(m)); + element->vec = DatumGetVector(value); + + /* Get entry point */ + entryPoint = GetEntryPoint(index); + + /* Insert element in graph */ + dup = HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, &updates, false); + + /* Write to disk */ + WriteElement(index, element, m, updates, dup, entryPoint); + + return true; +} + +/* + * Insert a tuple into the index + */ +bool +hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, + Relation heap, IndexUniqueCheck checkUnique +#if PG_VERSION_NUM >= 140000 + ,bool indexUnchanged +#endif + ,IndexInfo *indexInfo +) +{ + MemoryContext oldCtx; + MemoryContext insertCtx; + + /* Skip nulls */ + if (isnull[0]) + return false; + + /* Create memory context */ + insertCtx = AllocSetContextCreate(CurrentMemoryContext, + "Hnsw insert temporary context", + ALLOCSET_DEFAULT_SIZES); + oldCtx = MemoryContextSwitchTo(insertCtx); + + /* Insert tuple */ + HnswInsertTuple(index, values, isnull, heap_tid, heap); + + /* Delete memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextDelete(insertCtx); + + return false; +} diff --git a/src/hnswscan.c b/src/hnswscan.c new file mode 100644 index 0000000..dc95acc --- /dev/null +++ b/src/hnswscan.c @@ -0,0 +1,191 @@ +#include "postgres.h" + +#include "access/relscan.h" +#include "hnsw.h" +#include "pgstat.h" +#include "storage/bufmgr.h" + +/* + * Algorithm 5 from paper + */ +static void +GetScanItems(IndexScanDesc scan, Datum q) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + Relation index = scan->indexRelation; + FmgrInfo *procinfo = so->procinfo; + Oid collation = so->collation; + List *ep = NIL; + List *w; + HnswElement entryPoint = GetEntryPoint(index); + + if (entryPoint == NULL) + return; + + /* TODO Use memory context */ + ep = lappend(ep, EntryCandidate(entryPoint, q, index, procinfo, collation, false)); + + for (int lc = entryPoint->level; lc >= 1; lc--) + { + w = SearchLayer(q, ep, 1, lc, index, procinfo, collation, false, NULL); + ep = w; + } + + /* TODO Return all visited elements at level 0, not just ef search */ + so->w = SearchLayer(q, ep, hnsw_ef_search, 0, index, procinfo, collation, false, NULL); +} + +/* + * Prepare for an index scan + */ +IndexScanDesc +hnswbeginscan(Relation index, int nkeys, int norderbys) +{ + IndexScanDesc scan; + HnswScanOpaque so; + + scan = RelationGetIndexScan(index, nkeys, norderbys); + + so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData)); + so->buf = InvalidBuffer; + so->first = true; + so->w = NIL; + + /* Set support functions */ + so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + so->collation = index->rd_indcollation[0]; + + scan->opaque = so; + + return scan; +} + +/* + * Start or restart an index scan + */ +void +hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + + so->first = true; + list_free(so->w); + so->w = NIL; + + if (keys && scan->numberOfKeys > 0) + memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData)); + + if (orderbys && scan->numberOfOrderBys > 0) + memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData)); +} + +/* + * Fetch the next tuple in the given scan + */ +bool +hnswgettuple(IndexScanDesc scan, ScanDirection dir) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + + /* + * Index can be used to scan backward, but Postgres doesn't support + * backward scan on operators + */ + Assert(ScanDirectionIsForward(dir)); + + if (so->first) + { + Datum value; + + /* Count index scan for stats */ + pgstat_count_index_scan(scan->indexRelation); + + /* Safety check */ + if (scan->orderByData == NULL) + elog(ERROR, "cannot scan hnsw index without order"); + + /* No items will match if null */ + if (scan->orderByData->sk_flags & SK_ISNULL) + return false; + + value = scan->orderByData->sk_argument; + + /* Value should not be compressed or toasted */ + Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); + Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); + + if (so->normprocinfo != NULL) + { + /* No items will match if normalization fails */ + if (!HnswNormValue(so->normprocinfo, so->collation, &value, NULL)) + return false; + } + + GetScanItems(scan, value); + so->first = false; + + /* Clean up if we allocated a new value */ + if (value != scan->orderByData->sk_argument) + pfree(DatumGetPointer(value)); + } + + while (list_length(so->w) > 0) + { + HnswCandidate *hc = llast(so->w); + ItemPointer tid; + BlockNumber indexblkno; + + /* Move to next element if no valid heap tids */ + if (list_length(hc->element->heaptids) == 0) + { + so->w = list_delete_last(so->w); + continue; + } + + tid = llast(hc->element->heaptids); + indexblkno = hc->element->blkno; + + hc->element->heaptids = list_delete_last(hc->element->heaptids); + +#if PG_VERSION_NUM >= 120000 + scan->xs_heaptid = *tid; +#else + scan->xs_ctup.t_self = *tid; +#endif + + if (BufferIsValid(so->buf)) + ReleaseBuffer(so->buf); + + /* + * An index scan must maintain a pin on the index page holding the + * item last returned by amgettuple + * + * https://www.postgresql.org/docs/current/index-locking.html + */ + so->buf = ReadBuffer(scan->indexRelation, indexblkno); + + scan->xs_recheckorderby = false; + return true; + } + + return false; +} + +/* + * End a scan and release resources + */ +void +hnswendscan(IndexScanDesc scan) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + + /* Release pin */ + if (BufferIsValid(so->buf)) + ReleaseBuffer(so->buf); + + list_free(so->w); + + pfree(so); + scan->opaque = NULL; +} diff --git a/src/hnswutils.c b/src/hnswutils.c new file mode 100644 index 0000000..3f70a2f --- /dev/null +++ b/src/hnswutils.c @@ -0,0 +1,927 @@ +#include "postgres.h" + +#include + +#include "hnsw.h" +#include "storage/bufmgr.h" +#include "vector.h" + +/* + * Get the number of connection in the index + */ +int +HnswGetM(Relation index) +{ + HnswOptions *opts = (HnswOptions *) index->rd_options; + + if (opts) + return opts->m; + + return HNSW_DEFAULT_M; +} + +/* + * Get the size of the dynamic candidate list in the index + */ +int +HnswGetEfConstruction(Relation index) +{ + HnswOptions *opts = (HnswOptions *) index->rd_options; + + if (opts) + return opts->efConstruction; + + return HNSW_DEFAULT_EF_CONSTRUCTION; +} + +/* + * Get proc + */ +FmgrInfo * +HnswOptionalProcInfo(Relation rel, uint16 procnum) +{ + if (!OidIsValid(index_getprocid(rel, 1, procnum))) + return NULL; + + return index_getprocinfo(rel, 1, procnum); +} + +/* + * Divide by the norm + * + * Returns false if value should not be indexed + * + * The caller needs to free the pointer stored in value + * if it's different than the original value + */ +bool +HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result) +{ + double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); + + if (norm > 0) + { + Vector *v = DatumGetVector(*value); + + if (result == NULL) + result = InitVector(v->dim); + + for (int i = 0; i < v->dim; i++) + result->x[i] = v->x[i] / norm; + + *value = PointerGetDatum(result); + + return true; + } + + return false; +} + +/* + * New buffer + */ +Buffer +HnswNewBuffer(Relation index, ForkNumber forkNum) +{ + Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL); + + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + return buf; +} + +/* + * Init page + */ +void +HnswInitPage(Buffer buf, Page page) +{ + PageInit(page, BufferGetPageSize(buf), sizeof(HnswPageOpaqueData)); + HnswPageGetOpaque(page)->nextblkno = InvalidBlockNumber; + HnswPageGetOpaque(page)->page_id = HNSW_PAGE_ID; +} + +/* + * Init and register page + */ +void +HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state) +{ + *state = GenericXLogStart(index); + *page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(*buf, *page); +} + +/* + * Commit buffer + */ +void +HnswCommitBuffer(Buffer buf, GenericXLogState *state) +{ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); +} + +/* + * Create an element from block and offset + */ +static HnswElement +CreateElementFromBlock(BlockNumber blkno, OffsetNumber offno) +{ + HnswElement element = palloc(sizeof(HnswElementData)); + + element->blkno = blkno; + element->offno = offno; + element->neighbors = NULL; + element->vec = NULL; + return element; +} + +/* + * Get the entry point + */ +HnswElement +GetEntryPoint(Relation index) +{ + Buffer buf; + Page page; + HnswMetaPage metap; + HnswElement entryPoint = NULL; + + buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = HnswPageGetMeta(page); + + if (BlockNumberIsValid(metap->entryBlkno)) + entryPoint = CreateElementFromBlock(metap->entryBlkno, metap->entryOffno); + + UnlockReleaseBuffer(buf); + + return entryPoint; +} + +/* + * Compare candidate distances + */ +static int +CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) + return 1; + + if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) + return -1; + + return 0; +} + +/* + * Compare candidate distances + */ +static int +CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) + return -1; + + if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) + return 1; + + return 0; +} + +/* + * Create a pairing heap node for a candidate + */ +static HnswPairingHeapNode * +CreatePairingHeapNode(HnswCandidate * c) +{ + HnswPairingHeapNode *node = palloc(sizeof(HnswPairingHeapNode)); + + node->inner = c; + return node; +} + +/* + * Calculate the distance between elements + */ +static float +HnswGetDistance(HnswElement a, HnswElement b, int lc, FmgrInfo *procinfo, Oid collation) +{ + /* Look for cached distance */ + if (a->neighbors != NULL) + { + for (int i = 0; i < a->neighbors[lc].length; i++) + { + if (a->neighbors[lc].items[i].element == b) + return a->neighbors[lc].items[i].distance; + } + } + + if (b->neighbors != NULL) + { + for (int i = 0; i < b->neighbors[lc].length; i++) + { + if (b->neighbors[lc].items[i].element == a) + return b->neighbors[lc].items[i].distance; + } + } + + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(a->vec), PointerGetDatum(b->vec))); +} + +/* + * Check if an element is closer to q than any element from R + */ +static bool +CheckElementCloser(HnswCandidate * e, List *r, int lc, FmgrInfo *procinfo, Oid collation) +{ + ListCell *lc2; + + foreach(lc2, r) + { + HnswCandidate *ri = lfirst(lc2); + float distance = HnswGetDistance(e->element, ri->element, lc, procinfo, collation); + + if (distance <= e->distance) + return false; + } + + return true; +} + +/* + * Algorithm 4 from paper + */ +static List * +SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswCandidate * *pruned) +{ + List *r = NIL; + List *w = list_copy(c); + pairingheap *wd; + + if (list_length(w) < m) + return w; + + wd = pairingheap_allocate(CompareNearestCandidates, NULL); + + while (list_length(w) > 0 && list_length(r) < m) + { + /* Assumes w is already ordered desc */ + HnswCandidate *e = llast(w); + bool closer; + + w = list_delete_last(w); + + closer = CheckElementCloser(e, r, lc, procinfo, collation); + + if (closer) + r = lappend(r, e); + else + pairingheap_add(wd, &(CreatePairingHeapNode(e)->ph_node)); + } + + /* Keep pruned connections */ + while (!pairingheap_is_empty(wd) && list_length(r) < m) + r = lappend(r, ((HnswPairingHeapNode *) pairingheap_remove_first(wd))->inner); + + /* Return pruned for update connections */ + if (pruned != NULL) + { + if (!pairingheap_is_empty(wd)) + *pruned = ((HnswPairingHeapNode *) pairingheap_first(wd))->inner; + else + *pruned = linitial(w); + } + + return r; +} + +/* + * Add connections + */ +static void +AddConnections(HnswElement element, List *neighbors, int m, int lc) +{ + ListCell *lc2; + HnswNeighborArray *a = &element->neighbors[lc]; + + foreach(lc2, neighbors) + a->items[a->length++] = *((HnswCandidate *) lfirst(lc2)); +} + +/* + * Create update + */ +static HnswUpdate * +CreateUpdate(HnswCandidate * hc, int level, int index) +{ + HnswUpdate *update = palloc(sizeof(HnswUpdate)); + + update->hc = *hc; + update->level = level; + update->index = index; + return update; +} + +/* + * Compare candidate distances + */ +static int +#if PG_VERSION_NUM >= 130000 +CompareCandidateDistances(const ListCell *a, const ListCell *b) +#else +CompareCandidateDistances(const void *a, const void *b) +#endif +{ + HnswCandidate *hca = lfirst((ListCell *) a); + HnswCandidate *hcb = lfirst((ListCell *) b); + + if (hca->distance < hcb->distance) + return 1; + + if (hca->distance > hcb->distance) + return -1; + + return 0; +} + +/* + * Add a heap TID to an element + */ +void +HnswAddHeapTid(HnswElement element, ItemPointer heaptid) +{ + ItemPointer copy = palloc(sizeof(ItemPointerData)); + + ItemPointerCopy(heaptid, copy); + element->heaptids = lappend(element->heaptids, copy); +} + +/* + * Load an element and optionally get its distance from q + */ +void +HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec) +{ + Buffer buf; + Page page; + HnswElementTuple item; + + /* Read vector */ + buf = ReadBuffer(index, element->blkno); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + + item = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, element->offno)); + + /* Load element */ + element->heaptids = NIL; + for (int i = 0; i < HNSW_HEAPTIDS; i++) + { + /* Can stop at first invalid */ + if (!ItemPointerIsValid(&item->heaptids[i])) + break; + + HnswAddHeapTid(element, &item->heaptids[i]); + } + element->level = item->level; + element->neighborPage = item->neighborPage; + element->deleted = item->deleted; + + if (loadvec) + { + element->vec = palloc(VECTOR_SIZE(item->vec.dim)); + memcpy(element->vec, &item->vec, VECTOR_SIZE(item->vec.dim)); + } + + /* Calculate distance */ + if (distance != NULL) + *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&item->vec))); + + UnlockReleaseBuffer(buf); +} + +/* + * Update connections + */ +static void +UpdateConnections(HnswElement element, List *neighbors, int m, int lc, List **updates, Relation index, FmgrInfo *procinfo, Oid collation) +{ + ListCell *lc2; + + foreach(lc2, neighbors) + { + HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); + HnswNeighborArray *currentNeighbors = &hc->element->neighbors[lc]; + + HnswCandidate hc2; + + hc2.element = element; + hc2.distance = hc->distance; + + if (currentNeighbors->length < m) + { + currentNeighbors->items[currentNeighbors->length++] = hc2; + + /* Track updates */ + if (updates != NULL) + *updates = lappend(*updates, CreateUpdate(hc, lc, currentNeighbors->length - 1)); + } + else + { + /* Shrink connections */ + HnswCandidate *pruned = NULL; + List *c = NIL; + + /* Add and sort candidates */ + for (int i = 0; i < currentNeighbors->length; i++) + c = lappend(c, ¤tNeighbors->items[i]); + c = lappend(c, &hc2); + list_sort(c, CompareCandidateDistances); + + /* Load elements on insert */ + if (index != NULL) + { + for (int i = 0; i < currentNeighbors->length; i++) + { + if (currentNeighbors->items[i].element->vec == NULL) + { + HnswLoadElement(currentNeighbors->items[i].element, NULL, NULL, index, procinfo, collation, true); + + /* Prune deleted element */ + if (currentNeighbors->items[i].element->deleted) + { + pruned = ¤tNeighbors->items[i]; + break; + } + } + } + } + + if (pruned == NULL) + { + SelectNeighbors(c, m, lc, procinfo, collation, &pruned); + + /* Should not happen */ + if (pruned == NULL) + continue; + } + + /* Find and replace the pruned element */ + for (int i = 0; i < currentNeighbors->length; i++) + { + if (currentNeighbors->items[i].element == pruned->element) + { + currentNeighbors->items[i] = hc2; + + /* Track updates */ + if (updates != NULL) + *updates = lappend(*updates, CreateUpdate(hc, lc, i)); + + break; + } + } + } + } +} + +/* + * Initialize neighbors + */ +void +HnswInitNeighbors(HnswElement element, int m) +{ + int level = element->level; + + element->neighbors = palloc(sizeof(HnswNeighborArray) * (level + 1)); + + for (int lc = 0; lc <= level; lc++) + { + HnswNeighborArray *a; + int lm = GetLayerM(m, lc); + + a = &element->neighbors[lc]; + a->length = 0; + a->items = palloc(sizeof(HnswCandidate) * lm); + } +} + +/* + * Load neighbors + */ +static void +LoadNeighbors(HnswCandidate * c, Relation index) +{ + Buffer buf; + Page page; + OffsetNumber offno; + OffsetNumber maxoffno; + HnswNeighborTuple neighbor; + HnswNeighborArray *neighbors; + int m = HnswGetM(index); + + buf = ReadBuffer(index, c->element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + maxoffno = PageGetMaxOffsetNumber(page); + + HnswInitNeighbors(c->element, m); + + /* If not, neighbor page represents new item */ + /* Only caught if item has a different level */ + /* TODO Use versioning to fix this? */ + if (maxoffno == (c->element->level + 2) * m) + { + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElement element; + int level; + HnswCandidate *hc; + + neighbor = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, offno)); + + if (!ItemPointerIsValid(&neighbor->indextid)) + continue; + + element = CreateElementFromBlock(ItemPointerGetBlockNumber(&neighbor->indextid), ItemPointerGetOffsetNumber(&neighbor->indextid)); + + /* Calculate level based on offset */ + level = c->element->level - (offno - FirstOffsetNumber) / m; + if (level < 0) + level = 0; + + neighbors = &c->element->neighbors[level]; + hc = &neighbors->items[neighbors->length]; + hc->element = element; + hc->distance = neighbor->distance; + neighbors->length++; + } + } + + UnlockReleaseBuffer(buf); +} + +/* + * Allocate an element + */ +HnswElement +HnswInitElement(ItemPointer heaptid, int m, double ml, int maxLevel) +{ + HnswElement element = palloc(sizeof(HnswElementData)); + + int level = (int) (-log(RandomDouble()) * ml); + + /* Cap level */ + if (level > maxLevel) + level = maxLevel; + + element->heaptids = NIL; + HnswAddHeapTid(element, heaptid); + + element->level = level; + element->deleted = 0; + + HnswInitNeighbors(element, m); + + return element; +} + +/* + * Free an element + */ +void +HnswFreeElement(HnswElement element) +{ + list_free(element->heaptids); + for (int lc = 0; lc <= element->level; lc++) + pfree(element->neighbors[lc].items); + pfree(element->neighbors); + pfree(element->vec); + pfree(element); +} + +/* + * Get the distance for a candidate + */ +static float +GetCandidateDistance(HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation) +{ + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, PointerGetDatum(hc->element->vec))); +} + +/* + * Check if a candidate is a list member by pointer + */ +static bool +HnswListMemberByPointer(List *v, HnswCandidate * e) +{ + ListCell *lc2; + + foreach(lc2, v) + { + HnswCandidate *v2 = (HnswCandidate *) lfirst(lc2); + + if (v2->element == e->element) + return true; + } + + return false; +} + +/* + * Check if a candidate is a list member by block and offset + */ +static bool +HnswListMemberByBlock(List *v, HnswCandidate * e) +{ + ListCell *lc2; + + foreach(lc2, v) + { + HnswCandidate *v2 = (HnswCandidate *) lfirst(lc2); + + /* Neighbor page not set yet */ + if (v2->element->blkno == e->element->blkno && v2->element->offno == e->element->offno) + return true; + } + + return false; +} + +/* + * Algorithm 2 from paper + */ +List * +SearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool inserting, BlockNumber *skipPage) +{ + ListCell *lc2; + + List *w = NIL; + List *v = NIL; + pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); + pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL); + int wlen = 0; + + /* Add entry points to v, C, and W */ + foreach(lc2, ep) + { + HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); + + v = lappend(v, hc); + pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node)); + pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node)); + + wlen++; + } + + while (!pairingheap_is_empty(C)) + { + HnswNeighborArray *neighborhood; + HnswCandidate *c = ((HnswPairingHeapNode *) pairingheap_remove_first(C))->inner; + HnswCandidate *f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; + + if (c->distance > f->distance) + break; + + if (c->element->neighbors == NULL) + LoadNeighbors(c, index); + + /* Get the neighborhood at layer lc */ + neighborhood = &c->element->neighbors[lc]; + + for (int i = 0; i < neighborhood->length; i++) + { + HnswCandidate *e = &neighborhood->items[i]; + bool visited; + + /* TODO Use hash for v? */ + visited = index == NULL ? HnswListMemberByPointer(v, e) : HnswListMemberByBlock(v, e); + + if (!visited) + { + float eDistance; + + v = lappend(v, e); + + f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; + + if (index == NULL) + eDistance = GetCandidateDistance(e, q, procinfo, collation); + else + HnswLoadElement(e->element, &eDistance, &q, index, procinfo, collation, inserting); + + /* Skip if fully deleted */ + if (e->element->deleted) + continue; + + /* Skip for inserts if deleting */ + if (inserting && list_length(e->element->heaptids) == 0) + continue; + + /* Skip self for vacuuming update */ + if (skipPage != NULL && e->element->neighborPage == *skipPage) + continue; + + /* Stale read */ + if (e->element->level < lc) + continue; + + if (eDistance < f->distance || wlen < ef) + { + /* copy e */ + HnswCandidate *e2 = palloc(sizeof(HnswCandidate)); + + e2->element = e->element; + e2->distance = eDistance; + + pairingheap_add(C, &(CreatePairingHeapNode(e2)->ph_node)); + pairingheap_add(W, &(CreatePairingHeapNode(e2)->ph_node)); + wlen++; + + if (wlen > ef) + { + pairingheap_remove_first(W); + wlen--; + } + } + } + } + } + + /* Add each element of W to w */ + while (!pairingheap_is_empty(W)) + { + HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner; + + w = lappend(w, hc); + } + + return w; +} + +/* + * Create a candidate for the entry point + */ +HnswCandidate * +EntryCandidate(HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec) +{ + HnswCandidate *hc = palloc(sizeof(HnswCandidate)); + + hc->element = entryPoint; + if (index == NULL) + hc->distance = GetCandidateDistance(hc, q, procinfo, collation); + else + HnswLoadElement(hc->element, &hc->distance, &q, index, procinfo, collation, loadvec); + return hc; +} + +/* + * Find duplicate element + */ +static HnswElement +HnswFindDuplicate(HnswElement e, List *neighbors) +{ + ListCell *lc; + + foreach(lc, neighbors) + { + HnswCandidate *neighbor = lfirst(lc); + + /* Exit early since ordered by distance */ + if (vector_cmp_internal(e->vec, neighbor->element->vec) != 0) + break; + + /* Check for space */ + if (list_length(neighbor->element->heaptids) < HNSW_HEAPTIDS) + return neighbor->element; + } + + return NULL; +} + +/* + * Algorithm 1 from paper + */ +HnswElement +HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, List **updates, bool vacuuming) +{ + List *ep = NIL; + List *w; + int level = element->level; + int entryLevel; + List **newNeighbors = palloc(sizeof(List *) * (level + 1)); + Datum q = PointerGetDatum(element->vec); + HnswElement dup; + BlockNumber *skipPage = vacuuming ? &element->neighborPage : NULL; + + /* Get entry point and level */ + if (entryPoint != NULL) + { + ep = lappend(ep, EntryCandidate(entryPoint, q, index, procinfo, collation, true)); + entryLevel = entryPoint->level; + } + else + entryLevel = -1; + + for (int lc = entryLevel; lc >= level + 1; lc--) + { + w = SearchLayer(q, ep, 1, lc, index, procinfo, collation, true, skipPage); + ep = w; + } + + if (level > entryLevel) + level = entryLevel; + + for (int lc = level; lc >= 0; lc--) + { + int lm = GetLayerM(m, lc); + + w = SearchLayer(q, ep, efConstruction, lc, index, procinfo, collation, true, skipPage); + newNeighbors[lc] = SelectNeighbors(w, lm, lc, procinfo, collation, NULL); + ep = w; + } + + if (level >= 0 && !vacuuming) + { + dup = HnswFindDuplicate(element, newNeighbors[0]); + if (dup != NULL) + return dup; + } + + /* Update connections */ + for (int lc = level; lc >= 0; lc--) + { + int lm = GetLayerM(m, lc); + + AddConnections(element, newNeighbors[lc], lm, lc); + + if (!vacuuming) + UpdateConnections(element, newNeighbors[lc], lm, lc, updates, index, procinfo, collation); + } + + return NULL; +} + +/* + * Update the metapage + */ +void +UpdateMetaPage(Relation index, bool updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum) +{ + Buffer buf; + Page page; + GenericXLogState *state; + HnswMetaPage metap; + + buf = ReadBufferExtended(index, forkNum, HNSW_METAPAGE_BLKNO, RBM_NORMAL, NULL); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + metap = HnswPageGetMeta(page); + + if (updateEntry) + { + metap->entryBlkno = entryPoint->blkno; + metap->entryOffno = entryPoint->offno; + } + + if (BlockNumberIsValid(insertPage)) + metap->insertPage = insertPage; + + HnswCommitBuffer(buf, state); +} + +/* + * Add neighbors to page + */ +void +AddNeighborsToPage(Relation index, Page page, HnswElement e, HnswNeighborTuple neighbor, Size neighborsz, int m) +{ + for (int lc = e->level; lc >= 0; lc--) + { + HnswNeighborArray *neighbors = &e->neighbors[lc]; + int lm = GetLayerM(m, lc); + + for (int i = 0; i < lm; i++) + { + if (i < neighbors->length) + { + HnswCandidate *hc = &neighbors->items[i]; + + ItemPointerSet(&neighbor->indextid, hc->element->blkno, hc->element->offno); + neighbor->distance = hc->distance; + } + else + { + ItemPointerSetInvalid(&neighbor->indextid); + neighbor->distance = NAN; + } + + if (PageAddItem(page, (Item) neighbor, neighborsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + } + } +} diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c new file mode 100644 index 0000000..cc15db1 --- /dev/null +++ b/src/hnswvacuum.c @@ -0,0 +1,517 @@ +#include "postgres.h" + +#include "commands/vacuum.h" +#include "hnsw.h" +#include "storage/bufmgr.h" +#include "utils/memutils.h" + +/* + * Check if deleted list contains an index tid + */ +static bool +DeletedContains(HTAB *deleted, ItemPointer indextid) +{ + bool found; + + hash_search(deleted, indextid, HASH_FIND, &found); + return found; +} + +/* + * Remove deleted heap tids + * + * OK to remove for entry point, since always considered for searches and inserts + */ +static void +RemoveHeapTids(HnswVacuumState * vacuumstate) +{ + BlockNumber blkno = HNSW_HEAD_BLKNO; + HnswElement highestPoint = &vacuumstate->highestPoint; + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + HnswElement entryPoint = GetEntryPoint(vacuumstate->index); + + /* Store separately since highestPoint.level is uint8 */ + int highestLevel = -1; + + /* Initialize highest point */ + highestPoint->blkno = InvalidBlockNumber; + highestPoint->offno = InvalidOffsetNumber; + + while (BlockNumberIsValid(blkno)) + { + Buffer buf; + Page page; + GenericXLogState *state; + OffsetNumber offno; + OffsetNumber maxoffno; + bool updated = false; + + vacuum_delay_point(); + + buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + maxoffno = PageGetMaxOffsetNumber(page); + + /* Iterate over nodes */ + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple item = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + int idx = 0; + bool itemUpdated = false; + + if (ItemPointerIsValid(&item->heaptids[0])) + { + for (int i = 0; i < HNSW_HEAPTIDS; i++) + { + /* Stop at first unused */ + if (!ItemPointerIsValid(&item->heaptids[i])) + break; + + if (vacuumstate->callback(&item->heaptids[i], vacuumstate->callback_state)) + itemUpdated = true; + else + { + /* Move to front of list */ + item->heaptids[idx++] = item->heaptids[i]; + } + } + + if (itemUpdated) + { + Size itemsz = MAXALIGN(HNSW_ELEMENT_TUPLE_SIZE(item->vec.dim)); + + /* Mark rest as invalid */ + for (int i = idx; i < HNSW_HEAPTIDS; i++) + ItemPointerSetInvalid(&item->heaptids[i]); + + if (!PageIndexTupleOverwrite(page, offno, (Item) item, itemsz)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + updated = true; + } + } + + if (!ItemPointerIsValid(&item->heaptids[0])) + { + ItemPointerData ip; + + /* Add to deleted list */ + ItemPointerSet(&ip, blkno, offno); + + (void) hash_search(vacuumstate->deleted, &ip, HASH_ENTER, NULL); + } + else if (item->level > highestLevel && !(highestPoint->blkno == entryPoint->blkno && highestPoint->offno == entryPoint->offno)) + { + /* Keep track of highest non-entry point */ + /* TODO Keep track of closest one to entry point? */ + highestPoint->blkno = blkno; + highestPoint->offno = offno; + highestLevel = item->level; + } + } + + blkno = HnswPageGetOpaque(page)->nextblkno; + + if (updated) + { + MarkBufferDirty(buf); + GenericXLogFinish(state); + } + else + GenericXLogAbort(state); + + UnlockReleaseBuffer(buf); + } +} + +/* + * Check for deleted neighbors + */ +static bool +NeedsUpdated(HnswVacuumState * vacuumstate, HnswElement element) +{ + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + Buffer buf; + Page page; + OffsetNumber offno; + OffsetNumber maxoffno; + bool needsUpdated = false; + + buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + maxoffno = PageGetMaxOffsetNumber(page); + + /* Check neighbors */ + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswNeighborTuple ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, offno)); + + if (!ItemPointerIsValid(&ntup->indextid)) + continue; + + /* Check if in deleted list */ + if (DeletedContains(vacuumstate->deleted, &ntup->indextid)) + { + needsUpdated = true; + break; + } + } + + UnlockReleaseBuffer(buf); + + return needsUpdated; +} + +/* + * Repair graph for a single element + */ +static void +RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element) +{ + Relation index = vacuumstate->index; + Buffer buf; + Page page; + GenericXLogState *state; + int m = vacuumstate->m; + int efConstruction = vacuumstate->efConstruction; + FmgrInfo *procinfo = vacuumstate->procinfo; + Oid collation = vacuumstate->collation; + HnswElement entryPoint; + BufferAccessStrategy bas = vacuumstate->bas; + HnswNeighborTuple ntup = vacuumstate->ntup; + Size nsize = vacuumstate->nsize; + + /* Check if any neighbors point to deleted values */ + if (!NeedsUpdated(vacuumstate, element)) + return; + + /* Refresh entry point for each element */ + entryPoint = GetEntryPoint(index); + + /* Special case for entry point */ + if (element->blkno == entryPoint->blkno && element->offno == entryPoint->offno) + { + if (BlockNumberIsValid(vacuumstate->highestPoint.blkno)) + { + /* Already updated */ + if (vacuumstate->highestPoint.blkno == element->blkno && vacuumstate->highestPoint.offno == element->offno) + return; + + entryPoint = &vacuumstate->highestPoint; + } + else + entryPoint = NULL; + } + + HnswInitNeighbors(element, m); + element->heaptids = NIL; + + HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, NULL, true); + + /* Write out new neighbors on page */ + buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE); + + /* Overwrites existing page via InitPage */ + HnswInitPage(buf, page); + + /* Update neighbors */ + AddNeighborsToPage(index, page, element, ntup, nsize, m); + + /* Commit */ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); +} + +/* + * Repair graph entry point + */ +static void +RepairGraphEntryPoint(HnswVacuumState * vacuumstate) +{ + Relation index = vacuumstate->index; + HnswElement highestPoint = &vacuumstate->highestPoint; + HnswElement entryPoint; + MemoryContext oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx); + + /* Repair graph for highest non-entry point */ + /* This may not be the highest with new inserts, but should be fine */ + if (BlockNumberIsValid(highestPoint->blkno)) + { + HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); + RepairGraphElement(vacuumstate, highestPoint); + } + + entryPoint = GetEntryPoint(index); + if (entryPoint != NULL) + { + ItemPointerData epData; + + ItemPointerSet(&epData, entryPoint->blkno, entryPoint->offno); + + if (DeletedContains(vacuumstate->deleted, &epData)) + UpdateMetaPage(index, true, highestPoint, InvalidBlockNumber, MAIN_FORKNUM); + else + { + /* Highest point will be used to repair */ + HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); + RepairGraphElement(vacuumstate, entryPoint); + } + } + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(vacuumstate->tmpCtx); +} + +/* + * Repair graph for all elements + */ +static void +RepairGraph(HnswVacuumState * vacuumstate) +{ + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + BlockNumber blkno = HNSW_HEAD_BLKNO; + + RepairGraphEntryPoint(vacuumstate); + + while (BlockNumberIsValid(blkno)) + { + Buffer buf; + Page page; + OffsetNumber offno; + OffsetNumber maxoffno; + List *elements = NIL; + ListCell *lc2; + MemoryContext oldCtx; + + vacuum_delay_point(); + + oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx); + + buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + maxoffno = PageGetMaxOffsetNumber(page); + + /* Load items into memory to minimize locking */ + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple item = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + HnswElement element; + + /* Skip updating neighbors if being deleted */ + if (!ItemPointerIsValid(&item->heaptids[0])) + continue; + + /* Create an element */ + element = palloc(sizeof(HnswElementData)); + element->neighborPage = item->neighborPage; + element->level = item->level; + element->blkno = blkno; + element->offno = offno; + element->vec = palloc(VECTOR_SIZE(item->vec.dim)); + memcpy(element->vec, &item->vec, VECTOR_SIZE(item->vec.dim)); + + elements = lappend(elements, element); + } + + blkno = HnswPageGetOpaque(page)->nextblkno; + + UnlockReleaseBuffer(buf); + + /* Update neighbor pages */ + foreach(lc2, elements) + RepairGraphElement(vacuumstate, (HnswElement) lfirst(lc2)); + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(vacuumstate->tmpCtx); + } +} + +/* + * Mark items as deleted + */ +static void +MarkDeleted(HnswVacuumState * vacuumstate) +{ + BlockNumber blkno = HNSW_HEAD_BLKNO; + BlockNumber insertPage = InvalidBlockNumber; + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + + while (BlockNumberIsValid(blkno)) + { + Buffer buf; + Page page; + GenericXLogState *state; + OffsetNumber offno; + OffsetNumber maxoffno; + + vacuum_delay_point(); + + buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); + + /* + * ambulkdelete cannot delete entries from pages that are pinned by + * other backends + * + * https://www.postgresql.org/docs/current/index-locking.html + */ + LockBufferForCleanup(buf); + + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + maxoffno = PageGetMaxOffsetNumber(page); + + /* Update element and neighbors together */ + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple item = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + Size itemsz; + Buffer nbuf; + Page npage; + + if (ItemPointerIsValid(&item->heaptids[0])) + continue; + + /* Overwrite element */ + /* TODO Increment version? */ + item->deleted = 1; + MemSet(&item->vec.x, 0, item->vec.dim * sizeof(float)); + + itemsz = MAXALIGN(HNSW_ELEMENT_TUPLE_SIZE(item->vec.dim)); + if (!PageIndexTupleOverwrite(page, offno, (Item) item, itemsz)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Overwrite neighbors */ + nbuf = ReadBufferExtended(index, MAIN_FORKNUM, item->neighborPage, RBM_NORMAL, bas); + LockBuffer(nbuf, BUFFER_LOCK_EXCLUSIVE); + npage = GenericXLogRegisterBuffer(state, nbuf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(nbuf, npage); + + /* Commit */ + MarkBufferDirty(buf); + MarkBufferDirty(nbuf); + GenericXLogFinish(state); + UnlockReleaseBuffer(nbuf); + + /* Set to first free page */ + if (!BlockNumberIsValid(insertPage)) + insertPage = blkno; + + /* Prepare new xlog */ + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + } + + blkno = HnswPageGetOpaque(page)->nextblkno; + + GenericXLogAbort(state); + UnlockReleaseBuffer(buf); + } + + UpdateMetaPage(index, false, NULL, insertPage, MAIN_FORKNUM); +} + +/* + * Initialize the vacuum state + */ +static void +InitVacuumState(HnswVacuumState * vacuumstate, IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state) +{ + Relation index = info->index; + HASHCTL hash_ctl; + + if (stats == NULL) + stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult)); + + vacuumstate->index = index; + vacuumstate->stats = stats; + vacuumstate->callback = callback; + vacuumstate->callback_state = callback_state; + vacuumstate->m = HnswGetM(index); + vacuumstate->efConstruction = HnswGetEfConstruction(index); + vacuumstate->bas = GetAccessStrategy(BAS_BULKREAD); + vacuumstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + vacuumstate->collation = index->rd_indcollation[0]; + vacuumstate->nsize = MAXALIGN(sizeof(HnswNeighborTupleData)); + vacuumstate->ntup = palloc0(vacuumstate->nsize); + vacuumstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, + "Hnsw vacuum temporary context", + ALLOCSET_DEFAULT_SIZES); + + /* Create hash table */ + hash_ctl.keysize = sizeof(ItemPointerData); + hash_ctl.entrysize = sizeof(ItemPointerData); + hash_ctl.hcxt = CurrentMemoryContext; + vacuumstate->deleted = hash_create("hnswbulkdelete indextids", 256, &hash_ctl, HASH_ELEM | HASH_BLOBS | HASH_CONTEXT); +} + +/* + * Free resources + */ +static void +FreeVacuumState(HnswVacuumState * vacuumstate) +{ + hash_destroy(vacuumstate->deleted); + FreeAccessStrategy(vacuumstate->bas); + pfree(vacuumstate->ntup); + MemoryContextDelete(vacuumstate->tmpCtx); +} + +/* + * Bulk delete tuples from the index + */ +IndexBulkDeleteResult * +hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, + IndexBulkDeleteCallback callback, void *callback_state) +{ + HnswVacuumState vacuumstate; + + InitVacuumState(&vacuumstate, info, stats, callback, callback_state); + + /* Pass 1: Remove heap tids */ + RemoveHeapTids(&vacuumstate); + + /* Pass 2: Repair graph */ + RepairGraph(&vacuumstate); + + /* Pass 3: Mark as deleted */ + MarkDeleted(&vacuumstate); + + FreeVacuumState(&vacuumstate); + + return vacuumstate.stats; +} + +/* + * Clean up after a VACUUM operation + */ +IndexBulkDeleteResult * +hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats) +{ + Relation rel = info->index; + + if (info->analyze_only) + return stats; + + /* stats is NULL if ambulkdelete not called */ + /* OK to return NULL if index not changed */ + if (stats == NULL) + return NULL; + + stats->num_pages = RelationGetNumberOfBlocks(rel); + + return stats; +} diff --git a/src/ivfflat.h b/src/ivfflat.h index 8dba5ef..2c18fd4 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -3,10 +3,6 @@ #include "postgres.h" -#if PG_VERSION_NUM < 110000 -#error "Requires PostgreSQL 11+" -#endif - #include "access/generic_xlog.h" #include "access/parallel.h" #include "access/reloptions.h" diff --git a/src/vector.c b/src/vector.c index 78dc4e4..091e051 100644 --- a/src/vector.c +++ b/src/vector.c @@ -4,6 +4,7 @@ #include "catalog/pg_type.h" #include "fmgr.h" +#include "hnsw.h" #include "ivfflat.h" #include "lib/stringinfo.h" #include "libpq/pqformat.h" @@ -37,6 +38,7 @@ PG_MODULE_MAGIC; void _PG_init(void) { + HnswInit(); IvfflatInit(); } diff --git a/test/expected/hnsw_cosine.out b/test/expected/hnsw_cosine.out new file mode 100644 index 0000000..eec40d9 --- /dev/null +++ b/test/expected/hnsw_cosine.out @@ -0,0 +1,19 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_cosine_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; + val +--------- + [1,1,1] + [1,2,3] + [1,2,4] +(3 rows) + +SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector); + val +----- +(0 rows) + +DROP TABLE t; diff --git a/test/expected/hnsw_ip.out b/test/expected/hnsw_ip.out new file mode 100644 index 0000000..85a4648 --- /dev/null +++ b/test/expected/hnsw_ip.out @@ -0,0 +1,20 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_ip_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; + val +--------- + [1,2,4] + [1,2,3] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector); + val +----- +(0 rows) + +DROP TABLE t; diff --git a/test/expected/hnsw_l2.out b/test/expected/hnsw_l2.out new file mode 100644 index 0000000..4136b82 --- /dev/null +++ b/test/expected/hnsw_l2.out @@ -0,0 +1,26 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,2,4] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); + val +----- +(0 rows) + +SELECT COUNT(*) FROM t; + count +------- + 5 +(1 row) + +DROP TABLE t; diff --git a/test/expected/hnsw_options.out b/test/expected/hnsw_options.out new file mode 100644 index 0000000..be10beb --- /dev/null +++ b/test/expected/hnsw_options.out @@ -0,0 +1,25 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 3); +ERROR: value 3 out of bounds for option "m" +DETAIL: Valid values are between "4" and "100". +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101); +ERROR: value 101 out of bounds for option "m" +DETAIL: Valid values are between "4" and "100". +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 9); +ERROR: value 9 out of bounds for option "ef_construction" +DETAIL: Valid values are between "10" and "1000". +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001); +ERROR: value 1001 out of bounds for option "ef_construction" +DETAIL: Valid values are between "10" and "1000". +SHOW hnsw.ef_search; + hnsw.ef_search +---------------- + 40 +(1 row) + +SET hnsw.ef_search = 9; +ERROR: 9 is outside the valid range for parameter "hnsw.ef_search" (10 .. 1000) +SET hnsw.ef_search = 1001; +ERROR: 1001 is outside the valid range for parameter "hnsw.ef_search" (10 .. 1000) +DROP TABLE t; diff --git a/test/expected/hnsw_unlogged.out b/test/expected/hnsw_unlogged.out new file mode 100644 index 0000000..0063773 --- /dev/null +++ b/test/expected/hnsw_unlogged.out @@ -0,0 +1,13 @@ +SET enable_seqscan = off; +CREATE UNLOGGED TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,1,1] + [0,0,0] +(3 rows) + +DROP TABLE t; diff --git a/test/sql/hnsw_cosine.sql b/test/sql/hnsw_cosine.sql new file mode 100644 index 0000000..4398150 --- /dev/null +++ b/test/sql/hnsw_cosine.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_cosine_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; +SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector); + +DROP TABLE t; diff --git a/test/sql/hnsw_ip.sql b/test/sql/hnsw_ip.sql new file mode 100644 index 0000000..8a1d8a0 --- /dev/null +++ b/test/sql/hnsw_ip.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_ip_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; +SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector); + +DROP TABLE t; diff --git a/test/sql/hnsw_l2.sql b/test/sql/hnsw_l2.sql new file mode 100644 index 0000000..8664cb3 --- /dev/null +++ b/test/sql/hnsw_l2.sql @@ -0,0 +1,13 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; +SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); +SELECT COUNT(*) FROM t; + +DROP TABLE t; diff --git a/test/sql/hnsw_options.sql b/test/sql/hnsw_options.sql new file mode 100644 index 0000000..c289922 --- /dev/null +++ b/test/sql/hnsw_options.sql @@ -0,0 +1,14 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 3); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 9); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001); + +SHOW hnsw.ef_search; + +SET hnsw.ef_search = 9; +SET hnsw.ef_search = 1001; + +DROP TABLE t; diff --git a/test/sql/hnsw_unlogged.sql b/test/sql/hnsw_unlogged.sql new file mode 100644 index 0000000..2efcc95 --- /dev/null +++ b/test/sql/hnsw_unlogged.sql @@ -0,0 +1,9 @@ +SET enable_seqscan = off; + +CREATE UNLOGGED TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); + +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +DROP TABLE t; diff --git a/test/t/010_hnsw_wal.pl b/test/t/010_hnsw_wal.pl new file mode 100644 index 0000000..725ad49 --- /dev/null +++ b/test/t/010_hnsw_wal.pl @@ -0,0 +1,99 @@ +# Based on postgres/contrib/bloom/t/001_wal.pl + +# Test generic xlog record work for hnsw index replication. +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $dim = 32; + +my $node_primary; +my $node_replica; + +# Run few queries on both primary and replica and check their results match. +sub test_index_replay +{ + my ($test_name) = @_; + + # Wait for replica to catch up + my $applname = $node_replica->name; + + my $server_version_num = $node_primary->safe_psql("postgres", "SHOW server_version_num"); + my $caughtup_query = "SELECT pg_current_wal_lsn() <= replay_lsn FROM pg_stat_replication WHERE application_name = '$applname';"; + $node_primary->poll_query_until('postgres', $caughtup_query) + or die "Timed out while waiting for replica 1 to catch up"; + + my @r = (); + for (1 .. $dim) { + push(@r, rand()); + } + my $sql = join(",", @r); + + my $queries = qq( + SET enable_seqscan = off; + SELECT * FROM tst ORDER BY v <-> '[$sql]' LIMIT 10; + ); + + # Run test queries and compare their result + my $primary_result = $node_primary->safe_psql("postgres", $queries); + my $replica_result = $node_replica->safe_psql("postgres", $queries); + + is($primary_result, $replica_result, "$test_name: query result matches"); + return; +} + +# Use ARRAY[random(), random(), random(), ...] over +# SELECT array_agg(random()) FROM generate_series(1, $dim) +# to generate different values for each row +my $array_sql = join(",", ('random()') x $dim); + +# Initialize primary node +$node_primary = get_new_node('primary'); +$node_primary->init(allows_streaming => 1); +if ($dim > 32) { + # TODO use wal_keep_segments for Postgres < 13 + $node_primary->append_conf('postgresql.conf', qq(wal_keep_size = 1GB)); +} +if ($dim > 1500) { + $node_primary->append_conf('postgresql.conf', qq(maintenance_work_mem = 128MB)); +} +$node_primary->start; +my $backup_name = 'my_backup'; + +# Take backup +$node_primary->backup($backup_name); + +# Create streaming replica linking to primary +$node_replica = get_new_node('replica'); +$node_replica->init_from_backup($node_primary, $backup_name, + has_streaming => 1); +$node_replica->start; + +# Create ivfflat index on primary +$node_primary->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node_primary->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node_primary->safe_psql("postgres", + "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); +$node_primary->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); + +# Test that queries give same result +test_index_replay('initial'); + +# Run 10 cycles of table modification. Run test queries after each modification. +for my $i (1 .. 10) +{ + $node_primary->safe_psql("postgres", "DELETE FROM tst WHERE i = $i;"); + test_index_replay("delete $i"); + $node_primary->safe_psql("postgres", "VACUUM tst;"); + test_index_replay("vacuum $i"); + my ($start, $end) = (10001 + ($i - 1) * 1000, 10000 + $i * 1000); + $node_primary->safe_psql("postgres", + "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series($start, $end) i;" + ); + test_index_replay("insert $i"); +} + +done_testing(); diff --git a/test/t/011_hnsw_vacuum.pl b/test/t/011_hnsw_vacuum.pl new file mode 100644 index 0000000..b5e49da --- /dev/null +++ b/test/t/011_hnsw_vacuum.pl @@ -0,0 +1,43 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $dim = 3; + +my @r = (); +for (1 .. $dim) { + my $v = int(rand(1000)) + 1; + push(@r, "i % $v"); +} +my $array_sql = join(", ", @r); + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table and index +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); + +# Get size +my $size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); + +# Delete all, vacuum, and insert same data +$node->safe_psql("postgres", "DELETE FROM tst;"); +$node->safe_psql("postgres", "VACUUM tst;"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); + +# Check size +my $new_size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); +is($size, $new_size, "size does not change"); + +done_testing(); diff --git a/test/t/012_hnsw_build_recall.pl b/test/t/012_hnsw_build_recall.pl new file mode 100644 index 0000000..911decf --- /dev/null +++ b/test/t/012_hnsw_build_recall.pl @@ -0,0 +1,96 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) { + if (exists($actual_set{$_})) { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" +); + +# Generate queries +for (1..20) { + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "[$r1,$r2,$r3]"); +} + +# Check each index type +my @operators = ("<->", "<#>", "<=>"); + +foreach (@operators) { + my $operator = $_; + + # Get exact results + @expected = (); + foreach (@queries) { + my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); + push(@expected, $res); + } + + # Add index + my $opclass; + if ($operator eq "<->") { + $opclass = "vector_l2_ops"; + } elsif ($operator eq "<#>") { + $opclass = "vector_ip_ops"; + } else { + $opclass = "vector_cosine_ops"; + } + $node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v $opclass);"); + + if ($operator eq "<#>") { + test_recall(0.85, $operator); + } else { + test_recall(0.99, $operator); + } +} + +done_testing(); diff --git a/test/t/013_hnsw_insert_recall.pl b/test/t/013_hnsw_insert_recall.pl new file mode 100644 index 0000000..280d5bc --- /dev/null +++ b/test/t/013_hnsw_insert_recall.pl @@ -0,0 +1,103 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) { + if (exists($actual_set{$_})) { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); + +# Generate queries +for (1..20) { + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "[$r1,$r2,$r3]"); +} + +# Check each index type +my @operators = ("<->", "<#>", "<=>"); + +foreach (@operators) { + my $operator = $_; + + # Add index + my $opclass; + if ($operator eq "<->") { + $opclass = "vector_l2_ops"; + } elsif ($operator eq "<#>") { + $opclass = "vector_ip_ops"; + } else { + $opclass = "vector_cosine_ops"; + } + $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v $opclass);"); + + $node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" + ); + + # Get exact results + @expected = (); + foreach (@queries) { + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit; + )); + push(@expected, $res); + } + + if ($operator eq "<#>") { + test_recall(0.85, $operator); + } else { + test_recall(0.99, $operator); + } + + $node->safe_psql("postgres", "DROP INDEX idx;"); + $node->safe_psql("postgres", "TRUNCATE tst;"); +} + +done_testing(); diff --git a/test/t/014_hnsw_inserts.pl b/test/t/014_hnsw_inserts.pl new file mode 100644 index 0000000..f6428be --- /dev/null +++ b/test/t/014_hnsw_inserts.pl @@ -0,0 +1,54 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $dim = 768; + +my $array_sql = join(",", ('random()') x $dim); + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table and index +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (v vector($dim));"); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); + +$node->pgbench( + "--no-vacuum --client=5 --transactions=100", + 0, + [qr{actually processed}], + [qr{^$}], + "concurrent INSERTs", + { + "007_inserts" => "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10) i;" + } +); + +sub idx_scan +{ + # Stats do not update instantaneously + # https://www.postgresql.org/docs/current/monitoring-stats.html#MONITORING-STATS-VIEWS + sleep(1); + $node->safe_psql("postgres", "SELECT idx_scan FROM pg_stat_user_indexes WHERE indexrelid = 'tst_v_idx'::regclass;"); +} + +my $expected = 5 * 100 * 10; + +my $count = $node->safe_psql("postgres", "SELECT COUNT(*) FROM tst;"); +is($count, $expected); +is(idx_scan(), 0); + +$count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = 400; + SELECT COUNT(*) FROM (SELECT v FROM tst ORDER BY v <-> (SELECT v FROM tst LIMIT 1)) t; +)); +is($count, 400); +is(idx_scan(), 1); + +done_testing();