From 51d292c93dff82f617c7a00df62eaaa4f6d73455 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Tue, 8 Aug 2023 16:42:47 -0700 Subject: [PATCH] Added HNSW index type - #181 --- CHANGELOG.md | 1 + Makefile | 2 +- Makefile.win | 2 +- sql/vector--0.4.4--0.5.0.sql | 23 + sql/vector.sql | 25 +- src/hnsw.c | 224 +++++++ src/hnsw.h | 301 ++++++++++ src/hnswbuild.c | 506 ++++++++++++++++ src/hnswinsert.c | 491 ++++++++++++++++ src/hnswscan.c | 212 +++++++ src/hnswutils.c | 982 +++++++++++++++++++++++++++++++ src/hnswvacuum.c | 584 ++++++++++++++++++ src/ivfflat.h | 4 - src/vector.c | 2 + test/expected/hnsw_cosine.out | 26 + test/expected/hnsw_ip.out | 21 + test/expected/hnsw_l2.out | 30 + test/expected/hnsw_options.out | 25 + test/expected/hnsw_unlogged.out | 13 + test/sql/hnsw_cosine.sql | 13 + test/sql/hnsw_ip.sql | 12 + test/sql/hnsw_l2.sql | 13 + test/sql/hnsw_options.sql | 14 + test/sql/hnsw_unlogged.sql | 9 + test/t/010_hnsw_wal.pl | 99 ++++ test/t/011_hnsw_vacuum.pl | 43 ++ test/t/012_hnsw_build_recall.pl | 96 +++ test/t/013_hnsw_insert_recall.pl | 103 ++++ test/t/014_hnsw_inserts.pl | 58 ++ 29 files changed, 3927 insertions(+), 7 deletions(-) create mode 100644 src/hnsw.c create mode 100644 src/hnsw.h create mode 100644 src/hnswbuild.c create mode 100644 src/hnswinsert.c create mode 100644 src/hnswscan.c create mode 100644 src/hnswutils.c create mode 100644 src/hnswvacuum.c create mode 100644 test/expected/hnsw_cosine.out create mode 100644 test/expected/hnsw_ip.out create mode 100644 test/expected/hnsw_l2.out create mode 100644 test/expected/hnsw_options.out create mode 100644 test/expected/hnsw_unlogged.out create mode 100644 test/sql/hnsw_cosine.sql create mode 100644 test/sql/hnsw_ip.sql create mode 100644 test/sql/hnsw_l2.sql create mode 100644 test/sql/hnsw_options.sql create mode 100644 test/sql/hnsw_unlogged.sql create mode 100644 test/t/010_hnsw_wal.pl create mode 100644 test/t/011_hnsw_vacuum.pl create mode 100644 test/t/012_hnsw_build_recall.pl create mode 100644 test/t/013_hnsw_insert_recall.pl create mode 100644 test/t/014_hnsw_inserts.pl diff --git a/CHANGELOG.md b/CHANGELOG.md index 532f40b..27c4bbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.5.0 (unreleased) +- Added HNSW index type - Added support for parallel index builds - Added `l1_distance` function - Added element-wise multiplication for vectors diff --git a/Makefile b/Makefile index ff26f56..09908b3 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ EXTVERSION = 0.4.4 MODULE_big = vector DATA = $(wildcard sql/*--*.sql) -OBJS = src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o +OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o TESTS = $(wildcard test/sql/*.sql) REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS)) diff --git a/Makefile.win b/Makefile.win index 8ceb572..e69810b 100644 --- a/Makefile.win +++ b/Makefile.win @@ -1,7 +1,7 @@ EXTENSION = vector EXTVERSION = 0.4.4 -OBJS = src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj +OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged REGRESS_OPTS = --inputdir=test --load-extension=vector diff --git a/sql/vector--0.4.4--0.5.0.sql b/sql/vector--0.4.4--0.5.0.sql index 3fe3365..48572bf 100644 --- a/sql/vector--0.4.4--0.5.0.sql +++ b/sql/vector--0.4.4--0.5.0.sql @@ -18,3 +18,26 @@ CREATE AGGREGATE sum(vector) ( COMBINEFUNC = vector_add, PARALLEL = SAFE ); + +CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler; + +COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method'; + +CREATE OPERATOR CLASS vector_l2_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_l2_squared_distance(vector, vector); + +CREATE OPERATOR CLASS vector_ip_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector); + +CREATE OPERATOR CLASS vector_cosine_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector), + FUNCTION 2 vector_norm(vector); diff --git a/sql/vector.sql b/sql/vector.sql index 91f594c..137931f 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -227,7 +227,7 @@ CREATE OPERATOR > ( RESTRICT = scalargtsel, JOIN = scalargtjoinsel ); --- access method +-- access methods CREATE FUNCTION ivfflathandler(internal) RETURNS index_am_handler AS 'MODULE_PATHNAME' LANGUAGE C; @@ -236,6 +236,13 @@ CREATE ACCESS METHOD ivfflat TYPE INDEX HANDLER ivfflathandler; COMMENT ON ACCESS METHOD ivfflat IS 'ivfflat index access method'; +CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler; + +COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method'; + -- opclasses CREATE OPERATOR CLASS vector_ops @@ -267,3 +274,19 @@ CREATE OPERATOR CLASS vector_cosine_ops FUNCTION 2 vector_norm(vector), FUNCTION 3 vector_spherical_distance(vector, vector), FUNCTION 4 vector_norm(vector); + +CREATE OPERATOR CLASS vector_l2_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_l2_squared_distance(vector, vector); + +CREATE OPERATOR CLASS vector_ip_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector); + +CREATE OPERATOR CLASS vector_cosine_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector), + FUNCTION 2 vector_norm(vector); diff --git a/src/hnsw.c b/src/hnsw.c new file mode 100644 index 0000000..6248aa8 --- /dev/null +++ b/src/hnsw.c @@ -0,0 +1,224 @@ +#include "postgres.h" + +#include +#include + +#include "access/amapi.h" +#include "commands/vacuum.h" +#include "hnsw.h" +#include "utils/guc.h" +#include "utils/selfuncs.h" + +#if PG_VERSION_NUM >= 120000 +#include "commands/progress.h" +#endif + +int hnsw_ef_search; +static relopt_kind hnsw_relopt_kind; + +/* + * Initialize index options and variables + */ +void +HnswInit(void) +{ + hnsw_relopt_kind = add_reloption_kind(); + add_int_reloption(hnsw_relopt_kind, "m", "Max number of connections", + HNSW_DEFAULT_M, HNSW_MIN_M, HNSW_MAX_M +#if PG_VERSION_NUM >= 130000 + ,AccessExclusiveLock +#endif + ); + add_int_reloption(hnsw_relopt_kind, "ef_construction", "Size of the dynamic candidate list for construction", + HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION +#if PG_VERSION_NUM >= 130000 + ,AccessExclusiveLock +#endif + ); + + DefineCustomIntVariable("hnsw.ef_search", "Sets the size of the dynamic candidate list for search", + "Valid range is 10..1000.", &hnsw_ef_search, + HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL); +} + +/* + * Get the name of index build phase + */ +#if PG_VERSION_NUM >= 120000 +static char * +hnswbuildphasename(int64 phasenum) +{ + switch (phasenum) + { + case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE: + return "initializing"; + case PROGRESS_HNSW_PHASE_LOAD: + return "loading tuples"; + default: + return NULL; + } +} +#endif + +/* + * Estimate the cost of an index scan + */ +static void +hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, + Cost *indexStartupCost, Cost *indexTotalCost, + Selectivity *indexSelectivity, double *indexCorrelation, + double *indexPages) +{ + GenericCosts costs; + int m; + int entryLevel; + Relation index; +#if PG_VERSION_NUM < 120000 + List *qinfos; +#endif + + /* Never use index without order */ + if (path->indexorderbys == NULL) + { + *indexStartupCost = DBL_MAX; + *indexTotalCost = DBL_MAX; + *indexSelectivity = 0; + *indexCorrelation = 0; + *indexPages = 0; + return; + } + + MemSet(&costs, 0, sizeof(costs)); + + index = index_open(path->indexinfo->indexoid, NoLock); + m = HnswGetM(index); + index_close(index, NoLock); + + /* Approximate entry level */ + entryLevel = (int) -log(1.0 / path->indexinfo->tuples) * HnswGetMl(m); + + /* TODO Improve estimate of visited tuples (currently underestimates) */ + /* Account for number of tuples (or entry level), m, and ef_search */ + costs.numIndexTuples = (entryLevel + 2) * m; + +#if PG_VERSION_NUM >= 120000 + genericcostestimate(root, path, loop_count, &costs); +#else + qinfos = deconstruct_indexquals(path); + genericcostestimate(root, path, loop_count, qinfos, &costs); +#endif + + /* Use total cost since most work happens before first tuple is returned */ + *indexStartupCost = costs.indexTotalCost; + *indexTotalCost = costs.indexTotalCost; + *indexSelectivity = costs.indexSelectivity; + *indexCorrelation = costs.indexCorrelation; + *indexPages = costs.numIndexPages; +} + +/* + * Parse and validate the reloptions + */ +static bytea * +hnswoptions(Datum reloptions, bool validate) +{ + static const relopt_parse_elt tab[] = { + {"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)}, + {"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)}, + }; + +#if PG_VERSION_NUM >= 130000 + return (bytea *) build_reloptions(reloptions, validate, + hnsw_relopt_kind, + sizeof(HnswOptions), + tab, lengthof(tab)); +#else + relopt_value *options; + int numoptions; + HnswOptions *rdopts; + + options = parseRelOptions(reloptions, validate, hnsw_relopt_kind, &numoptions); + rdopts = allocateReloptStruct(sizeof(HnswOptions), options, numoptions); + fillRelOptions((void *) rdopts, sizeof(HnswOptions), options, numoptions, + validate, tab, lengthof(tab)); + + return (bytea *) rdopts; +#endif +} + +/* + * Validate catalog entries for the specified operator class + */ +static bool +hnswvalidate(Oid opclassoid) +{ + return true; +} + +/* + * Define index handler + * + * See https://www.postgresql.org/docs/current/index-api.html + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(hnswhandler); +Datum +hnswhandler(PG_FUNCTION_ARGS) +{ + IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); + + amroutine->amstrategies = 0; + amroutine->amsupport = 2; +#if PG_VERSION_NUM >= 130000 + amroutine->amoptsprocnum = 0; +#endif + amroutine->amcanorder = false; + amroutine->amcanorderbyop = true; + amroutine->amcanbackward = false; /* can change direction mid-scan */ + amroutine->amcanunique = false; + amroutine->amcanmulticol = false; + amroutine->amoptionalkey = true; + amroutine->amsearcharray = false; + amroutine->amsearchnulls = false; + amroutine->amstorage = false; + amroutine->amclusterable = false; + amroutine->ampredlocks = false; + amroutine->amcanparallel = false; + amroutine->amcaninclude = false; +#if PG_VERSION_NUM >= 130000 + amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */ + amroutine->amparallelvacuumoptions = VACUUM_OPTION_PARALLEL_BULKDEL; +#endif + amroutine->amkeytype = InvalidOid; + + /* Interface functions */ + amroutine->ambuild = hnswbuild; + amroutine->ambuildempty = hnswbuildempty; + amroutine->aminsert = hnswinsert; + amroutine->ambulkdelete = hnswbulkdelete; + amroutine->amvacuumcleanup = hnswvacuumcleanup; + amroutine->amcanreturn = NULL; /* tuple not included in heapsort */ + amroutine->amcostestimate = hnswcostestimate; + amroutine->amoptions = hnswoptions; + amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */ +#if PG_VERSION_NUM >= 120000 + amroutine->ambuildphasename = hnswbuildphasename; +#endif + amroutine->amvalidate = hnswvalidate; +#if PG_VERSION_NUM >= 140000 + amroutine->amadjustmembers = NULL; +#endif + amroutine->ambeginscan = hnswbeginscan; + amroutine->amrescan = hnswrescan; + amroutine->amgettuple = hnswgettuple; + amroutine->amgetbitmap = NULL; + amroutine->amendscan = hnswendscan; + amroutine->ammarkpos = NULL; + amroutine->amrestrpos = NULL; + + /* Interface functions to support parallel index scans */ + amroutine->amestimateparallelscan = NULL; + amroutine->aminitparallelscan = NULL; + amroutine->amparallelrescan = NULL; + + PG_RETURN_POINTER(amroutine); +} diff --git a/src/hnsw.h b/src/hnsw.h new file mode 100644 index 0000000..56f2ccf --- /dev/null +++ b/src/hnsw.h @@ -0,0 +1,301 @@ +#ifndef HNSW_H +#define HNSW_H + +#include "postgres.h" + +#include "access/generic_xlog.h" +#include "access/reloptions.h" +#include "nodes/execnodes.h" +#include "port.h" /* for random() */ +#include "utils/sampling.h" +#include "vector.h" + +#if PG_VERSION_NUM < 110000 +#error "Requires PostgreSQL 11+" +#endif + +#define HNSW_MAX_DIM 2000 + +/* Support functions */ +#define HNSW_DISTANCE_PROC 1 +#define HNSW_NORM_PROC 2 + +#define HNSW_VERSION 1 +#define HNSW_MAGIC_NUMBER 0xA953A953 +#define HNSW_PAGE_ID 0xFF85 + +/* Preserved page numbers */ +#define HNSW_METAPAGE_BLKNO 0 +#define HNSW_HEAD_BLKNO 1 /* first element page */ + +#define HNSW_DEFAULT_M 16 +#define HNSW_MIN_M 4 +#define HNSW_MAX_M 100 +#define HNSW_DEFAULT_EF_CONSTRUCTION 40 +#define HNSW_MIN_EF_CONSTRUCTION 10 +#define HNSW_MAX_EF_CONSTRUCTION 1000 +#define HNSW_DEFAULT_EF_SEARCH 40 +#define HNSW_MIN_EF_SEARCH 10 +#define HNSW_MAX_EF_SEARCH 1000 + +#define HNSW_ELEMENT_TUPLE_TYPE 1 +#define HNSW_NEIGHBOR_TUPLE_TYPE 2 + +/* Make graph robust against non-HOT updates */ +#define HNSW_HEAPTIDS 10 + +/* Build phases */ +/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ +#define PROGRESS_HNSW_PHASE_LOAD 2 + +#define HNSW_ELEMENT_TUPLE_SIZE(_dim) MAXALIGN(offsetof(HnswElementTupleData, vec) + VECTOR_SIZE(_dim)) +#define HNSW_NEIGHBOR_TUPLE_SIZE(level, m) MAXALIGN(offsetof(HnswNeighborTupleData, neighbors) + ((level) + 2) * (m) * sizeof(HnswNeighborTupleItem)) + +#define HnswPageGetOpaque(page) ((HnswPageOpaque) PageGetSpecialPointer(page)) +#define HnswPageGetMeta(page) ((HnswMetaPageData *) PageGetContents(page)) + +#if PG_VERSION_NUM >= 150000 +#define RandomDouble() pg_prng_double(&pg_global_prng_state) +#else +#define RandomDouble() (((double) random()) / MAX_RANDOM_VALUE) +#endif + +#if PG_VERSION_NUM < 130000 +#define list_delete_last(list) list_truncate(list, list_length(list) - 1) +#define list_sort(list, cmp) list_qsort(list, cmp) +#endif + +#define HnswIsElementTuple(tup) ((tup)->type == HNSW_ELEMENT_TUPLE_TYPE) +#define HnswIsNeighborTuple(tup) ((tup)->type == HNSW_NEIGHBOR_TUPLE_TYPE) + +#define HnswGetLayerM(m, layer) (layer == 0 ? m * 2 : m) +#define HnswGetMl(m) (1 / log(m)) + +/* Variables */ +extern int hnsw_ef_search; + +typedef struct HnswNeighborArray HnswNeighborArray; + +typedef struct HnswElementData +{ + List *heaptids; + uint8 level; + uint8 deleted; + HnswNeighborArray *neighbors; + BlockNumber blkno; + OffsetNumber offno; + OffsetNumber neighborOffno; + BlockNumber neighborPage; + Vector *vec; +} HnswElementData; + +typedef HnswElementData * HnswElement; + +typedef struct HnswCandidate +{ + HnswElement element; + float distance; +} HnswCandidate; + +typedef struct HnswNeighborArray +{ + int length; + HnswCandidate *items; +} HnswNeighborArray; + +typedef struct HnswUpdate +{ + HnswCandidate hc; + int level; + int index; +} HnswUpdate; + +typedef struct HnswPairingHeapNode +{ + pairingheap_node ph_node; + HnswCandidate *inner; +} HnswPairingHeapNode; + +/* HNSW index options */ +typedef struct HnswOptions +{ + int32 vl_len_; /* varlena header (do not touch directly!) */ + int m; /* number of connections */ + int efConstruction; /* size of dynamic candidate list */ +} HnswOptions; + +typedef struct HnswBuildState +{ + /* Info */ + Relation heap; + Relation index; + IndexInfo *indexInfo; + ForkNumber forkNum; + + /* Settings */ + int dimensions; + int m; + int efConstruction; + + /* Statistics */ + double indtuples; + double reltuples; + + /* Support functions */ + FmgrInfo *procinfo; + FmgrInfo *normprocinfo; + Oid collation; + + /* Variables */ + List *elements; + HnswElement entryPoint; + double ml; + int maxLevel; + double maxInMemoryElements; + bool flushed; + Vector *normvec; + + /* Memory */ + MemoryContext tmpCtx; +} HnswBuildState; + +typedef struct HnswMetaPageData +{ + uint32 magicNumber; + uint32 version; + uint32 dimensions; + uint16 m; + uint16 efConstruction; + BlockNumber entryBlkno; + OffsetNumber entryOffno; + int16 entryLevel; + BlockNumber insertPage; +} HnswMetaPageData; + +typedef HnswMetaPageData * HnswMetaPage; + +typedef struct HnswPageOpaqueData +{ + BlockNumber nextblkno; + uint16 unused; + uint16 page_id; /* for identification of HNSW indexes */ +} HnswPageOpaqueData; + +typedef HnswPageOpaqueData * HnswPageOpaque; + +typedef struct HnswElementTupleData +{ + uint8 type; + uint8 level; + uint8 deleted; + uint8 unused; + ItemPointerData heaptids[HNSW_HEAPTIDS]; + ItemPointerData neighbortid; + uint16 unused2; + Vector vec; +} HnswElementTupleData; + +typedef HnswElementTupleData * HnswElementTuple; + +typedef struct HnswNeighborTupleItem +{ + ItemPointerData indextid; + uint16 unused; + float distance; /* improves performance of inserts */ +} HnswNeighborTupleItem; + +typedef struct HnswNeighborTupleData +{ + uint8 type; + uint8 unused; + uint16 count; + HnswNeighborTupleItem neighbors[FLEXIBLE_ARRAY_MEMBER]; +} HnswNeighborTupleData; + +typedef HnswNeighborTupleData * HnswNeighborTuple; + +typedef struct HnswScanOpaqueData +{ + bool first; + Buffer buf; + List *w; + MemoryContext tmpCtx; + + /* Support functions */ + FmgrInfo *procinfo; + FmgrInfo *normprocinfo; + Oid collation; +} HnswScanOpaqueData; + +typedef HnswScanOpaqueData * HnswScanOpaque; + +typedef struct HnswVacuumState +{ + /* Info */ + Relation index; + IndexBulkDeleteResult *stats; + IndexBulkDeleteCallback callback; + void *callback_state; + + /* Settings */ + int m; + int efConstruction; + + /* Support functions */ + FmgrInfo *procinfo; + Oid collation; + + /* Variables */ + HTAB *deleted; + BufferAccessStrategy bas; + HnswNeighborTuple ntup; + HnswElementData highestPoint; + + /* Memory */ + MemoryContext tmpCtx; +} HnswVacuumState; + +/* Methods */ +int HnswGetM(Relation index); +int HnswGetEfConstruction(Relation index); +FmgrInfo *HnswOptionalProcInfo(Relation rel, uint16 procnum); +bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); +void HnswCommitBuffer(Buffer buf, GenericXLogState *state); +Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); +void HnswInitPage(Buffer buf, Page page); +void HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state); +void HnswInit(void); +List *HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool inserting, BlockNumber *skipPage, OffsetNumber *skipOffno); +HnswElement HnswGetEntryPoint(Relation index); +HnswElement HnswInitElement(ItemPointer tid, int m, double ml, int maxLevel); +void HnswFreeElement(HnswElement element); +HnswElement HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, List **updates, bool vacuuming); +HnswCandidate *HnswEntryCandidate(HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadvec); +void HnswUpdateMetaPage(Relation index, bool updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum); +void HnswSetNeighborTuple(HnswNeighborTuple ntup, HnswElement e, int m); +void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); +void HnswInitNeighbors(HnswElement element, int m); +bool HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel); +void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec); +void HnswSetElementTuple(HnswElementTuple etup, HnswElement element); + +/* Index access methods */ +IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo); +void hnswbuildempty(Relation index); +bool hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique +#if PG_VERSION_NUM >= 140000 + ,bool indexUnchanged +#endif + ,IndexInfo *indexInfo +); +IndexBulkDeleteResult *hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state); +IndexBulkDeleteResult *hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats); +IndexScanDesc hnswbeginscan(Relation index, int nkeys, int norderbys); +void hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys); +bool hnswgettuple(IndexScanDesc scan, ScanDirection dir); +void hnswendscan(IndexScanDesc scan); + +/* Ensure fits in uint8 */ +#define HnswGetMaxLevel(m) Min(((BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData)) - offsetof(HnswNeighborTupleData, neighbors) - sizeof(ItemIdData)) / (sizeof(HnswNeighborTupleItem)) / m) - 2, 255) + +#endif diff --git a/src/hnswbuild.c b/src/hnswbuild.c new file mode 100644 index 0000000..2341cdb --- /dev/null +++ b/src/hnswbuild.c @@ -0,0 +1,506 @@ +#include "postgres.h" + +#include + +#include "catalog/index.h" +#include "hnsw.h" +#include "miscadmin.h" +#include "lib/pairingheap.h" +#include "nodes/pg_list.h" +#include "storage/bufmgr.h" +#include "utils/memutils.h" + +#if PG_VERSION_NUM >= 140000 +#include "utils/backend_progress.h" +#elif PG_VERSION_NUM >= 120000 +#include "pgstat.h" +#endif + +#if PG_VERSION_NUM >= 120000 +#include "access/tableam.h" +#include "commands/progress.h" +#else +#define PROGRESS_CREATEIDX_TUPLES_DONE 0 +#endif + +#if PG_VERSION_NUM >= 130000 +#define CALLBACK_ITEM_POINTER ItemPointer tid +#else +#define CALLBACK_ITEM_POINTER HeapTuple hup +#endif + +#if PG_VERSION_NUM >= 120000 +#define UpdateProgress(index, val) pgstat_progress_update_param(index, val) +#else +#define UpdateProgress(index, val) ((void)val) +#endif + +/* + * Create the metapage + */ +static void +CreateMetaPage(HnswBuildState * buildstate) +{ + Relation index = buildstate->index; + ForkNumber forkNum = buildstate->forkNum; + Buffer buf; + Page page; + GenericXLogState *state; + HnswMetaPage metap; + + buf = HnswNewBuffer(index, forkNum); + HnswInitRegisterPage(index, &buf, &page, &state); + + /* Set metapage data */ + metap = HnswPageGetMeta(page); + metap->magicNumber = HNSW_MAGIC_NUMBER; + metap->version = HNSW_VERSION; + metap->dimensions = buildstate->dimensions; + metap->m = buildstate->m; + metap->efConstruction = buildstate->efConstruction; + metap->entryBlkno = InvalidBlockNumber; + metap->entryOffno = InvalidOffsetNumber; + metap->insertPage = InvalidBlockNumber; + ((PageHeader) page)->pd_lower = + ((char *) metap + sizeof(HnswMetaPageData)) - (char *) page; + + HnswCommitBuffer(buf, state); +} + +/* + * Add a new page + */ +static void +HnswBuildAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum) +{ + /* Add a new page */ + Buffer newbuf = HnswNewBuffer(index, forkNum); + + /* Update previous page */ + HnswPageGetOpaque(*page)->nextblkno = BufferGetBlockNumber(newbuf); + + /* Commit */ + MarkBufferDirty(*buf); + GenericXLogFinish(*state); + UnlockReleaseBuffer(*buf); + + /* Can take a while, so ensure we can interrupt */ + /* Needs to be called when no buffer locks are held */ + LockBuffer(newbuf, BUFFER_LOCK_UNLOCK); + CHECK_FOR_INTERRUPTS(); + LockBuffer(newbuf, BUFFER_LOCK_EXCLUSIVE); + + /* Prepare new page */ + *buf = newbuf; + *state = GenericXLogStart(index); + *page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(*buf, *page); +} + +/* + * Create element pages + */ +static void +CreateElementPages(HnswBuildState * buildstate) +{ + Relation index = buildstate->index; + ForkNumber forkNum = buildstate->forkNum; + int dimensions = buildstate->dimensions; + Size etupSize; + Size maxSize; + HnswElementTuple etup; + HnswNeighborTuple ntup; + BlockNumber insertPage; + Buffer buf; + Page page; + GenericXLogState *state; + ListCell *lc; + + /* Calculate sizes */ + maxSize = BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData)); + etupSize = HNSW_ELEMENT_TUPLE_SIZE(dimensions); + + /* Allocate once */ + etup = palloc0(etupSize); + ntup = palloc0(maxSize); + + /* Prepare first page */ + buf = HnswNewBuffer(index, forkNum); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(buf, page); + + foreach(lc, buildstate->elements) + { + HnswElement element = lfirst(lc); + Size ntupSize; + Size combinedSize; + + HnswSetElementTuple(etup, element); + + /* Calculate sizes */ + ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, buildstate->m); + combinedSize = etupSize + ntupSize + sizeof(ItemIdData); + + /* Keep element and neighbors on the same page if possible */ + if (PageGetFreeSpace(page) < etupSize || (combinedSize <= maxSize && PageGetFreeSpace(page) < combinedSize)) + HnswBuildAppendPage(index, &buf, &page, &state, forkNum); + + /* Calculate offsets */ + element->blkno = BufferGetBlockNumber(buf); + element->offno = OffsetNumberNext(PageGetMaxOffsetNumber(page)); + if (combinedSize <= maxSize) + { + element->neighborPage = element->blkno; + element->neighborOffno = OffsetNumberNext(element->offno); + } + else + { + element->neighborPage = element->blkno + 1; + element->neighborOffno = FirstOffsetNumber; + } + + ItemPointerSet(&etup->neighbortid, element->neighborPage, element->neighborOffno); + + /* Add element */ + if (PageAddItem(page, (Item) etup, etupSize, InvalidOffsetNumber, false, false) != element->offno) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Add new page if needed */ + if (PageGetFreeSpace(page) < ntupSize) + HnswBuildAppendPage(index, &buf, &page, &state, forkNum); + + /* Add placeholder for neighbors */ + if (PageAddItem(page, (Item) ntup, ntupSize, InvalidOffsetNumber, false, false) != element->neighborOffno) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + } + + insertPage = BufferGetBlockNumber(buf); + + /* Commit */ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + + HnswUpdateMetaPage(index, true, buildstate->entryPoint, insertPage, forkNum); + + pfree(etup); + pfree(ntup); +} + +/* + * Create neighbor pages + */ +static void +CreateNeighborPages(HnswBuildState * buildstate) +{ + Relation index = buildstate->index; + ForkNumber forkNum = buildstate->forkNum; + int m = buildstate->m; + ListCell *lc; + HnswNeighborTuple ntup; + + /* Allocate once */ + ntup = palloc0(BLCKSZ); + + foreach(lc, buildstate->elements) + { + HnswElement e = lfirst(lc); + Buffer buf; + Page page; + GenericXLogState *state; + Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m); + + /* Can take a while, so ensure we can interrupt */ + /* Needs to be called when no buffer locks are held */ + CHECK_FOR_INTERRUPTS(); + + buf = ReadBufferExtended(index, forkNum, e->neighborPage, RBM_NORMAL, NULL); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + HnswSetNeighborTuple(ntup, e, m); + + if (!PageIndexTupleOverwrite(page, e->neighborOffno, (Item) ntup, ntupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + } + + pfree(ntup); +} + +/* + * Free elements + */ +static void +FreeElements(HnswBuildState * buildstate) +{ + ListCell *lc; + + foreach(lc, buildstate->elements) + HnswFreeElement(lfirst(lc)); + + list_free(buildstate->elements); +} + +/* + * Flush pages + */ +static void +FlushPages(HnswBuildState * buildstate) +{ + CreateMetaPage(buildstate); + CreateElementPages(buildstate); + CreateNeighborPages(buildstate); + + buildstate->flushed = true; + FreeElements(buildstate); +} + +/* + * Insert tuple + */ +static bool +InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState * buildstate, HnswElement * dup) +{ + FmgrInfo *procinfo = buildstate->procinfo; + Oid collation = buildstate->collation; + HnswElement entryPoint = buildstate->entryPoint; + int efConstruction = buildstate->efConstruction; + int m = buildstate->m; + + /* Detoast once for all calls */ + Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + + /* Normalize if needed */ + if (buildstate->normprocinfo != NULL) + { + if (!HnswNormValue(buildstate->normprocinfo, collation, &value, buildstate->normvec)) + return false; + } + + /* Copy value to element so accessible outside of memory context */ + memcpy(element->vec, DatumGetVector(value), VECTOR_SIZE(buildstate->dimensions)); + + /* Insert element in graph */ + *dup = HnswInsertElement(element, entryPoint, NULL, procinfo, collation, m, efConstruction, NULL, false); + + /* Update entry point if needed */ + if (*dup == NULL && (entryPoint == NULL || element->level > entryPoint->level)) + buildstate->entryPoint = element; + + UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples); + + return *dup == NULL; +} + +/* + * Callback for table_index_build_scan + */ +static void +BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, + bool *isnull, bool tupleIsAlive, void *state) +{ + HnswBuildState *buildstate = (HnswBuildState *) state; + MemoryContext oldCtx; + HnswElement element; + HnswElement dup = NULL; + bool inserted; + +#if PG_VERSION_NUM < 130000 + ItemPointer tid = &hup->t_self; +#endif + + /* Skip nulls */ + if (isnull[0]) + return; + + if (buildstate->indtuples >= buildstate->maxInMemoryElements) + { + if (!buildstate->flushed) + { + ereport(NOTICE, + (errmsg("hnsw graph no longer fits into maintenance_work_mem after " INT64_FORMAT " tuples", (int64) buildstate->indtuples), + errdetail("Building will take significantly more time."), + errhint("Increase maintenance_work_mem to speed up builds."))); + + FlushPages(buildstate); + } + + oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); + + if (HnswInsertTuple(buildstate->index, values, isnull, tid, buildstate->heap)) + UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples); + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(buildstate->tmpCtx); + + return; + } + + /* Allocate necessary memory outside of memory context */ + element = HnswInitElement(tid, buildstate->m, buildstate->ml, buildstate->maxLevel); + element->vec = palloc(VECTOR_SIZE(buildstate->dimensions)); + + /* Use memory context since detoast can allocate */ + oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); + + /* Insert tuple */ + inserted = InsertTuple(index, values, element, buildstate, &dup); + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(buildstate->tmpCtx); + + /* Add outside memory context */ + if (dup != NULL) + HnswAddHeapTid(dup, tid); + + /* Add to buildstate or free */ + if (inserted) + buildstate->elements = lappend(buildstate->elements, element); + else + HnswFreeElement(element); +} + +/* + * Get the max number of elements that fit into maintenance_work_mem + */ +static double +HnswGetMaxInMemoryElements(int m, double ml, int dimensions) +{ + Size elementSize = sizeof(HnswElementData); + double avgLevel = -log(0.5) * ml; + + elementSize += sizeof(HnswNeighborArray) * (avgLevel + 1); + elementSize += sizeof(HnswCandidate) * (m * (avgLevel + 2)); + elementSize += sizeof(ItemPointerData); + elementSize += VECTOR_SIZE(dimensions); + return (maintenance_work_mem * 1024L) / elementSize; +} + +/* + * Initialize the build state + */ +static void +InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum) +{ + buildstate->heap = heap; + buildstate->index = index; + buildstate->indexInfo = indexInfo; + buildstate->forkNum = forkNum; + + buildstate->m = HnswGetM(index); + buildstate->efConstruction = HnswGetEfConstruction(index); + buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + + /* Require column to have dimensions to be indexed */ + if (buildstate->dimensions < 0) + elog(ERROR, "column does not have dimensions"); + + if (buildstate->dimensions > HNSW_MAX_DIM) + elog(ERROR, "column cannot have more than %d dimensions for hnsw index", HNSW_MAX_DIM); + + buildstate->reltuples = 0; + buildstate->indtuples = 0; + + /* Get support functions */ + buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + buildstate->collation = index->rd_indcollation[0]; + + buildstate->elements = NIL; + buildstate->entryPoint = NULL; + buildstate->ml = HnswGetMl(buildstate->m); + buildstate->maxLevel = HnswGetMaxLevel(buildstate->m); + buildstate->maxInMemoryElements = HnswGetMaxInMemoryElements(buildstate->m, buildstate->ml, buildstate->dimensions); + buildstate->flushed = false; + + /* Reuse for each tuple */ + buildstate->normvec = InitVector(buildstate->dimensions); + + buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, + "Hnsw build temporary context", + ALLOCSET_DEFAULT_SIZES); +} + +/* + * Free resources + */ +static void +FreeBuildState(HnswBuildState * buildstate) +{ + pfree(buildstate->normvec); + MemoryContextDelete(buildstate->tmpCtx); +} + +/* + * Build graph + */ +static void +BuildGraph(HnswBuildState * buildstate, ForkNumber forkNum) +{ + UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_HNSW_PHASE_LOAD); + +#if PG_VERSION_NUM >= 120000 + buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, true, BuildCallback, (void *) buildstate, NULL); +#else + buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, BuildCallback, (void *) buildstate, NULL); +#endif +} + +/* + * Build the index + */ +static void +BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo, + HnswBuildState * buildstate, ForkNumber forkNum) +{ + InitBuildState(buildstate, heap, index, indexInfo, forkNum); + + if (buildstate->heap != NULL) + BuildGraph(buildstate, forkNum); + + if (!buildstate->flushed) + FlushPages(buildstate); + + FreeBuildState(buildstate); +} + +/* + * Build the index for a logged table + */ +IndexBuildResult * +hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo) +{ + IndexBuildResult *result; + HnswBuildState buildstate; + + BuildIndex(heap, index, indexInfo, &buildstate, MAIN_FORKNUM); + + result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult)); + result->heap_tuples = buildstate.reltuples; + result->index_tuples = buildstate.indtuples; + + return result; +} + +/* + * Build the index for an unlogged table + */ +void +hnswbuildempty(Relation index) +{ + IndexInfo *indexInfo = BuildIndexInfo(index); + HnswBuildState buildstate; + + BuildIndex(NULL, index, indexInfo, &buildstate, INIT_FORKNUM); +} diff --git a/src/hnswinsert.c b/src/hnswinsert.c new file mode 100644 index 0000000..c2a17ff --- /dev/null +++ b/src/hnswinsert.c @@ -0,0 +1,491 @@ +#include "postgres.h" + +#include + +#include "hnsw.h" +#include "storage/bufmgr.h" +#include "storage/lmgr.h" +#include "utils/memutils.h" + +/* + * Get the insert page + */ +static BlockNumber +GetInsertPage(Relation index) +{ + Buffer buf; + Page page; + HnswMetaPage metap; + BlockNumber insertPage; + + buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = HnswPageGetMeta(page); + + insertPage = metap->insertPage; + + UnlockReleaseBuffer(buf); + + return insertPage; +} + +/* + * Check for a free offset + */ +static bool +HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *firstFreePage) +{ + OffsetNumber offno; + OffsetNumber maxoffno = PageGetMaxOffsetNumber(page); + + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + + /* Skip neighbor tuples */ + if (!HnswIsElementTuple(etup)) + continue; + + if (etup->deleted) + { + BlockNumber neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); + OffsetNumber neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); + ItemId itemid; + + if (!BlockNumberIsValid(*firstFreePage)) + *firstFreePage = neighborPage; + + if (neighborPage == BufferGetBlockNumber(buf)) + { + *nbuf = buf; + *npage = page; + } + else + { + *nbuf = ReadBuffer(index, neighborPage); + LockBuffer(*nbuf, BUFFER_LOCK_EXCLUSIVE); + + /* Skip WAL for now */ + *npage = BufferGetPage(*nbuf); + } + + itemid = PageGetItemId(*npage, neighborOffno); + + /* Check for space on neighbor tuple page */ + if (PageGetFreeSpace(*npage) + ItemIdGetLength(itemid) - sizeof(ItemIdData) >= ntupSize) + { + *freeOffno = offno; + *freeNeighborOffno = neighborOffno; + return true; + } + else if (*nbuf != buf) + UnlockReleaseBuffer(*nbuf); + } + } + + return false; +} + +/* + * Add a new page + */ +static void +HnswInsertAppendPage(Relation index, Buffer *nbuf, Page *npage, GenericXLogState *state, Page page) +{ + /* Add a new page */ + LockRelationForExtension(index, ExclusiveLock); + *nbuf = HnswNewBuffer(index, MAIN_FORKNUM); + UnlockRelationForExtension(index, ExclusiveLock); + + /* Init new page */ + *npage = GenericXLogRegisterBuffer(state, *nbuf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(*nbuf, *npage); + + /* Update previous buffer */ + HnswPageGetOpaque(page)->nextblkno = BufferGetBlockNumber(*nbuf); +} + +/* + * Add to element and neighbor pages + */ +static void +WriteNewElementPages(Relation index, HnswElement e, int m) +{ + Buffer buf; + Page page; + GenericXLogState *state; + Size etupSize; + Size ntupSize; + Size combinedSize; + HnswElementTuple etup; + BlockNumber insertPage = GetInsertPage(index); + BlockNumber originalInsertPage = insertPage; + int dimensions = e->vec->dim; + HnswNeighborTuple ntup; + Buffer nbuf; + Page npage; + OffsetNumber freeOffno = InvalidOffsetNumber; + OffsetNumber freeNeighborOffno = InvalidOffsetNumber; + BlockNumber firstFreePage = InvalidBlockNumber; + + /* Calculate sizes */ + etupSize = HNSW_ELEMENT_TUPLE_SIZE(dimensions); + ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m); + combinedSize = etupSize + ntupSize + sizeof(ItemIdData); + + /* Prepare element tuple */ + etup = palloc0(etupSize); + HnswSetElementTuple(etup, e); + + /* Prepare neighbor tuple */ + ntup = palloc0(ntupSize); + HnswSetNeighborTuple(ntup, e, m); + + /* Find a page to insert the item */ + for (;;) + { + buf = ReadBuffer(index, insertPage); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + /* Space for both */ + if (PageGetFreeSpace(page) >= combinedSize) + { + nbuf = buf; + npage = page; + break; + } + + /* Space for element but not neighbors and last page */ + if (PageGetFreeSpace(page) >= etupSize && !BlockNumberIsValid(HnswPageGetOpaque(page)->nextblkno)) + { + HnswInsertAppendPage(index, &nbuf, &npage, state, page); + break; + } + + /* Space from deleted item */ + if (HnswFreeOffset(index, buf, page, e, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &firstFreePage)) + { + if (nbuf != buf) + npage = GenericXLogRegisterBuffer(state, nbuf, 0); + + break; + } + + insertPage = HnswPageGetOpaque(page)->nextblkno; + + if (BlockNumberIsValid(insertPage)) + { + /* Move to next page */ + GenericXLogAbort(state); + UnlockReleaseBuffer(buf); + } + else + { + Buffer newbuf; + Page newpage; + + HnswInsertAppendPage(index, &newbuf, &newpage, state, page); + + /* Commit */ + MarkBufferDirty(newbuf); + MarkBufferDirty(buf); + GenericXLogFinish(state); + + /* Unlock previous buffer */ + UnlockReleaseBuffer(buf); + + /* Prepare new buffer */ + state = GenericXLogStart(index); + buf = newbuf; + page = GenericXLogRegisterBuffer(state, buf, 0); + + /* Create new page for neighbors if needed */ + if (PageGetFreeSpace(page) < combinedSize) + HnswInsertAppendPage(index, &nbuf, &npage, state, page); + else + { + nbuf = buf; + npage = page; + } + + break; + } + } + + e->blkno = BufferGetBlockNumber(buf); + e->neighborPage = BufferGetBlockNumber(nbuf); + + insertPage = e->neighborPage; + + if (OffsetNumberIsValid(freeOffno)) + { + e->offno = freeOffno; + e->neighborOffno = freeNeighborOffno; + } + else + { + e->offno = OffsetNumberNext(PageGetMaxOffsetNumber(page)); + if (nbuf == buf) + e->neighborOffno = OffsetNumberNext(e->offno); + else + e->neighborOffno = FirstOffsetNumber; + } + + ItemPointerSet(&etup->neighbortid, e->neighborPage, e->neighborOffno); + + /* Add element and neighbors */ + if (OffsetNumberIsValid(freeOffno)) + { + if (!PageIndexTupleOverwrite(page, e->offno, (Item) etup, etupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + if (!PageIndexTupleOverwrite(npage, e->neighborOffno, (Item) ntup, ntupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + } + else + { + if (PageAddItem(page, (Item) etup, etupSize, InvalidOffsetNumber, false, false) != e->offno) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + if (PageAddItem(npage, (Item) ntup, ntupSize, InvalidOffsetNumber, false, false) != e->neighborOffno) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + } + + /* Commit */ + MarkBufferDirty(buf); + if (nbuf != buf) + MarkBufferDirty(nbuf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + if (nbuf != buf) + UnlockReleaseBuffer(nbuf); + + /* Update the insert page */ + if (insertPage != originalInsertPage && (!OffsetNumberIsValid(freeOffno) || firstFreePage == insertPage)) + HnswUpdateMetaPage(index, false, NULL, insertPage, MAIN_FORKNUM); +} + +/* + * Calculate index for update + */ +static int +HnswGetIndex(HnswUpdate * update, int m) +{ + return (update->hc.element->level - update->level) * m + update->index; +} + +/* + * Update neighbors + */ +static void +UpdateNeighborPages(Relation index, HnswElement e, int m, List *updates) +{ + ListCell *lc; + + /* Could update multiple at once for same element */ + /* but should only happen a low percent of time, so keep simple for now */ + foreach(lc, updates) + { + Buffer buf; + Page page; + GenericXLogState *state; + HnswUpdate *update = lfirst(lc); + ItemId itemid; + HnswNeighborTuple ntup; + Size ntupSize; + int idx; + OffsetNumber offno = update->hc.element->neighborOffno; + + /* Register page */ + buf = ReadBuffer(index, update->hc.element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + /* Get tuple */ + itemid = PageGetItemId(page, offno); + ntup = (HnswNeighborTuple) PageGetItem(page, itemid); + ntupSize = ItemIdGetLength(itemid); + + /* Calculate index */ + idx = HnswGetIndex(update, m); + + /* Make robust to issues */ + if (idx < ntup->count) + { + HnswNeighborTupleItem *neighbor = &ntup->neighbors[idx]; + + /* Update neighbor */ + ItemPointerSet(&neighbor->indextid, e->blkno, e->offno); + neighbor->distance = update->hc.distance; + + /* Overwrite tuple */ + if (!PageIndexTupleOverwrite(page, offno, (Item) ntup, ntupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + MarkBufferDirty(buf); + GenericXLogFinish(state); + } + else + GenericXLogAbort(state); + + UnlockReleaseBuffer(buf); + } +} + +/* + * Add a heap TID to an existing element + */ +static bool +HnswAddDuplicate(Relation index, HnswElement element, HnswElement dup) +{ + Buffer buf; + Page page; + GenericXLogState *state; + Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(dup->vec->dim); + HnswElementTuple etup; + int i; + + /* Read page */ + buf = ReadBuffer(index, dup->blkno); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + /* Find space */ + etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, dup->offno)); + for (i = 0; i < HNSW_HEAPTIDS; i++) + { + if (!ItemPointerIsValid(&etup->heaptids[i])) + break; + } + + /* Either being deleted or we lost our chance to another backend */ + if (i == 0 || i == HNSW_HEAPTIDS) + { + GenericXLogAbort(state); + UnlockReleaseBuffer(buf); + return false; + } + + /* Add heap TID */ + etup->heaptids[i] = *((ItemPointer) linitial(element->heaptids)); + + /* Overwrite tuple */ + if (!PageIndexTupleOverwrite(page, dup->offno, (Item) etup, etupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + + return true; +} + +/* + * Write changes to disk + */ +static void +WriteElement(Relation index, HnswElement element, int m, List *updates, HnswElement dup, HnswElement entryPoint) +{ + /* Try to add to existing page */ + if (dup != NULL) + { + if (HnswAddDuplicate(index, element, dup)) + return; + } + + /* If fails, take this path */ + WriteNewElementPages(index, element, m); + UpdateNeighborPages(index, element, m, updates); + + /* Update metapage if needed */ + if (entryPoint == NULL || element->level > entryPoint->level) + HnswUpdateMetaPage(index, true, element, InvalidBlockNumber, MAIN_FORKNUM); +} + +/* + * Insert a tuple into the index + */ +bool +HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel) +{ + Datum value; + FmgrInfo *normprocinfo; + HnswElement entryPoint; + HnswElement element; + int m = HnswGetM(index); + int efConstruction = HnswGetEfConstruction(index); + double ml = HnswGetMl(m); + FmgrInfo *procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + Oid collation = index->rd_indcollation[0]; + List *updates = NIL; + HnswElement dup; + + /* Detoast once for all calls */ + value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + + /* Normalize if needed */ + normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + if (normprocinfo != NULL) + { + if (!HnswNormValue(normprocinfo, collation, &value, NULL)) + return false; + } + + /* Create an element */ + element = HnswInitElement(heap_tid, m, ml, HnswGetMaxLevel(m)); + element->vec = DatumGetVector(value); + + /* Get entry point */ + entryPoint = HnswGetEntryPoint(index); + + /* Insert element in graph */ + dup = HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, &updates, false); + + /* Write to disk */ + WriteElement(index, element, m, updates, dup, entryPoint); + + return true; +} + +/* + * Insert a tuple into the index + */ +bool +hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, + Relation heap, IndexUniqueCheck checkUnique +#if PG_VERSION_NUM >= 140000 + ,bool indexUnchanged +#endif + ,IndexInfo *indexInfo +) +{ + MemoryContext oldCtx; + MemoryContext insertCtx; + + /* Skip nulls */ + if (isnull[0]) + return false; + + /* Create memory context */ + insertCtx = AllocSetContextCreate(CurrentMemoryContext, + "Hnsw insert temporary context", + ALLOCSET_DEFAULT_SIZES); + oldCtx = MemoryContextSwitchTo(insertCtx); + + /* Insert tuple */ + HnswInsertTuple(index, values, isnull, heap_tid, heap); + + /* Delete memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextDelete(insertCtx); + + return false; +} diff --git a/src/hnswscan.c b/src/hnswscan.c new file mode 100644 index 0000000..365c6e3 --- /dev/null +++ b/src/hnswscan.c @@ -0,0 +1,212 @@ +#include "postgres.h" + +#include "access/relscan.h" +#include "hnsw.h" +#include "pgstat.h" +#include "storage/bufmgr.h" +#include "utils/memutils.h" + +/* + * Algorithm 5 from paper + */ +static void +GetScanItems(IndexScanDesc scan, Datum q) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + Relation index = scan->indexRelation; + FmgrInfo *procinfo = so->procinfo; + Oid collation = so->collation; + List *ep = NIL; + List *w; + HnswElement entryPoint = HnswGetEntryPoint(index); + + if (entryPoint == NULL) + return; + + ep = lappend(ep, HnswEntryCandidate(entryPoint, q, index, procinfo, collation, false)); + + for (int lc = entryPoint->level; lc >= 1; lc--) + { + w = HnswSearchLayer(q, ep, 1, lc, index, procinfo, collation, false, NULL, NULL); + ep = w; + } + + so->w = HnswSearchLayer(q, ep, hnsw_ef_search, 0, index, procinfo, collation, false, NULL, NULL); +} + +/* + * Get dimensions from metapage + */ +static int +GetDimensions(Relation index) +{ + Buffer buf; + Page page; + HnswMetaPage metap; + int dimensions; + + buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = HnswPageGetMeta(page); + + dimensions = metap->dimensions; + + UnlockReleaseBuffer(buf); + + return dimensions; +} + +/* + * Prepare for an index scan + */ +IndexScanDesc +hnswbeginscan(Relation index, int nkeys, int norderbys) +{ + IndexScanDesc scan; + HnswScanOpaque so; + + scan = RelationGetIndexScan(index, nkeys, norderbys); + + so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData)); + so->buf = InvalidBuffer; + so->first = true; + so->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, + "Hnsw scan temporary context", + ALLOCSET_DEFAULT_SIZES); + + /* Set support functions */ + so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + so->collation = index->rd_indcollation[0]; + + scan->opaque = so; + + return scan; +} + +/* + * Start or restart an index scan + */ +void +hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + + so->first = true; + MemoryContextReset(so->tmpCtx); + + if (keys && scan->numberOfKeys > 0) + memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData)); + + if (orderbys && scan->numberOfOrderBys > 0) + memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData)); +} + +/* + * Fetch the next tuple in the given scan + */ +bool +hnswgettuple(IndexScanDesc scan, ScanDirection dir) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + MemoryContext oldCtx = MemoryContextSwitchTo(so->tmpCtx); + + /* + * Index can be used to scan backward, but Postgres doesn't support + * backward scan on operators + */ + Assert(ScanDirectionIsForward(dir)); + + if (so->first) + { + Datum value; + + /* Count index scan for stats */ + pgstat_count_index_scan(scan->indexRelation); + + /* Safety check */ + if (scan->orderByData == NULL) + elog(ERROR, "cannot scan hnsw index without order"); + + if (scan->orderByData->sk_flags & SK_ISNULL) + value = PointerGetDatum(InitVector(GetDimensions(scan->indexRelation))); + else + { + value = scan->orderByData->sk_argument; + + /* Value should not be compressed or toasted */ + Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); + Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); + + /* Fine if normalization fails */ + if (so->normprocinfo != NULL) + HnswNormValue(so->normprocinfo, so->collation, &value, NULL); + } + + GetScanItems(scan, value); + so->first = false; + } + + while (list_length(so->w) > 0) + { + HnswCandidate *hc = llast(so->w); + ItemPointer tid; + BlockNumber indexblkno; + + /* Move to next element if no valid heap tids */ + if (list_length(hc->element->heaptids) == 0) + { + so->w = list_delete_last(so->w); + continue; + } + + tid = llast(hc->element->heaptids); + indexblkno = hc->element->blkno; + + hc->element->heaptids = list_delete_last(hc->element->heaptids); + + MemoryContextSwitchTo(oldCtx); + +#if PG_VERSION_NUM >= 120000 + scan->xs_heaptid = *tid; +#else + scan->xs_ctup.t_self = *tid; +#endif + + if (BufferIsValid(so->buf)) + ReleaseBuffer(so->buf); + + /* + * An index scan must maintain a pin on the index page holding the + * item last returned by amgettuple + * + * https://www.postgresql.org/docs/current/index-locking.html + */ + so->buf = ReadBuffer(scan->indexRelation, indexblkno); + + scan->xs_recheckorderby = false; + return true; + } + + MemoryContextSwitchTo(oldCtx); + return false; +} + +/* + * End a scan and release resources + */ +void +hnswendscan(IndexScanDesc scan) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + + /* Release pin */ + if (BufferIsValid(so->buf)) + ReleaseBuffer(so->buf); + + MemoryContextDelete(so->tmpCtx); + + pfree(so); + scan->opaque = NULL; +} diff --git a/src/hnswutils.c b/src/hnswutils.c new file mode 100644 index 0000000..1d3409e --- /dev/null +++ b/src/hnswutils.c @@ -0,0 +1,982 @@ +#include "postgres.h" + +#include + +#include "hnsw.h" +#include "storage/bufmgr.h" +#include "vector.h" + +/* + * Get the number of connection in the index + */ +int +HnswGetM(Relation index) +{ + HnswOptions *opts = (HnswOptions *) index->rd_options; + + if (opts) + return opts->m; + + return HNSW_DEFAULT_M; +} + +/* + * Get the size of the dynamic candidate list in the index + */ +int +HnswGetEfConstruction(Relation index) +{ + HnswOptions *opts = (HnswOptions *) index->rd_options; + + if (opts) + return opts->efConstruction; + + return HNSW_DEFAULT_EF_CONSTRUCTION; +} + +/* + * Get proc + */ +FmgrInfo * +HnswOptionalProcInfo(Relation rel, uint16 procnum) +{ + if (!OidIsValid(index_getprocid(rel, 1, procnum))) + return NULL; + + return index_getprocinfo(rel, 1, procnum); +} + +/* + * Divide by the norm + * + * Returns false if value should not be indexed + * + * The caller needs to free the pointer stored in value + * if it's different than the original value + */ +bool +HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result) +{ + double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); + + if (norm > 0) + { + Vector *v = DatumGetVector(*value); + + if (result == NULL) + result = InitVector(v->dim); + + for (int i = 0; i < v->dim; i++) + result->x[i] = v->x[i] / norm; + + *value = PointerGetDatum(result); + + return true; + } + + return false; +} + +/* + * New buffer + */ +Buffer +HnswNewBuffer(Relation index, ForkNumber forkNum) +{ + Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL); + + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + return buf; +} + +/* + * Init page + */ +void +HnswInitPage(Buffer buf, Page page) +{ + PageInit(page, BufferGetPageSize(buf), sizeof(HnswPageOpaqueData)); + HnswPageGetOpaque(page)->nextblkno = InvalidBlockNumber; + HnswPageGetOpaque(page)->page_id = HNSW_PAGE_ID; +} + +/* + * Init and register page + */ +void +HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state) +{ + *state = GenericXLogStart(index); + *page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(*buf, *page); +} + +/* + * Commit buffer + */ +void +HnswCommitBuffer(Buffer buf, GenericXLogState *state) +{ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); +} + +/* + * Allocate neighbors + */ +void +HnswInitNeighbors(HnswElement element, int m) +{ + int level = element->level; + + element->neighbors = palloc(sizeof(HnswNeighborArray) * (level + 1)); + + for (int lc = 0; lc <= level; lc++) + { + HnswNeighborArray *a; + int lm = HnswGetLayerM(m, lc); + + a = &element->neighbors[lc]; + a->length = 0; + a->items = palloc(sizeof(HnswCandidate) * lm); + } +} + +/* + * Allocate an element + */ +HnswElement +HnswInitElement(ItemPointer heaptid, int m, double ml, int maxLevel) +{ + HnswElement element = palloc(sizeof(HnswElementData)); + + int level = (int) (-log(RandomDouble()) * ml); + + /* Cap level */ + if (level > maxLevel) + level = maxLevel; + + element->heaptids = NIL; + HnswAddHeapTid(element, heaptid); + + element->level = level; + element->deleted = 0; + + HnswInitNeighbors(element, m); + + return element; +} + +/* + * Free an element + */ +void +HnswFreeElement(HnswElement element) +{ + list_free_deep(element->heaptids); + for (int lc = 0; lc <= element->level; lc++) + pfree(element->neighbors[lc].items); + pfree(element->neighbors); + pfree(element->vec); + pfree(element); +} + +/* + * Add a heap TID to an element + */ +void +HnswAddHeapTid(HnswElement element, ItemPointer heaptid) +{ + ItemPointer copy = palloc(sizeof(ItemPointerData)); + + ItemPointerCopy(heaptid, copy); + element->heaptids = lappend(element->heaptids, copy); +} + +/* + * Allocate an element from block and offset numbers + */ +static HnswElement +InitElementFromBlock(BlockNumber blkno, OffsetNumber offno) +{ + HnswElement element = palloc(sizeof(HnswElementData)); + + element->blkno = blkno; + element->offno = offno; + element->neighbors = NULL; + element->vec = NULL; + return element; +} + +/* + * Get the entry point + */ +HnswElement +HnswGetEntryPoint(Relation index) +{ + Buffer buf; + Page page; + HnswMetaPage metap; + HnswElement entryPoint = NULL; + + buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = HnswPageGetMeta(page); + + if (BlockNumberIsValid(metap->entryBlkno)) + entryPoint = InitElementFromBlock(metap->entryBlkno, metap->entryOffno); + + UnlockReleaseBuffer(buf); + + return entryPoint; +} + +/* + * Update the metapage + */ +void +HnswUpdateMetaPage(Relation index, bool updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum) +{ + Buffer buf; + Page page; + GenericXLogState *state; + HnswMetaPage metap; + + buf = ReadBufferExtended(index, forkNum, HNSW_METAPAGE_BLKNO, RBM_NORMAL, NULL); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + metap = HnswPageGetMeta(page); + + if (updateEntry) + { + if (entryPoint == NULL) + { + metap->entryBlkno = InvalidBlockNumber; + metap->entryOffno = InvalidOffsetNumber; + metap->entryLevel = -1; + } + else + { + metap->entryBlkno = entryPoint->blkno; + metap->entryOffno = entryPoint->offno; + metap->entryLevel = entryPoint->level; + } + } + + if (BlockNumberIsValid(insertPage)) + metap->insertPage = insertPage; + + HnswCommitBuffer(buf, state); +} + +/* + * Set element tuple, except for neighbor info + */ +void +HnswSetElementTuple(HnswElementTuple etup, HnswElement element) +{ + etup->type = HNSW_ELEMENT_TUPLE_TYPE; + etup->level = element->level; + etup->deleted = 0; + for (int i = 0; i < HNSW_HEAPTIDS; i++) + { + if (i < list_length(element->heaptids)) + etup->heaptids[i] = *((ItemPointer) list_nth(element->heaptids, i)); + else + ItemPointerSetInvalid(&etup->heaptids[i]); + } + memcpy(&etup->vec, element->vec, VECTOR_SIZE(element->vec->dim)); +} + +/* + * Set neighbor tuple + */ +void +HnswSetNeighborTuple(HnswNeighborTuple ntup, HnswElement e, int m) +{ + int idx = 0; + + ntup->type = HNSW_NEIGHBOR_TUPLE_TYPE; + + for (int lc = e->level; lc >= 0; lc--) + { + HnswNeighborArray *neighbors = &e->neighbors[lc]; + int lm = HnswGetLayerM(m, lc); + + for (int i = 0; i < lm; i++) + { + HnswNeighborTupleItem *neighbor = &ntup->neighbors[idx++]; + + if (i < neighbors->length) + { + HnswCandidate *hc = &neighbors->items[i]; + + ItemPointerSet(&neighbor->indextid, hc->element->blkno, hc->element->offno); + neighbor->distance = hc->distance; + } + else + { + ItemPointerSetInvalid(&neighbor->indextid); + neighbor->distance = NAN; + } + } + } + + ntup->count = idx; +} + +/* + * Load neighbors from page + */ +static void +LoadNeighborsFromPage(HnswElement element, Relation index, Page page) +{ + HnswNeighborTuple ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); + int m = HnswGetM(index); + int neighborCount = (element->level + 2) * m; + + Assert(HnswIsNeighborTuple(ntup)); + + HnswInitNeighbors(element, m); + + /* Ensure expected neighbors */ + if (ntup->count != neighborCount) + return; + + for (int i = 0; i < neighborCount; i++) + { + HnswElement e; + int level; + HnswCandidate *hc; + HnswNeighborTupleItem *neighbor; + HnswNeighborArray *neighbors; + + neighbor = &ntup->neighbors[i]; + + if (!ItemPointerIsValid(&neighbor->indextid)) + continue; + + e = InitElementFromBlock(ItemPointerGetBlockNumber(&neighbor->indextid), ItemPointerGetOffsetNumber(&neighbor->indextid)); + + /* Calculate level based on offset */ + level = element->level - i / m; + if (level < 0) + level = 0; + + neighbors = &element->neighbors[level]; + hc = &neighbors->items[neighbors->length++]; + hc->element = e; + hc->distance = neighbor->distance; + } +} + +/* + * Load neighbors + */ +static void +LoadNeighbors(HnswElement element, Relation index) +{ + Buffer buf; + Page page; + + buf = ReadBuffer(index, element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + + LoadNeighborsFromPage(element, index, page); + + UnlockReleaseBuffer(buf); +} + +/* + * Load an element and optionally get its distance from q + */ +void +HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +{ + Buffer buf; + Page page; + HnswElementTuple etup; + + /* Read vector */ + buf = ReadBuffer(index, element->blkno); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + + etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, element->offno)); + + Assert(HnswIsElementTuple(etup)); + + /* Load element */ + element->heaptids = NIL; + for (int i = 0; i < HNSW_HEAPTIDS; i++) + { + /* Can stop at first invalid */ + if (!ItemPointerIsValid(&etup->heaptids[i])) + break; + + HnswAddHeapTid(element, &etup->heaptids[i]); + } + element->level = etup->level; + element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); + element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); + element->deleted = etup->deleted; + + if (loadVec) + { + element->vec = palloc(VECTOR_SIZE(etup->vec.dim)); + memcpy(element->vec, &etup->vec, VECTOR_SIZE(etup->vec.dim)); + } + + /* Calculate distance */ + if (distance != NULL) + *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->vec))); + + UnlockReleaseBuffer(buf); +} + +/* + * Get the distance for a candidate + */ +static float +GetCandidateDistance(HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation) +{ + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, PointerGetDatum(hc->element->vec))); +} + +/* + * Create a candidate for the entry point + */ +HnswCandidate * +HnswEntryCandidate(HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadvec) +{ + HnswCandidate *hc = palloc(sizeof(HnswCandidate)); + + hc->element = entryPoint; + if (index == NULL) + hc->distance = GetCandidateDistance(hc, q, procinfo, collation); + else + HnswLoadElement(hc->element, &hc->distance, &q, index, procinfo, collation, loadvec); + return hc; +} + +/* + * Compare candidate distances + */ +static int +CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) + return 1; + + if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) + return -1; + + return 0; +} + +/* + * Compare candidate distances + */ +static int +CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) + return -1; + + if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) + return 1; + + return 0; +} + +/* + * Create a pairing heap node for a candidate + */ +static HnswPairingHeapNode * +CreatePairingHeapNode(HnswCandidate * c) +{ + HnswPairingHeapNode *node = palloc(sizeof(HnswPairingHeapNode)); + + node->inner = c; + return node; +} + +/* + * Add to visited + */ +static inline void +AddToVisited(HTAB *v, HnswCandidate * hc, Relation index, bool *found) +{ + if (index == NULL) + hash_search(v, &hc->element, HASH_ENTER, found); + else + { + ItemPointerData indextid; + + ItemPointerSet(&indextid, hc->element->blkno, hc->element->offno); + hash_search(v, &indextid, HASH_ENTER, found); + } +} + +/* + * Algorithm 2 from paper + */ +List * +HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, bool inserting, BlockNumber *skipPage, OffsetNumber *skipOffno) +{ + ListCell *lc2; + + List *w = NIL; + pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); + pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL); + int wlen = 0; + HASHCTL hash_ctl; + HTAB *v; + + /* Create hash table */ + if (index == NULL) + { + hash_ctl.keysize = sizeof(HnswElement *); + hash_ctl.entrysize = sizeof(HnswElement *); + } + else + { + hash_ctl.keysize = sizeof(ItemPointerData); + hash_ctl.entrysize = sizeof(ItemPointerData); + } + + hash_ctl.hcxt = CurrentMemoryContext; + v = hash_create("hnsw visited", 256, &hash_ctl, HASH_ELEM | HASH_BLOBS | HASH_CONTEXT); + + /* Add entry points to v, C, and W */ + foreach(lc2, ep) + { + HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); + + AddToVisited(v, hc, index, NULL); + + pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node)); + pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node)); + + wlen++; + } + + while (!pairingheap_is_empty(C)) + { + HnswNeighborArray *neighborhood; + HnswCandidate *c = ((HnswPairingHeapNode *) pairingheap_remove_first(C))->inner; + HnswCandidate *f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; + + if (c->distance > f->distance) + break; + + if (c->element->neighbors == NULL) + LoadNeighbors(c->element, index); + + /* Get the neighborhood at layer lc */ + neighborhood = &c->element->neighbors[lc]; + + for (int i = 0; i < neighborhood->length; i++) + { + HnswCandidate *e = &neighborhood->items[i]; + bool visited; + + AddToVisited(v, e, index, &visited); + + if (!visited) + { + float eDistance; + + f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; + + if (index == NULL) + eDistance = GetCandidateDistance(e, q, procinfo, collation); + else + HnswLoadElement(e->element, &eDistance, &q, index, procinfo, collation, inserting); + + /* Skip if fully deleted */ + if (e->element->deleted) + continue; + + /* Skip for inserts if deleting */ + if (inserting && list_length(e->element->heaptids) == 0) + continue; + + /* Skip self for vacuuming update */ + if (skipPage != NULL && e->element->neighborPage == *skipPage && e->element->neighborOffno == *skipOffno) + continue; + + /* Make robust to issues */ + if (e->element->level < lc) + continue; + + if (eDistance < f->distance || wlen < ef) + { + /* Copy e */ + HnswCandidate *ec = palloc(sizeof(HnswCandidate)); + + ec->element = e->element; + ec->distance = eDistance; + + pairingheap_add(C, &(CreatePairingHeapNode(ec)->ph_node)); + pairingheap_add(W, &(CreatePairingHeapNode(ec)->ph_node)); + wlen++; + + /* No need to decrement wlen */ + if (wlen > ef) + pairingheap_remove_first(W); + } + } + } + } + + /* Add each element of W to w */ + while (!pairingheap_is_empty(W)) + { + HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner; + + w = lappend(w, hc); + } + + return w; +} + +/* + * Calculate the distance between elements + */ +static float +HnswGetDistance(HnswElement a, HnswElement b, int lc, FmgrInfo *procinfo, Oid collation) +{ + /* Look for cached distance */ + if (a->neighbors != NULL) + { + Assert(a->level >= lc); + + for (int i = 0; i < a->neighbors[lc].length; i++) + { + if (a->neighbors[lc].items[i].element == b) + return a->neighbors[lc].items[i].distance; + } + } + + if (b->neighbors != NULL) + { + Assert(b->level >= lc); + + for (int i = 0; i < b->neighbors[lc].length; i++) + { + if (b->neighbors[lc].items[i].element == a) + return b->neighbors[lc].items[i].distance; + } + } + + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(a->vec), PointerGetDatum(b->vec))); +} + +/* + * Check if an element is closer to q than any element from R + */ +static bool +CheckElementCloser(HnswCandidate * e, List *r, int lc, FmgrInfo *procinfo, Oid collation) +{ + ListCell *lc2; + + foreach(lc2, r) + { + HnswCandidate *ri = lfirst(lc2); + float distance = HnswGetDistance(e->element, ri->element, lc, procinfo, collation); + + if (distance <= e->distance) + return false; + } + + return true; +} + +/* + * Algorithm 4 from paper + */ +static List * +SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswCandidate * *pruned) +{ + List *r = NIL; + List *w = list_copy(c); + pairingheap *wd; + + if (list_length(w) < m) + return w; + + wd = pairingheap_allocate(CompareNearestCandidates, NULL); + + while (list_length(w) > 0 && list_length(r) < m) + { + /* Assumes w is already ordered desc */ + HnswCandidate *e = llast(w); + bool closer; + + w = list_delete_last(w); + + closer = CheckElementCloser(e, r, lc, procinfo, collation); + + if (closer) + r = lappend(r, e); + else + pairingheap_add(wd, &(CreatePairingHeapNode(e)->ph_node)); + } + + /* Keep pruned connections */ + while (!pairingheap_is_empty(wd) && list_length(r) < m) + r = lappend(r, ((HnswPairingHeapNode *) pairingheap_remove_first(wd))->inner); + + /* Return pruned for update connections */ + if (pruned != NULL) + { + if (!pairingheap_is_empty(wd)) + *pruned = ((HnswPairingHeapNode *) pairingheap_first(wd))->inner; + else + *pruned = linitial(w); + } + + return r; +} + +/* + * Find duplicate element + */ +static HnswElement +HnswFindDuplicate(HnswElement e, List *neighbors) +{ + ListCell *lc; + + foreach(lc, neighbors) + { + HnswCandidate *neighbor = lfirst(lc); + + /* Exit early since ordered by distance */ + if (vector_cmp_internal(e->vec, neighbor->element->vec) != 0) + break; + + /* Check for space */ + if (list_length(neighbor->element->heaptids) < HNSW_HEAPTIDS) + return neighbor->element; + } + + return NULL; +} + +/* + * Add connections + */ +static void +AddConnections(HnswElement element, List *neighbors, int m, int lc) +{ + ListCell *lc2; + HnswNeighborArray *a = &element->neighbors[lc]; + + foreach(lc2, neighbors) + a->items[a->length++] = *((HnswCandidate *) lfirst(lc2)); +} + +/* + * Compare candidate distances + */ +static int +#if PG_VERSION_NUM >= 130000 +CompareCandidateDistances(const ListCell *a, const ListCell *b) +#else +CompareCandidateDistances(const void *a, const void *b) +#endif +{ + HnswCandidate *hca = lfirst((ListCell *) a); + HnswCandidate *hcb = lfirst((ListCell *) b); + + if (hca->distance < hcb->distance) + return 1; + + if (hca->distance > hcb->distance) + return -1; + + return 0; +} + +/* + * Create update + */ +static HnswUpdate * +CreateUpdate(HnswCandidate * hc, int level, int index) +{ + HnswUpdate *update = palloc(sizeof(HnswUpdate)); + + update->hc = *hc; + update->level = level; + update->index = index; + return update; +} + +/* + * Update connections + */ +static void +UpdateConnections(HnswElement element, List *neighbors, int m, int lc, List **updates, Relation index, FmgrInfo *procinfo, Oid collation) +{ + ListCell *lc2; + + foreach(lc2, neighbors) + { + HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); + HnswNeighborArray *currentNeighbors = &hc->element->neighbors[lc]; + + HnswCandidate hc2; + + hc2.element = element; + hc2.distance = hc->distance; + + if (currentNeighbors->length < m) + { + currentNeighbors->items[currentNeighbors->length++] = hc2; + + /* Track updates */ + if (updates != NULL) + *updates = lappend(*updates, CreateUpdate(hc, lc, currentNeighbors->length - 1)); + } + else + { + /* Shrink connections */ + HnswCandidate *pruned = NULL; + List *c = NIL; + + /* Add and sort candidates */ + for (int i = 0; i < currentNeighbors->length; i++) + c = lappend(c, ¤tNeighbors->items[i]); + c = lappend(c, &hc2); + list_sort(c, CompareCandidateDistances); + + /* Load elements on insert */ + if (index != NULL) + { + for (int i = 0; i < currentNeighbors->length; i++) + { + if (currentNeighbors->items[i].element->vec == NULL) + { + HnswLoadElement(currentNeighbors->items[i].element, NULL, NULL, index, procinfo, collation, true); + + /* Prune deleted element */ + if (currentNeighbors->items[i].element->deleted) + { + pruned = ¤tNeighbors->items[i]; + break; + } + } + } + } + + if (pruned == NULL) + { + SelectNeighbors(c, m, lc, procinfo, collation, &pruned); + + /* Should not happen */ + if (pruned == NULL) + continue; + } + + /* Find and replace the pruned element */ + for (int i = 0; i < currentNeighbors->length; i++) + { + if (currentNeighbors->items[i].element == pruned->element) + { + currentNeighbors->items[i] = hc2; + + /* Track updates */ + if (updates != NULL) + *updates = lappend(*updates, CreateUpdate(hc, lc, i)); + + break; + } + } + } + } +} + +/* + * Algorithm 1 from paper + */ +HnswElement +HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, List **updates, bool vacuuming) +{ + List *ep = NIL; + List *w; + int level = element->level; + int entryLevel; + List **newNeighbors = palloc(sizeof(List *) * (level + 1)); + Datum q = PointerGetDatum(element->vec); + HnswElement dup; + BlockNumber *skipPage = vacuuming ? &element->neighborPage : NULL; + OffsetNumber *skipOffno = vacuuming ? &element->neighborOffno : NULL; + bool removeEntryPoint; + HnswCandidate *entryCandidate; + + /* Get entry point and level */ + if (entryPoint != NULL) + { + entryCandidate = HnswEntryCandidate(entryPoint, q, index, procinfo, collation, true); + ep = lappend(ep, entryCandidate); + entryLevel = entryPoint->level; + removeEntryPoint = vacuuming && list_length(entryPoint->heaptids) == 0; + } + else + { + entryLevel = -1; + removeEntryPoint = false; + } + + /* 1st phase: greedy search to insert level */ + for (int lc = entryLevel; lc >= level + 1; lc--) + { + w = HnswSearchLayer(q, ep, 1, lc, index, procinfo, collation, true, skipPage, skipOffno); + ep = w; + } + + if (level > entryLevel) + level = entryLevel; + + /* 2nd phase */ + for (int lc = level; lc >= 0; lc--) + { + int lm = HnswGetLayerM(m, lc); + + w = HnswSearchLayer(q, ep, efConstruction, lc, index, procinfo, collation, true, skipPage, skipOffno); + + /* Remove entry point if it's being deleted */ + if (removeEntryPoint) + w = list_delete_ptr(w, entryCandidate); + + newNeighbors[lc] = SelectNeighbors(w, lm, lc, procinfo, collation, NULL); + ep = w; + } + + /* Look for duplicate */ + if (level >= 0 && !vacuuming) + { + dup = HnswFindDuplicate(element, newNeighbors[0]); + if (dup != NULL) + return dup; + } + + /* Update connections */ + for (int lc = level; lc >= 0; lc--) + { + int lm = HnswGetLayerM(m, lc); + + AddConnections(element, newNeighbors[lc], lm, lc); + + if (!vacuuming) + UpdateConnections(element, newNeighbors[lc], lm, lc, updates, index, procinfo, collation); + } + + return NULL; +} diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c new file mode 100644 index 0000000..b37c362 --- /dev/null +++ b/src/hnswvacuum.c @@ -0,0 +1,584 @@ +#include "postgres.h" + +#include + +#include "commands/vacuum.h" +#include "hnsw.h" +#include "storage/bufmgr.h" +#include "utils/memutils.h" + +/* + * Check if deleted list contains an index tid + */ +static bool +DeletedContains(HTAB *deleted, ItemPointer indextid) +{ + bool found; + + hash_search(deleted, indextid, HASH_FIND, &found); + return found; +} + +/* + * Remove deleted heap TIDs + * + * OK to remove for entry point, since always considered for searches and inserts + */ +static void +RemoveHeapTids(HnswVacuumState * vacuumstate) +{ + BlockNumber blkno = HNSW_HEAD_BLKNO; + HnswElement highestPoint = &vacuumstate->highestPoint; + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + HnswElement entryPoint = HnswGetEntryPoint(vacuumstate->index); + + /* Store separately since highestPoint.level is uint8 */ + int highestLevel = -1; + + /* Initialize highest point */ + highestPoint->blkno = InvalidBlockNumber; + highestPoint->offno = InvalidOffsetNumber; + + while (BlockNumberIsValid(blkno)) + { + Buffer buf; + Page page; + GenericXLogState *state; + OffsetNumber offno; + OffsetNumber maxoffno; + bool updated = false; + + vacuum_delay_point(); + + buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + maxoffno = PageGetMaxOffsetNumber(page); + + /* Iterate over nodes */ + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + int idx = 0; + bool itemUpdated = false; + + /* Skip neighbor tuples */ + if (!HnswIsElementTuple(etup)) + continue; + + if (ItemPointerIsValid(&etup->heaptids[0])) + { + for (int i = 0; i < HNSW_HEAPTIDS; i++) + { + /* Stop at first unused */ + if (!ItemPointerIsValid(&etup->heaptids[i])) + break; + + if (vacuumstate->callback(&etup->heaptids[i], vacuumstate->callback_state)) + itemUpdated = true; + else + { + /* Move to front of list */ + etup->heaptids[idx++] = etup->heaptids[i]; + } + } + + if (itemUpdated) + { + Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(etup->vec.dim); + + /* Mark rest as invalid */ + for (int i = idx; i < HNSW_HEAPTIDS; i++) + ItemPointerSetInvalid(&etup->heaptids[i]); + + if (!PageIndexTupleOverwrite(page, offno, (Item) etup, etupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + updated = true; + } + } + + if (!ItemPointerIsValid(&etup->heaptids[0])) + { + ItemPointerData ip; + + /* Add to deleted list */ + ItemPointerSet(&ip, blkno, offno); + + (void) hash_search(vacuumstate->deleted, &ip, HASH_ENTER, NULL); + } + else if (etup->level > highestLevel && !(blkno == entryPoint->blkno && offno == entryPoint->offno)) + { + /* Keep track of highest non-entry point */ + highestPoint->blkno = blkno; + highestPoint->offno = offno; + highestPoint->level = etup->level; + highestLevel = etup->level; + } + } + + blkno = HnswPageGetOpaque(page)->nextblkno; + + if (updated) + { + MarkBufferDirty(buf); + GenericXLogFinish(state); + } + else + GenericXLogAbort(state); + + UnlockReleaseBuffer(buf); + } +} + +/* + * Check for deleted neighbors + */ +static bool +NeedsUpdated(HnswVacuumState * vacuumstate, HnswElement element) +{ + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + Buffer buf; + Page page; + HnswNeighborTuple ntup; + bool needsUpdated = false; + + buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); + + Assert(HnswIsNeighborTuple(ntup)); + + /* Check neighbors */ + for (int i = 0; i < ntup->count; i++) + { + HnswNeighborTupleItem *neighbor = &ntup->neighbors[i]; + + if (!ItemPointerIsValid(&neighbor->indextid)) + continue; + + /* Check if in deleted list */ + if (DeletedContains(vacuumstate->deleted, &neighbor->indextid)) + { + needsUpdated = true; + break; + } + } + + UnlockReleaseBuffer(buf); + + return needsUpdated; +} + +/* + * Repair graph for a single element + */ +static void +RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element) +{ + Relation index = vacuumstate->index; + Buffer buf; + Page page; + GenericXLogState *state; + int m = vacuumstate->m; + int efConstruction = vacuumstate->efConstruction; + FmgrInfo *procinfo = vacuumstate->procinfo; + Oid collation = vacuumstate->collation; + HnswElement entryPoint; + BufferAccessStrategy bas = vacuumstate->bas; + HnswNeighborTuple ntup = vacuumstate->ntup; + Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, m); + + /* Check if any neighbors point to deleted values */ + if (!NeedsUpdated(vacuumstate, element)) + return; + + /* Refresh entry point for each element */ + entryPoint = HnswGetEntryPoint(index); + + /* Special case for entry point */ + if (element->blkno == entryPoint->blkno && element->offno == entryPoint->offno) + { + if (BlockNumberIsValid(vacuumstate->highestPoint.blkno)) + { + /* Already updated */ + if (vacuumstate->highestPoint.blkno == element->blkno && vacuumstate->highestPoint.offno == element->offno) + return; + + entryPoint = &vacuumstate->highestPoint; + + /* Reset neighbors from previous update */ + entryPoint->neighbors = NULL; + } + else + entryPoint = NULL; + } + + /* Init fields */ + HnswInitNeighbors(element, m); + element->heaptids = NIL; + + /* Add element to graph, skipping itself */ + HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, NULL, true); + + /* Update neighbor tuple */ + /* Do this before getting page to minimize locking */ + HnswSetNeighborTuple(ntup, element, m); + + /* Get neighbor page */ + buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + /* Overwrite tuple */ + if (!PageIndexTupleOverwrite(page, element->neighborOffno, (Item) ntup, ntupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + MarkBufferDirty(buf); + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); +} + +/* + * Repair graph entry point + */ +static void +RepairGraphEntryPoint(HnswVacuumState * vacuumstate) +{ + Relation index = vacuumstate->index; + HnswElement highestPoint = &vacuumstate->highestPoint; + HnswElement entryPoint; + MemoryContext oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx); + + /* Repair graph for highest non-entry point */ + /* This may not be the highest with new inserts, but should be fine */ + if (BlockNumberIsValid(highestPoint->blkno)) + { + HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); + RepairGraphElement(vacuumstate, highestPoint); + } + + /* See if entry point needs updated */ + entryPoint = HnswGetEntryPoint(index); + if (entryPoint != NULL) + { + ItemPointerData epData; + + ItemPointerSet(&epData, entryPoint->blkno, entryPoint->offno); + + if (DeletedContains(vacuumstate->deleted, &epData)) + HnswUpdateMetaPage(index, true, highestPoint, InvalidBlockNumber, MAIN_FORKNUM); + else + { + /* Highest point will be used to repair */ + HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); + RepairGraphElement(vacuumstate, entryPoint); + } + } + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(vacuumstate->tmpCtx); +} + +/* + * Repair graph for all elements + */ +static void +RepairGraph(HnswVacuumState * vacuumstate) +{ + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + BlockNumber blkno = HNSW_HEAD_BLKNO; + + RepairGraphEntryPoint(vacuumstate); + + while (BlockNumberIsValid(blkno)) + { + Buffer buf; + Page page; + OffsetNumber offno; + OffsetNumber maxoffno; + List *elements = NIL; + ListCell *lc2; + MemoryContext oldCtx; + + vacuum_delay_point(); + + oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx); + + buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + maxoffno = PageGetMaxOffsetNumber(page); + + /* Load items into memory to minimize locking */ + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + HnswElement element; + + /* Skip neighbor tuples */ + if (!HnswIsElementTuple(etup)) + continue; + + /* Skip updating neighbors if being deleted */ + if (!ItemPointerIsValid(&etup->heaptids[0])) + continue; + + /* Create an element */ + element = palloc(sizeof(HnswElementData)); + element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); + element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); + element->level = etup->level; + element->blkno = blkno; + element->offno = offno; + element->vec = palloc(VECTOR_SIZE(etup->vec.dim)); + memcpy(element->vec, &etup->vec, VECTOR_SIZE(etup->vec.dim)); + + elements = lappend(elements, element); + } + + blkno = HnswPageGetOpaque(page)->nextblkno; + + UnlockReleaseBuffer(buf); + + /* Update neighbor pages */ + foreach(lc2, elements) + RepairGraphElement(vacuumstate, (HnswElement) lfirst(lc2)); + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(vacuumstate->tmpCtx); + } +} + +/* + * Mark items as deleted + */ +static void +MarkDeleted(HnswVacuumState * vacuumstate) +{ + BlockNumber blkno = HNSW_HEAD_BLKNO; + BlockNumber insertPage = InvalidBlockNumber; + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + IndexBulkDeleteResult *stats = vacuumstate->stats; + + while (BlockNumberIsValid(blkno)) + { + Buffer buf; + Page page; + GenericXLogState *state; + OffsetNumber offno; + OffsetNumber maxoffno; + + vacuum_delay_point(); + + buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); + + /* + * ambulkdelete cannot delete entries from pages that are pinned by + * other backends + * + * https://www.postgresql.org/docs/current/index-locking.html + */ + LockBufferForCleanup(buf); + + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + maxoffno = PageGetMaxOffsetNumber(page); + + /* Update element and neighbors together */ + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + HnswNeighborTuple ntup; + Size etupSize; + Size ntupSize; + Buffer nbuf; + Page npage; + BlockNumber neighborPage; + OffsetNumber neighborOffno; + + /* Skip neighbor tuples */ + if (!HnswIsElementTuple(etup)) + continue; + + /* Skip deleted tuples */ + if (etup->deleted) + continue; + + /* Skip live tuples */ + if (ItemPointerIsValid(&etup->heaptids[0])) + { + stats->num_index_tuples++; + continue; + } + + /* Update stats */ + stats->tuples_removed++; + + /* Calculate sizes */ + etupSize = HNSW_ELEMENT_TUPLE_SIZE(etup->vec.dim); + ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(etup->level, vacuumstate->m); + + /* Get neighbor page */ + neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); + neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); + + if (neighborPage == blkno) + { + nbuf = buf; + npage = page; + } + else + { + nbuf = ReadBufferExtended(index, MAIN_FORKNUM, neighborPage, RBM_NORMAL, bas); + LockBuffer(nbuf, BUFFER_LOCK_EXCLUSIVE); + npage = GenericXLogRegisterBuffer(state, nbuf, 0); + } + + ntup = (HnswNeighborTuple) PageGetItem(npage, PageGetItemId(npage, neighborOffno)); + + /* Overwrite element */ + etup->deleted = 1; + MemSet(&etup->vec.x, 0, etup->vec.dim * sizeof(float)); + + /* Overwrite neighbors */ + for (int i = 0; i < ntup->count; i++) + { + ItemPointerSetInvalid(&ntup->neighbors[i].indextid); + ntup->neighbors[i].distance = NAN; + } + + /* Overwrite element tuple */ + if (!PageIndexTupleOverwrite(page, offno, (Item) etup, etupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Overwrite neighbor tuple */ + if (!PageIndexTupleOverwrite(npage, neighborOffno, (Item) ntup, ntupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + MarkBufferDirty(buf); + if (nbuf != buf) + MarkBufferDirty(nbuf); + GenericXLogFinish(state); + if (nbuf != buf) + UnlockReleaseBuffer(nbuf); + + /* Set to first free page */ + if (!BlockNumberIsValid(insertPage)) + insertPage = blkno; + + /* Prepare new xlog */ + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + } + + blkno = HnswPageGetOpaque(page)->nextblkno; + + GenericXLogAbort(state); + UnlockReleaseBuffer(buf); + } + + HnswUpdateMetaPage(index, false, NULL, insertPage, MAIN_FORKNUM); +} + +/* + * Initialize the vacuum state + */ +static void +InitVacuumState(HnswVacuumState * vacuumstate, IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state) +{ + Relation index = info->index; + HASHCTL hash_ctl; + + if (stats == NULL) + stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult)); + + vacuumstate->index = index; + vacuumstate->stats = stats; + vacuumstate->callback = callback; + vacuumstate->callback_state = callback_state; + vacuumstate->m = HnswGetM(index); + vacuumstate->efConstruction = HnswGetEfConstruction(index); + vacuumstate->bas = GetAccessStrategy(BAS_BULKREAD); + vacuumstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + vacuumstate->collation = index->rd_indcollation[0]; + vacuumstate->ntup = palloc0(BLCKSZ); + vacuumstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, + "Hnsw vacuum temporary context", + ALLOCSET_DEFAULT_SIZES); + + /* Create hash table */ + hash_ctl.keysize = sizeof(ItemPointerData); + hash_ctl.entrysize = sizeof(ItemPointerData); + hash_ctl.hcxt = CurrentMemoryContext; + vacuumstate->deleted = hash_create("hnswbulkdelete indextids", 256, &hash_ctl, HASH_ELEM | HASH_BLOBS | HASH_CONTEXT); +} + +/* + * Free resources + */ +static void +FreeVacuumState(HnswVacuumState * vacuumstate) +{ + hash_destroy(vacuumstate->deleted); + FreeAccessStrategy(vacuumstate->bas); + pfree(vacuumstate->ntup); + MemoryContextDelete(vacuumstate->tmpCtx); +} + +/* + * Bulk delete tuples from the index + */ +IndexBulkDeleteResult * +hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, + IndexBulkDeleteCallback callback, void *callback_state) +{ + HnswVacuumState vacuumstate; + + InitVacuumState(&vacuumstate, info, stats, callback, callback_state); + + /* Pass 1: Remove heap TIDs */ + RemoveHeapTids(&vacuumstate); + + /* Pass 2: Repair graph */ + RepairGraph(&vacuumstate); + + /* Pass 3: Mark as deleted */ + MarkDeleted(&vacuumstate); + + FreeVacuumState(&vacuumstate); + + return vacuumstate.stats; +} + +/* + * Clean up after a VACUUM operation + */ +IndexBulkDeleteResult * +hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats) +{ + Relation rel = info->index; + + if (info->analyze_only) + return stats; + + /* stats is NULL if ambulkdelete not called */ + /* OK to return NULL if index not changed */ + if (stats == NULL) + return NULL; + + stats->num_pages = RelationGetNumberOfBlocks(rel); + + return stats; +} diff --git a/src/ivfflat.h b/src/ivfflat.h index 8dba5ef..2c18fd4 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -3,10 +3,6 @@ #include "postgres.h" -#if PG_VERSION_NUM < 110000 -#error "Requires PostgreSQL 11+" -#endif - #include "access/generic_xlog.h" #include "access/parallel.h" #include "access/reloptions.h" diff --git a/src/vector.c b/src/vector.c index 277bd2d..02964d6 100644 --- a/src/vector.c +++ b/src/vector.c @@ -4,6 +4,7 @@ #include "catalog/pg_type.h" #include "fmgr.h" +#include "hnsw.h" #include "ivfflat.h" #include "lib/stringinfo.h" #include "libpq/pqformat.h" @@ -37,6 +38,7 @@ PG_MODULE_MAGIC; void _PG_init(void) { + HnswInit(); IvfflatInit(); } diff --git a/test/expected/hnsw_cosine.out b/test/expected/hnsw_cosine.out new file mode 100644 index 0000000..df9eb81 --- /dev/null +++ b/test/expected/hnsw_cosine.out @@ -0,0 +1,26 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_cosine_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; + val +--------- + [1,1,1] + [1,2,3] + [1,2,4] +(3 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; + count +------- + 3 +(1 row) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; + count +------- + 3 +(1 row) + +DROP TABLE t; diff --git a/test/expected/hnsw_ip.out b/test/expected/hnsw_ip.out new file mode 100644 index 0000000..92a5072 --- /dev/null +++ b/test/expected/hnsw_ip.out @@ -0,0 +1,21 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_ip_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; + val +--------- + [1,2,4] + [1,2,3] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; + count +------- + 4 +(1 row) + +DROP TABLE t; diff --git a/test/expected/hnsw_l2.out b/test/expected/hnsw_l2.out new file mode 100644 index 0000000..e8a16c8 --- /dev/null +++ b/test/expected/hnsw_l2.out @@ -0,0 +1,30 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,2,4] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); + val +--------- + [0,0,0] + [1,1,1] + [1,2,3] + [1,2,4] +(4 rows) + +SELECT COUNT(*) FROM t; + count +------- + 5 +(1 row) + +DROP TABLE t; diff --git a/test/expected/hnsw_options.out b/test/expected/hnsw_options.out new file mode 100644 index 0000000..be10beb --- /dev/null +++ b/test/expected/hnsw_options.out @@ -0,0 +1,25 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 3); +ERROR: value 3 out of bounds for option "m" +DETAIL: Valid values are between "4" and "100". +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101); +ERROR: value 101 out of bounds for option "m" +DETAIL: Valid values are between "4" and "100". +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 9); +ERROR: value 9 out of bounds for option "ef_construction" +DETAIL: Valid values are between "10" and "1000". +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001); +ERROR: value 1001 out of bounds for option "ef_construction" +DETAIL: Valid values are between "10" and "1000". +SHOW hnsw.ef_search; + hnsw.ef_search +---------------- + 40 +(1 row) + +SET hnsw.ef_search = 9; +ERROR: 9 is outside the valid range for parameter "hnsw.ef_search" (10 .. 1000) +SET hnsw.ef_search = 1001; +ERROR: 1001 is outside the valid range for parameter "hnsw.ef_search" (10 .. 1000) +DROP TABLE t; diff --git a/test/expected/hnsw_unlogged.out b/test/expected/hnsw_unlogged.out new file mode 100644 index 0000000..0063773 --- /dev/null +++ b/test/expected/hnsw_unlogged.out @@ -0,0 +1,13 @@ +SET enable_seqscan = off; +CREATE UNLOGGED TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,1,1] + [0,0,0] +(3 rows) + +DROP TABLE t; diff --git a/test/sql/hnsw_cosine.sql b/test/sql/hnsw_cosine.sql new file mode 100644 index 0000000..d23f4f3 --- /dev/null +++ b/test/sql/hnsw_cosine.sql @@ -0,0 +1,13 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_cosine_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; + +DROP TABLE t; diff --git a/test/sql/hnsw_ip.sql b/test/sql/hnsw_ip.sql new file mode 100644 index 0000000..5a616a1 --- /dev/null +++ b/test/sql/hnsw_ip.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_ip_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; + +DROP TABLE t; diff --git a/test/sql/hnsw_l2.sql b/test/sql/hnsw_l2.sql new file mode 100644 index 0000000..8664cb3 --- /dev/null +++ b/test/sql/hnsw_l2.sql @@ -0,0 +1,13 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; +SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); +SELECT COUNT(*) FROM t; + +DROP TABLE t; diff --git a/test/sql/hnsw_options.sql b/test/sql/hnsw_options.sql new file mode 100644 index 0000000..c289922 --- /dev/null +++ b/test/sql/hnsw_options.sql @@ -0,0 +1,14 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 3); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 9); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001); + +SHOW hnsw.ef_search; + +SET hnsw.ef_search = 9; +SET hnsw.ef_search = 1001; + +DROP TABLE t; diff --git a/test/sql/hnsw_unlogged.sql b/test/sql/hnsw_unlogged.sql new file mode 100644 index 0000000..2efcc95 --- /dev/null +++ b/test/sql/hnsw_unlogged.sql @@ -0,0 +1,9 @@ +SET enable_seqscan = off; + +CREATE UNLOGGED TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); + +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +DROP TABLE t; diff --git a/test/t/010_hnsw_wal.pl b/test/t/010_hnsw_wal.pl new file mode 100644 index 0000000..1e1e3f0 --- /dev/null +++ b/test/t/010_hnsw_wal.pl @@ -0,0 +1,99 @@ +# Based on postgres/contrib/bloom/t/001_wal.pl + +# Test generic xlog record work for hnsw index replication. +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $dim = 32; + +my $node_primary; +my $node_replica; + +# Run few queries on both primary and replica and check their results match. +sub test_index_replay +{ + my ($test_name) = @_; + + # Wait for replica to catch up + my $applname = $node_replica->name; + + my $server_version_num = $node_primary->safe_psql("postgres", "SHOW server_version_num"); + my $caughtup_query = "SELECT pg_current_wal_lsn() <= replay_lsn FROM pg_stat_replication WHERE application_name = '$applname';"; + $node_primary->poll_query_until('postgres', $caughtup_query) + or die "Timed out while waiting for replica 1 to catch up"; + + my @r = (); + for (1 .. $dim) { + push(@r, rand()); + } + my $sql = join(",", @r); + + my $queries = qq( + SET enable_seqscan = off; + SELECT * FROM tst ORDER BY v <-> '[$sql]' LIMIT 10; + ); + + # Run test queries and compare their result + my $primary_result = $node_primary->safe_psql("postgres", $queries); + my $replica_result = $node_replica->safe_psql("postgres", $queries); + + is($primary_result, $replica_result, "$test_name: query result matches"); + return; +} + +# Use ARRAY[random(), random(), random(), ...] over +# SELECT array_agg(random()) FROM generate_series(1, $dim) +# to generate different values for each row +my $array_sql = join(",", ('random()') x $dim); + +# Initialize primary node +$node_primary = get_new_node('primary'); +$node_primary->init(allows_streaming => 1); +if ($dim > 32) { + # TODO use wal_keep_segments for Postgres < 13 + $node_primary->append_conf('postgresql.conf', qq(wal_keep_size = 1GB)); +} +if ($dim > 1500) { + $node_primary->append_conf('postgresql.conf', qq(maintenance_work_mem = 128MB)); +} +$node_primary->start; +my $backup_name = 'my_backup'; + +# Take backup +$node_primary->backup($backup_name); + +# Create streaming replica linking to primary +$node_replica = get_new_node('replica'); +$node_replica->init_from_backup($node_primary, $backup_name, + has_streaming => 1); +$node_replica->start; + +# Create hnsw index on primary +$node_primary->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node_primary->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node_primary->safe_psql("postgres", + "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 1000) i;" +); +$node_primary->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); + +# Test that queries give same result +test_index_replay('initial'); + +# Run 10 cycles of table modification. Run test queries after each modification. +for my $i (1 .. 10) +{ + $node_primary->safe_psql("postgres", "DELETE FROM tst WHERE i = $i;"); + test_index_replay("delete $i"); + $node_primary->safe_psql("postgres", "VACUUM tst;"); + test_index_replay("vacuum $i"); + my ($start, $end) = (1001 + ($i - 1) * 100, 1000 + $i * 100); + $node_primary->safe_psql("postgres", + "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series($start, $end) i;" + ); + test_index_replay("insert $i"); +} + +done_testing(); diff --git a/test/t/011_hnsw_vacuum.pl b/test/t/011_hnsw_vacuum.pl new file mode 100644 index 0000000..67379e1 --- /dev/null +++ b/test/t/011_hnsw_vacuum.pl @@ -0,0 +1,43 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $dim = 3; + +my @r = (); +for (1 .. $dim) { + my $v = int(rand(1000)) + 1; + push(@r, "i % $v"); +} +my $array_sql = join(", ", @r); + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table and index +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); + +# Get size +my $size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); + +# Delete all, vacuum, and insert same data +$node->safe_psql("postgres", "DELETE FROM tst;"); +$node->safe_psql("postgres", "VACUUM tst;"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); + +# Check size +my $new_size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); +cmp_ok($new_size, "<=", $size * 1.01, "size does not increase too much"); + +done_testing(); diff --git a/test/t/012_hnsw_build_recall.pl b/test/t/012_hnsw_build_recall.pl new file mode 100644 index 0000000..d8fa2ab --- /dev/null +++ b/test/t/012_hnsw_build_recall.pl @@ -0,0 +1,96 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) { + if (exists($actual_set{$_})) { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" +); + +# Generate queries +for (1..20) { + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "[$r1,$r2,$r3]"); +} + +# Check each index type +my @operators = ("<->", "<#>", "<=>"); + +foreach (@operators) { + my $operator = $_; + + # Get exact results + @expected = (); + foreach (@queries) { + my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); + push(@expected, $res); + } + + # Add index + my $opclass; + if ($operator eq "<->") { + $opclass = "vector_l2_ops"; + } elsif ($operator eq "<#>") { + $opclass = "vector_ip_ops"; + } else { + $opclass = "vector_cosine_ops"; + } + $node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v $opclass);"); + + if ($operator eq "<#>") { + test_recall(0.80, $operator); + } else { + test_recall(0.99, $operator); + } +} + +done_testing(); diff --git a/test/t/013_hnsw_insert_recall.pl b/test/t/013_hnsw_insert_recall.pl new file mode 100644 index 0000000..7e5ffc5 --- /dev/null +++ b/test/t/013_hnsw_insert_recall.pl @@ -0,0 +1,103 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) { + if (exists($actual_set{$_})) { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); + +# Generate queries +for (1..20) { + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "[$r1,$r2,$r3]"); +} + +# Check each index type +my @operators = ("<->", "<#>", "<=>"); + +foreach (@operators) { + my $operator = $_; + + # Add index + my $opclass; + if ($operator eq "<->") { + $opclass = "vector_l2_ops"; + } elsif ($operator eq "<#>") { + $opclass = "vector_ip_ops"; + } else { + $opclass = "vector_cosine_ops"; + } + $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v $opclass);"); + + $node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" + ); + + # Get exact results + @expected = (); + foreach (@queries) { + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit; + )); + push(@expected, $res); + } + + if ($operator eq "<#>") { + test_recall(0.80, $operator); + } else { + test_recall(0.99, $operator); + } + + $node->safe_psql("postgres", "DROP INDEX idx;"); + $node->safe_psql("postgres", "TRUNCATE tst;"); +} + +done_testing(); diff --git a/test/t/014_hnsw_inserts.pl b/test/t/014_hnsw_inserts.pl new file mode 100644 index 0000000..5478fe4 --- /dev/null +++ b/test/t/014_hnsw_inserts.pl @@ -0,0 +1,58 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +# Ensures elements and neighbors on both same and different pages +my $dim = 1900; + +my $array_sql = join(",", ('random()') x $dim); + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table and index +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 100) i;" +); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); + +$node->pgbench( + "--no-vacuum --client=5 --transactions=100", + 0, + [qr{actually processed}], + [qr{^$}], + "concurrent INSERTs", + { + "007_inserts" => "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10) i;" + } +); + +sub idx_scan +{ + # Stats do not update instantaneously + # https://www.postgresql.org/docs/current/monitoring-stats.html#MONITORING-STATS-VIEWS + sleep(1); + $node->safe_psql("postgres", "SELECT idx_scan FROM pg_stat_user_indexes WHERE indexrelid = 'tst_v_idx'::regclass;"); +} + +my $expected = 100 + 5 * 100 * 10; + +my $count = $node->safe_psql("postgres", "SELECT COUNT(*) FROM tst;"); +is($count, $expected); +is(idx_scan(), 0); + +$count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = 400; + SELECT COUNT(*) FROM (SELECT v FROM tst ORDER BY v <-> (SELECT v FROM tst LIMIT 1)) t; +)); +is($count, 400); +is(idx_scan(), 1); + +done_testing();